Skip to content

Commit e5829bc

Browse files
committed
update test_custom_op
1 parent d0a88a6 commit e5829bc

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tests/test_graph.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def get_attribute_value(attr):
7070
if "broadcast" in attr:
7171
kwarg["broadcast"] = "{}".format(int(attr["broadcast"].i))
7272

73+
# display domain if it is not onnx domain
74+
if node.domain:
75+
kwarg["domain"] = node.domain
76+
7377
g2.node(node.name, op_type=node.type, **kwarg)
7478
for node in g.get_nodes():
7579
for i in node.input:
@@ -349,9 +353,11 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
349353
opset=self.config.opset,
350354
extra_opset=[constants.DEFAULT_CUSTOM_OP_OPSET])
351355
self.assertEqual(
352-
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '
353-
'output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
356+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [domain="ai.onnx.converters.tensorflow" '
357+
'op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
354358
onnx_to_graphviz(g))
359+
self.assertEqual(g.opset, self.config.opset)
360+
self.assertEqual(g.extra_opset, [constants.DEFAULT_CUSTOM_OP_OPSET])
355361

356362

357363
if __name__ == '__main__':

0 commit comments

Comments
 (0)