Skip to content

Commit 29ab979

Browse files
author
wayuanho
authored
Merge pull request #615 from zhijxu-MS/fix_graph_matcher
Fix bug and refactor graph matcher
2 parents ddb4301 + a7a3c73 commit 29ab979

File tree

3 files changed

+91
-40
lines changed

3 files changed

+91
-40
lines changed

tests/test_backend.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
1818
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
1919
from tf2onnx import constants
20+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2021

2122
# pylint: disable=missing-docstring,invalid-name,unused-argument
2223

@@ -87,6 +88,7 @@ def get_conv_getdata(kind=1):
8788
else:
8889
raise ValueError("kind not known")
8990

91+
9092
def get_maxpoolwithargmax_getdata():
9193
data = [
9294
('SAME', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]),
@@ -99,6 +101,7 @@ def get_maxpoolwithargmax_getdata():
99101
for idx, v in enumerate(data):
100102
yield (idx,) + v
101103

104+
102105
class BackendTests(Tf2OnnxBackendTestBase):
103106
def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
104107
kwargs["convert_var_to_const"] = False
@@ -2014,7 +2017,6 @@ def test_reverse_sequence_time_major(self):
20142017
_ = tf.identity(x_, name=_TFOUTPUT)
20152018
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20162019

2017-
20182020
@check_opset_min_version(10, "ReverseSequence")
20192021
def test_reversev2_constant_axis(self):
20202022
# Tests for constant axis.
@@ -2034,7 +2036,6 @@ def test_reversev2_constant_axis(self):
20342036
_ = tf.identity(x_, name=_TFOUTPUT)
20352037
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20362038

2037-
20382039
@check_opset_min_version(10, "ReverseSequence")
20392040
def test_reversev2_vector_axis(self):
20402041
x_val_shape = [1, 2, 3, 4]
@@ -2060,7 +2061,6 @@ def test_reversev2_vector_axis(self):
20602061
_ = tf.identity(x_, name=_TFOUTPUT)
20612062
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20622063

2063-
20642064
@check_opset_min_version(10, "ReverseSequence")
20652065
def test_reversev2_1D_tensor(self):
20662066
# For tensors with 1 dimension and no axis to reverse.
@@ -2072,7 +2072,6 @@ def test_reversev2_1D_tensor(self):
20722072
_ = tf.identity(x_, name=_TFOUTPUT)
20732073
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20742074

2075-
20762075
@check_opset_min_version(8, "where")
20772076
def test_where(self):
20782077
x_val = np.array([1, 2, -3, 4, -5, -6, -7, 8, 9, 0], dtype=np.float32)
@@ -2552,5 +2551,36 @@ def test_selu(self):
25522551
_ = tf.identity(y, name=_TFOUTPUT)
25532552
self._run_test_case([_OUTPUT], {_INPUT: x_val})
25542553

2554+
def test_graph_matcher(self):
2555+
shape = [2, 6]
2556+
x_val = np.random.random(shape).astype(np.float32)
2557+
y_val = np.random.random(shape).astype(np.float32)
2558+
z_val = np.random.random(shape).astype(np.float32)
2559+
x = tf.placeholder(tf.float32, shape, name=_TFINPUT)
2560+
y = tf.placeholder(tf.float32, shape, name=_TFINPUT1)
2561+
z = tf.placeholder(tf.float32, shape, name=_TFINPUT2)
2562+
tmp1 = x + y
2563+
tmp2 = x - y
2564+
tmp3 = tf.multiply(tmp1, z)
2565+
tmp4 = tf.multiply(tmp2, z)
2566+
_ = tf.add(tmp4, tmp3, name=_TFOUTPUT)
2567+
onnx_graph = self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
2568+
pattern = \
2569+
OpTypePattern('Add', name='output', inputs=[
2570+
OpTypePattern('Mul', inputs=[
2571+
OpTypePattern('Add', name='input1'),
2572+
OpTypePattern('*', name='input2')]),
2573+
OpTypePattern('Mul', inputs=[
2574+
OpTypePattern('Sub', name='input1'),
2575+
OpTypePattern('*', name='input2')])])
2576+
2577+
matcher = GraphMatcher(pattern, allow_reorder=False)
2578+
match_results = list(matcher.match_ops(onnx_graph.get_nodes()))
2579+
self.assertTrue(len(match_results) == 0)
2580+
matcher = GraphMatcher(pattern, allow_reorder=True)
2581+
match_results = list(matcher.match_ops(onnx_graph.get_nodes()))
2582+
self.assertTrue(len(match_results) == 1)
2583+
2584+
25552585
if __name__ == '__main__':
25562586
unittest_main()

tf2onnx/graph_matcher.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
from __future__ import print_function
2020
from __future__ import unicode_literals
2121

22-
import copy
23-
22+
from itertools import permutations
2423
import six
2524

2625

27-
2826
class OpTypePattern(object):
2927
"""A tree pattern that matches TF expressions with certain op types."""
3028

@@ -137,6 +135,16 @@ def __init__(self, pattern, allow_reorder=False):
137135
self._pattern = pattern
138136
self._allow_reorder = allow_reorder
139137

138+
@staticmethod
139+
def _is_op_type_same(op, pattern):
140+
if pattern.op_type == "*":
141+
return True
142+
143+
if op.type in pattern.op_type.split('|'):
144+
return True
145+
146+
return False
147+
140148
def _match_pattern(self, pattern, op, tensor):
141149
"""Returns whether an TF expression rooted at `op` matches `pattern`.
142150
@@ -151,48 +159,55 @@ def _match_pattern(self, pattern, op, tensor):
151159
pattern tree.
152160
153161
Returns:
154-
True if an TF expression rooted at `op` matches `pattern`.
162+
if matched return True and match_list whose elem is [pattern, op, tensor]
163+
else return False
164+
the condition that op is matched with pattern:
165+
1 op is same:
166+
if pattern.op_type is None or *, then treat as same
167+
or op.type in pattern.op_type.split("|")
168+
2 op.inputs are same with pattern.inputs:
169+
if not pattern.inputs, then treat as same
170+
otherwise, iteratively compare input nodes with pattern.
155171
"""
172+
match_list = []
156173
if pattern.op_type is None:
157-
return True
158-
159-
if pattern.op_type != '*':
160-
if op is None or op.type not in pattern.op_type.split('|'):
161-
return False
174+
return True, match_list
162175

163-
self._match_result.add(pattern, op, tensor)
164-
# print("matched", ",".join([op.type + "|" + op.name for op in self._match_result.get_nodes()]))
176+
if self._is_op_type_same(op, pattern):
177+
match_list.append([pattern, op, tensor])
178+
else:
179+
return False, match_list
165180

166181
if not pattern.inputs:
167182
# If pattern.inputs is empty, skips the rest and accepts all the inputs.
168-
return True
183+
return True, match_list
169184

170185
if not op or len(op.inputs) != len(pattern.inputs):
171-
return False
186+
return False, match_list
172187

173188
if self._allow_reorder:
174-
inputs = [None] * len(op.inputs)
175-
wanted = copy.copy(pattern.inputs)
176-
for idx, i in enumerate(op.inputs):
177-
for j in range(len(wanted)): # pylint: disable=consider-using-enumerate
178-
if i.type == wanted[j].op_type:
179-
inputs[idx] = wanted[j]
180-
del wanted[j]
181-
break
182-
for idx, i in enumerate(inputs):
183-
if i is None:
184-
inputs[idx] = wanted[0]
185-
del wanted[0]
186-
pat = list(zip(op.inputs, inputs))
189+
pattern_inputs_list = permutations(pattern.inputs)
187190
else:
188-
pat = list(zip(op.inputs, pattern.inputs))
189-
190-
ret = []
191-
for input_tensor, input_pattern in pat:
192-
# print("MATCHING", input_pattern.op_type, input_tensor.type)
193-
r = self._match_pattern(input_pattern, input_tensor, input_tensor)
194-
ret.append(r)
195-
return all(ret)
191+
pattern_inputs_list = [pattern.inputs]
192+
193+
for possible_pattern_inputs in pattern_inputs_list:
194+
pat = list(zip(op.inputs, possible_pattern_inputs))
195+
match_flag_of_inputs = []
196+
match_lists_of_inputs = []
197+
for input_tensor, input_pattern in pat:
198+
# print("MATCHING", input_pattern.op_type, input_tensor.type)
199+
flag, match_list_of_input = self._match_pattern(input_pattern, input_tensor, input_tensor)
200+
match_flag_of_inputs.append(flag)
201+
match_lists_of_inputs.extend(match_list_of_input)
202+
203+
if all(match_flag_of_inputs):
204+
match_list.extend(match_lists_of_inputs)
205+
return True, match_list
206+
return False, match_list
207+
208+
def _parse_match_list_to_match_result(self, match_list):
209+
for pattern, op, tensor in match_list:
210+
self._match_result.add(pattern, op, tensor)
196211

197212
def match_op(self, op):
198213
"""Matches `op` against `self._pattern`.
@@ -205,8 +220,10 @@ def match_op(self, op):
205220
None.
206221
"""
207222
self._match_result = MatchResult()
208-
if not self._match_pattern(self._pattern, op, tensor=None):
223+
match_flag, match_list = self._match_pattern(self._pattern, op, tensor=None)
224+
if not match_flag:
209225
return None
226+
self._parse_match_list_to_match_result(match_list)
210227
return self._match_result
211228

212229
def match_ops(self, ops):

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,14 @@ def _get_weight_and_bias_for_lstm_cell(self, context):
125125
def parse_attributes(self, context):
126126
if self.lstm_cell_type == RNNUnitType.LSTMBlockCell:
127127
lstm_block_cell = context.cell_match.get_op("lstm_block_cell")
128-
clip = float(lstm_block_cell.get_attr("cell_clip").f)
128+
clip = lstm_block_cell.get_attr_value("cell_clip")
129129
# current LSTM op cannot handle clip
130130
if clip > 0:
131131
return False
132+
133+
use_peephole = lstm_block_cell.get_attr_value("use_peephole")
134+
if use_peephole:
135+
return False
132136
return True
133137

134138
def _ct_variable_finder(self, context):

0 commit comments

Comments
 (0)