17
17
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
18
18
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
19
19
from tf2onnx import constants
20
+ from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
20
21
21
22
# pylint: disable=missing-docstring,invalid-name,unused-argument
22
23
@@ -87,6 +88,7 @@ def get_conv_getdata(kind=1):
87
88
else :
88
89
raise ValueError ("kind not known" )
89
90
91
+
90
92
def get_maxpoolwithargmax_getdata ():
91
93
data = [
92
94
('SAME' , [1 , 3 , 3 , 1 ], [1 , 3 , 3 , 1 ], [1 , 2 , 2 , 1 ]),
@@ -99,6 +101,7 @@ def get_maxpoolwithargmax_getdata():
99
101
for idx , v in enumerate (data ):
100
102
yield (idx ,) + v
101
103
104
+
102
105
class BackendTests (Tf2OnnxBackendTestBase ):
103
106
def _run_test_case (self , output_names_with_port , feed_dict , ** kwargs ):
104
107
kwargs ["convert_var_to_const" ] = False
@@ -2015,7 +2018,6 @@ def test_reverse_sequence_time_major(self):
2015
2018
_ = tf .identity (x_ , name = _TFOUTPUT )
2016
2019
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
2017
2020
2018
-
2019
2021
@check_opset_min_version (10 , "ReverseSequence" )
2020
2022
def test_reversev2_constant_axis (self ):
2021
2023
# Tests for constant axis.
@@ -2035,7 +2037,6 @@ def test_reversev2_constant_axis(self):
2035
2037
_ = tf .identity (x_ , name = _TFOUTPUT )
2036
2038
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
2037
2039
2038
-
2039
2040
@check_opset_min_version (10 , "ReverseSequence" )
2040
2041
def test_reversev2_vector_axis (self ):
2041
2042
x_val_shape = [1 , 2 , 3 , 4 ]
@@ -2061,7 +2062,6 @@ def test_reversev2_vector_axis(self):
2061
2062
_ = tf .identity (x_ , name = _TFOUTPUT )
2062
2063
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
2063
2064
2064
-
2065
2065
@check_opset_min_version (10 , "ReverseSequence" )
2066
2066
def test_reversev2_1D_tensor (self ):
2067
2067
# For tensors with 1 dimension and no axis to reverse.
@@ -2073,7 +2073,6 @@ def test_reversev2_1D_tensor(self):
2073
2073
_ = tf .identity (x_ , name = _TFOUTPUT )
2074
2074
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
2075
2075
2076
-
2077
2076
@check_opset_min_version (8 , "where" )
2078
2077
def test_where (self ):
2079
2078
x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .float32 )
@@ -2553,5 +2552,36 @@ def test_selu(self):
2553
2552
_ = tf .identity (y , name = _TFOUTPUT )
2554
2553
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
2555
2554
2555
+ def test_graph_matcher (self ):
2556
+ shape = [2 , 6 ]
2557
+ x_val = np .random .random (shape ).astype (np .float32 )
2558
+ y_val = np .random .random (shape ).astype (np .float32 )
2559
+ z_val = np .random .random (shape ).astype (np .float32 )
2560
+ x = tf .placeholder (tf .float32 , shape , name = _TFINPUT )
2561
+ y = tf .placeholder (tf .float32 , shape , name = _TFINPUT1 )
2562
+ z = tf .placeholder (tf .float32 , shape , name = _TFINPUT2 )
2563
+ tmp1 = x + y
2564
+ tmp2 = x - y
2565
+ tmp3 = tf .multiply (tmp1 , z )
2566
+ tmp4 = tf .multiply (tmp2 , z )
2567
+ _ = tf .add (tmp4 , tmp3 , name = _TFOUTPUT )
2568
+ onnx_graph = self ._run_test_case ([_OUTPUT ], {_INPUT : x_val , _INPUT1 : y_val , _INPUT2 : z_val })
2569
+ pattern = \
2570
+ OpTypePattern ('Add' , name = 'output' , inputs = [
2571
+ OpTypePattern ('Mul' , inputs = [
2572
+ OpTypePattern ('Add' , name = 'input1' ),
2573
+ OpTypePattern ('*' , name = 'input2' )]),
2574
+ OpTypePattern ('Mul' , inputs = [
2575
+ OpTypePattern ('Sub' , name = 'input1' ),
2576
+ OpTypePattern ('*' , name = 'input2' )])])
2577
+
2578
+ matcher = GraphMatcher (pattern , allow_reorder = False )
2579
+ match_results = list (matcher .match_ops (onnx_graph .get_nodes ()))
2580
+ self .assertTrue (len (match_results ) == 0 )
2581
+ matcher = GraphMatcher (pattern , allow_reorder = True )
2582
+ match_results = list (matcher .match_ops (onnx_graph .get_nodes ()))
2583
+ self .assertTrue (len (match_results ) == 1 )
2584
+
2585
+
2556
2586
if __name__ == '__main__' :
2557
2587
unittest_main ()
0 commit comments