Skip to content

Commit fcd3ab1

Browse files
committed
add test_extra_opset
1 parent e5829bc commit fcd3ab1

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_graph.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import tf2onnx
2121
from tf2onnx import constants
22+
from tf2onnx.graph import GraphUtil
2223
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2324
from tf2onnx.tfonnx import process_tf_graph
2425
from common import get_test_config, unittest_main
@@ -359,6 +360,28 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
359360
self.assertEqual(g.opset, self.config.opset)
360361
self.assertEqual(g.extra_opset, [constants.DEFAULT_CUSTOM_OP_OPSET])
361362

363+
def test_extra_opset(self):
364+
extra_opset = [
365+
helper.make_opsetid(constants.MICROSOFT_DOMAIN, 1),
366+
helper.make_opsetid("my.domain", 1024),
367+
]
368+
with tf.Session() as sess:
369+
x = tf.placeholder(tf.float32, [2, 3], name="input1")
370+
x_ = tf.add(x, x)
371+
_ = tf.identity(x_, name="output")
372+
g = process_tf_graph(sess.graph,
373+
opset=self.config.opset,
374+
extra_opset=extra_opset)
375+
self.assertEqual(g.opset, self.config.opset)
376+
self.assertEqual(g.extra_opset, extra_opset)
377+
378+
# convert between graph and model proto, make sure extra opset is preserved
379+
model_proto = g.make_model("test")
380+
model_proto = GraphUtil.optimize_model_proto(model_proto)
381+
g = GraphUtil.create_graph_from_onnx_model(model_proto)
382+
self.assertEqual(g.opset, self.config.opset)
383+
self.assertEqual(g.extra_opset, extra_opset)
384+
362385

363386
if __name__ == '__main__':
364387
unittest_main()

0 commit comments

Comments
 (0)