Skip to content

Commit 5081424

Browse files
authored
Merge pull request #305 from jiafatom/customized_arg
Pass args to custom_op converter
2 parents 73afa70 + 598453f commit 5081424

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def print_handler(ctx, node, name, args):
227227
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
228228
# becomes:
229229
# T output = Identity(T Input)
230-
node.type = "Identity"
231230
node.domain = _TENSORFLOW_DOMAIN
232231
del node.input[1:]
233232
return node
@@ -239,7 +238,7 @@ with tf.Session() as sess:
239238
x_ = tf.Print(x, [x], "hello")
240239
_ = tf.identity(x_, name="output")
241240
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
242-
custom_op_handlers={"Print": print_handler},
241+
custom_op_handlers={"Print": (print_handler, ["Identity", "mode"])},
243242
extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)],
244243
input_names=["input:0"],
245244
output_names=["output:0"])

tests/test_graph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,9 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
333333
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
334334
# becomes:
335335
# T output = Identity(T Input)
336-
node.type = "Identity"
336+
self.assertEqual(node.type, "Identity")
337337
node.domain = _TENSORFLOW_DOMAIN
338+
self.assertEqual(args[0], "mode")
338339
del node.input[1:]
339340
return node
340341

@@ -343,7 +344,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
343344
x_ = tf.Print(x, [x], "hello")
344345
_ = tf.identity(x_, name="output")
345346
g = process_tf_graph(sess.graph,
346-
custom_op_handlers={"Print": print_handler},
347+
custom_op_handlers={"Print": (print_handler, ["Identity", "mode"])},
347348
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
348349
self.assertEqual(
349350
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2284,7 +2284,7 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
22842284
# apply custom ops on top of the assembled opset. We can either completment the opset
22852285
# or override existing ops with a custom op.
22862286
if custom_op_handlers is not None:
2287-
custom_opset = {k: [v, []] for k, v in custom_op_handlers.items()}
2287+
custom_opset = {k: v for k, v in custom_op_handlers.items()}
22882288
ops_mapping.update(custom_opset)
22892289

22902290
ops = g.get_nodes()

0 commit comments

Comments
 (0)