22
22
from tf2onnx .graph import GraphUtil
23
23
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
24
24
from tf2onnx .tfonnx import process_tf_graph
25
+ from tf2onnx .handler import tf_op
26
+
25
27
from common import get_test_config , unittest_main
26
28
27
29
28
- # pylint: disable=missing-docstring
30
+ # pylint: disable=missing-docstring,unused-argument,unused-variable
29
31
30
32
def onnx_to_graphviz (g , include_attrs = False ):
31
33
"""Return dot for graph."""
@@ -331,10 +333,10 @@ def rewrite_test(g, ops):
331
333
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }' ,
332
334
onnx_to_graphviz (g ))
333
335
334
- def test_custom_op (self ):
335
- """Custom op test."""
336
+ def test_custom_op_depreciated (self ):
337
+ """Custom op test using old depreciated api ."""
336
338
337
- def print_handler (ctx , node , name , args ): # pylint: disable=unused-argument
339
+ def print_handler (ctx , node , name , args ):
338
340
# replace tf.Print() with Identity
339
341
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
340
342
# becomes:
@@ -360,6 +362,32 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
360
362
self .assertEqual (g .opset , self .config .opset )
361
363
self .assertEqual (g .extra_opset , [constants .TENSORFLOW_OPSET ])
362
364
365
+ def test_custom_op (self ):
366
+ """Custom op test."""
367
+
368
+ @tf_op ("Print" , type_map = {"Print" : "Identity" })
369
+ class Print :
370
+ @classmethod
371
+ def version_1 (cls , ctx , node , ** kwargs ):
372
+ self .assertEqual (node .type , "Identity" )
373
+ node .domain = constants .TENSORFLOW_OPSET .domain
374
+ del node .input [1 :]
375
+ return node
376
+
377
+ with tf .Session () as sess :
378
+ x = tf .placeholder (tf .float32 , [2 , 3 ], name = "input1" )
379
+ x_ = tf .Print (x , [x ], "hello" )
380
+ _ = tf .identity (x_ , name = "output" )
381
+ g = process_tf_graph (sess .graph ,
382
+ opset = self .config .opset ,
383
+ extra_opset = [constants .TENSORFLOW_OPSET ])
384
+ self .assertEqual (
385
+ 'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [domain="ai.onnx.converters.tensorflow" '
386
+ 'op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }' ,
387
+ onnx_to_graphviz (g ))
388
+ self .assertEqual (g .opset , self .config .opset )
389
+ self .assertEqual (g .extra_opset , [constants .TENSORFLOW_OPSET ])
390
+
363
391
def test_extra_opset (self ):
364
392
extra_opset = [
365
393
utils .make_opsetid (constants .MICROSOFT_DOMAIN , 1 ),
0 commit comments