Skip to content

Commit a7a3c73

Browse files
committed
add test for graph matcher
1 parent 8a3c573 commit a7a3c73

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

tests/test_backend.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
1818
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
1919
from tf2onnx import constants
20+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2021

2122
# pylint: disable=missing-docstring,invalid-name,unused-argument
2223

@@ -87,6 +88,7 @@ def get_conv_getdata(kind=1):
8788
else:
8889
raise ValueError("kind not known")
8990

91+
9092
def get_maxpoolwithargmax_getdata():
9193
data = [
9294
('SAME', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]),
@@ -99,6 +101,7 @@ def get_maxpoolwithargmax_getdata():
99101
for idx, v in enumerate(data):
100102
yield (idx,) + v
101103

104+
102105
class BackendTests(Tf2OnnxBackendTestBase):
103106
def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
104107
kwargs["convert_var_to_const"] = False
@@ -2015,7 +2018,6 @@ def test_reverse_sequence_time_major(self):
20152018
_ = tf.identity(x_, name=_TFOUTPUT)
20162019
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20172020

2018-
20192021
@check_opset_min_version(10, "ReverseSequence")
20202022
def test_reversev2_constant_axis(self):
20212023
# Tests for constant axis.
@@ -2035,7 +2037,6 @@ def test_reversev2_constant_axis(self):
20352037
_ = tf.identity(x_, name=_TFOUTPUT)
20362038
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20372039

2038-
20392040
@check_opset_min_version(10, "ReverseSequence")
20402041
def test_reversev2_vector_axis(self):
20412042
x_val_shape = [1, 2, 3, 4]
@@ -2061,7 +2062,6 @@ def test_reversev2_vector_axis(self):
20612062
_ = tf.identity(x_, name=_TFOUTPUT)
20622063
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20632064

2064-
20652065
@check_opset_min_version(10, "ReverseSequence")
20662066
def test_reversev2_1D_tensor(self):
20672067
# For tensors with 1 dimension and no axis to reverse.
@@ -2073,7 +2073,6 @@ def test_reversev2_1D_tensor(self):
20732073
_ = tf.identity(x_, name=_TFOUTPUT)
20742074
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20752075

2076-
20772076
@check_opset_min_version(8, "where")
20782077
def test_where(self):
20792078
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.float32)
@@ -2553,5 +2552,36 @@ def test_selu(self):
25532552
_ = tf.identity(y, name=_TFOUTPUT)
25542553
self._run_test_case([_OUTPUT], {_INPUT: x_val})
25552554

2555+
def test_graph_matcher(self):
2556+
shape = [2, 6]
2557+
x_val = np.random.random(shape).astype(np.float32)
2558+
y_val = np.random.random(shape).astype(np.float32)
2559+
z_val = np.random.random(shape).astype(np.float32)
2560+
x = tf.placeholder(tf.float32, shape, name=_TFINPUT)
2561+
y = tf.placeholder(tf.float32, shape, name=_TFINPUT1)
2562+
z = tf.placeholder(tf.float32, shape, name=_TFINPUT2)
2563+
tmp1 = x + y
2564+
tmp2 = x - y
2565+
tmp3 = tf.multiply(tmp1, z)
2566+
tmp4 = tf.multiply(tmp2, z)
2567+
_ = tf.add(tmp4, tmp3, name=_TFOUTPUT)
2568+
onnx_graph = self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
2569+
pattern = \
2570+
OpTypePattern('Add', name='output', inputs=[
2571+
OpTypePattern('Mul', inputs=[
2572+
OpTypePattern('Add', name='input1'),
2573+
OpTypePattern('*', name='input2')]),
2574+
OpTypePattern('Mul', inputs=[
2575+
OpTypePattern('Sub', name='input1'),
2576+
OpTypePattern('*', name='input2')])])
2577+
2578+
matcher = GraphMatcher(pattern, allow_reorder=False)
2579+
match_results = list(matcher.match_ops(onnx_graph.get_nodes()))
2580+
self.assertTrue(len(match_results) == 0)
2581+
matcher = GraphMatcher(pattern, allow_reorder=True)
2582+
match_results = list(matcher.match_ops(onnx_graph.get_nodes()))
2583+
self.assertTrue(len(match_results) == 1)
2584+
2585+
25562586
if __name__ == '__main__':
25572587
unittest_main()

0 commit comments

Comments
 (0)