|
19 | 19 |
|
20 | 20 | import tf2onnx
|
21 | 21 | from tf2onnx import constants
|
| 22 | +from tf2onnx.graph import GraphUtil |
22 | 23 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
|
23 | 24 | from tf2onnx.tfonnx import process_tf_graph
|
24 | 25 | from common import get_test_config, unittest_main
|
@@ -359,6 +360,28 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
|
359 | 360 | self.assertEqual(g.opset, self.config.opset)
|
360 | 361 | self.assertEqual(g.extra_opset, [constants.DEFAULT_CUSTOM_OP_OPSET])
|
361 | 362 |
|
| 363 | + def test_extra_opset(self): |
| 364 | + extra_opset = [ |
| 365 | + helper.make_opsetid(constants.MICROSOFT_DOMAIN, 1), |
| 366 | + helper.make_opsetid("my.domain", 1024), |
| 367 | + ] |
| 368 | + with tf.Session() as sess: |
| 369 | + x = tf.placeholder(tf.float32, [2, 3], name="input1") |
| 370 | + x_ = tf.add(x, x) |
| 371 | + _ = tf.identity(x_, name="output") |
| 372 | + g = process_tf_graph(sess.graph, |
| 373 | + opset=self.config.opset, |
| 374 | + extra_opset=extra_opset) |
| 375 | + self.assertEqual(g.opset, self.config.opset) |
| 376 | + self.assertEqual(g.extra_opset, extra_opset) |
| 377 | + |
| 378 | + # convert between graph and model proto, make sure extra opset is preserved |
| 379 | + model_proto = g.make_model("test") |
| 380 | + model_proto = GraphUtil.optimize_model_proto(model_proto) |
| 381 | + g = GraphUtil.create_graph_from_onnx_model(model_proto) |
| 382 | + self.assertEqual(g.opset, self.config.opset) |
| 383 | + self.assertEqual(g.extra_opset, extra_opset) |
| 384 | + |
362 | 385 |
|
363 | 386 | if __name__ == '__main__':
|
364 | 387 | unittest_main()
|
0 commit comments