Skip to content

Commit 3cc7a0c

Browse files
committed
support for old custom op registration, fix graph unit test
1 parent 86b208e commit 3cc7a0c

File tree

3 files changed

+70
-7
lines changed

3 files changed

+70
-7
lines changed

tests/test_graph.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
from tf2onnx.graph import GraphUtil
2323
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2424
from tf2onnx.tfonnx import process_tf_graph
25+
from tf2onnx.handler import tf_op
26+
2527
from common import get_test_config, unittest_main
2628

2729

28-
# pylint: disable=missing-docstring
30+
# pylint: disable=missing-docstring,unused-argument,unused-variable
2931

3032
def onnx_to_graphviz(g, include_attrs=False):
3133
"""Return dot for graph."""
@@ -331,10 +333,10 @@ def rewrite_test(g, ops):
331333
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }',
332334
onnx_to_graphviz(g))
333335

334-
def test_custom_op(self):
335-
"""Custom op test."""
336+
def test_custom_op_depreciated(self):
337+
"""Custom op test using old depreciated api."""
336338

337-
def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
339+
def print_handler(ctx, node, name, args):
338340
# replace tf.Print() with Identity
339341
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
340342
# becomes:
@@ -360,6 +362,32 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
360362
self.assertEqual(g.opset, self.config.opset)
361363
self.assertEqual(g.extra_opset, [constants.TENSORFLOW_OPSET])
362364

365+
def test_custom_op(self):
366+
"""Custom op test."""
367+
368+
@tf_op("Print", type_map={"Print": "Identity"})
369+
class Print:
370+
@classmethod
371+
def version_1(cls, ctx, node, **kwargs):
372+
self.assertEqual(node.type, "Identity")
373+
node.domain = constants.TENSORFLOW_OPSET.domain
374+
del node.input[1:]
375+
return node
376+
377+
with tf.Session() as sess:
378+
x = tf.placeholder(tf.float32, [2, 3], name="input1")
379+
x_ = tf.Print(x, [x], "hello")
380+
_ = tf.identity(x_, name="output")
381+
g = process_tf_graph(sess.graph,
382+
opset=self.config.opset,
383+
extra_opset=[constants.TENSORFLOW_OPSET])
384+
self.assertEqual(
385+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [domain="ai.onnx.converters.tensorflow" '
386+
'op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
387+
onnx_to_graphviz(g))
388+
self.assertEqual(g.opset, self.config.opset)
389+
self.assertEqual(g.extra_opset, [constants.TENSORFLOW_OPSET])
390+
363391
def test_extra_opset(self):
364392
extra_opset = [
365393
utils.make_opsetid(constants.MICROSOFT_DOMAIN, 1),

tf2onnx/handler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,20 @@ def __call__(self, func):
3838
for name in self.name:
3939
opset_dict[name] = (v, self.kwargs)
4040

41+
def register_compat_handler(self, func, version):
42+
opset = tf_op._OPSETS.get(self.domain)
43+
if not opset:
44+
opset = []
45+
tf_op._OPSETS[self.domain] = opset
46+
while version >= len(opset):
47+
opset.append({})
48+
opset_dict = opset[version]
49+
opset_dict[self.name[0]] = (func, self.kwargs)
50+
4151
@staticmethod
4252
def get_opsets():
4353
return tf_op._OPSETS
4454

45-
4655
@staticmethod
4756
def create_mapping(max_opset, extra_opsets):
4857
mapping = {"onnx": max_opset}
@@ -60,7 +69,6 @@ def create_mapping(max_opset, extra_opsets):
6069
tf_op._MAPPING = ops_mapping
6170
return ops_mapping
6271

63-
6472
@staticmethod
6573
def find_op(name):
6674
map_info = tf_op._MAPPING.get(name)

tf2onnx/tfonnx.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,34 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
717717
# apply custom ops on top of the assembled opset. We can either complement the opset
718718
# or override existing ops with a custom op.
719719
if custom_op_handlers is not None:
720-
custom_opset = {k: v for k, v in custom_op_handlers.items()}
720+
# below is a bit tricky since there are a few api's:
721+
# 1. the future way we want custom ops to be registered with the @tf_op decorator. THose handlers will be
722+
# registered via the decorator on load of the module ... nothing is required here.
723+
# 2. the old custom op api: a dictionary of {name: (func, args[])
724+
# We deal with this by using a compat_handler that wraps to old handler with a new style handler.
725+
# This is tempoary to give people give to move to the new api and after tf2onnx-1.5 we want to remove this
726+
custom_opset = {}
727+
for k, v in custom_op_handlers.items():
728+
# FIXME: remove this after tf2onnx-1.5
729+
def compat_handler(ctx, node, **kwargs):
730+
# wrap old handler
731+
name = node.name
732+
args = kwargs["args"]
733+
func = kwargs["func"]
734+
return func(ctx, node, name, args)
735+
736+
args = v[1]
737+
kwargs = {"func": v[0]}
738+
if args:
739+
type_map = {k: args[0]}
740+
kwargs["type_map"] = type_map
741+
args = args[1:]
742+
kwargs["args"] = args
743+
new_handler = handler.tf_op(k,
744+
domain=constants.TENSORFLOW_OPSET.domain,
745+
kwargs=kwargs)
746+
new_handler.register_compat_handler(compat_handler, 1)
747+
custom_opset[k] = (compat_handler, kwargs)
721748
ops_mapping.update(custom_opset)
722749

723750
infer_shape_for_graph(g)

0 commit comments

Comments
 (0)