Skip to content

Commit 09cd571

Browse files
committed
rename DEFAULT_CUSTOM_OP_OPSET
1 parent 75e5c03 commit 09cd571

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

tests/test_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
340340
# becomes:
341341
# T output = Identity(T Input)
342342
self.assertEqual(node.type, "Identity")
343-
node.domain = constants.DEFAULT_CUSTOM_OP_OPSET.domain
343+
node.domain = constants.TENSORFLOW_OPSET.domain
344344
self.assertEqual(args[0], "mode")
345345
del node.input[1:]
346346
return node
@@ -352,13 +352,13 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
352352
g = process_tf_graph(sess.graph,
353353
custom_op_handlers={"Print": (print_handler, ["Identity", "mode"])},
354354
opset=self.config.opset,
355-
extra_opset=[constants.DEFAULT_CUSTOM_OP_OPSET])
355+
extra_opset=[constants.TENSORFLOW_OPSET])
356356
self.assertEqual(
357357
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [domain="ai.onnx.converters.tensorflow" '
358358
'op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
359359
onnx_to_graphviz(g))
360360
self.assertEqual(g.opset, self.config.opset)
361-
self.assertEqual(g.extra_opset, [constants.DEFAULT_CUSTOM_OP_OPSET])
361+
self.assertEqual(g.extra_opset, [constants.TENSORFLOW_OPSET])
362362

363363
def test_extra_opset(self):
364364
extra_opset = [

tf2onnx/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
PREFERRED_OPSET = 7
1717

1818
# Default opset for custom ops
19-
DEFAULT_CUSTOM_OP_OPSET = utils.make_opsetid("ai.onnx.converters.tensorflow", 1)
19+
TENSORFLOW_OPSET = utils.make_opsetid("ai.onnx.converters.tensorflow", 1)
2020

2121
# Target for the generated onnx graph. It possible targets:
2222
# onnx-1.1 = onnx at v1.1 (winml in rs4 is based on this)

tf2onnx/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_args():
7373

7474

7575
def default_custom_op_handler(ctx, node, name, args):
76-
node.domain = constants.DEFAULT_CUSTOM_OP_OPSET.domain
76+
node.domain = constants.TENSORFLOW_OPSET.domain
7777
return node
7878

7979

@@ -89,7 +89,7 @@ def main():
8989
if args.custom_ops:
9090
# default custom ops for tensorflow-onnx are in the "tf" namespace
9191
custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
92-
extra_opset.append(constants.DEFAULT_CUSTOM_OP_OPSET)
92+
extra_opset.append(constants.TENSORFLOW_OPSET)
9393

9494
# get the frozen tensorflow model from graphdef, checkpoint or saved_model.
9595
if args.graphdef:

0 commit comments

Comments
 (0)