Skip to content

Commit 9ec80aa

Browse files
Merge pull request #913 from RandySheriffH/rashuai/FixLSTM
Fix LSTM pattern matching for version between 1.15.0 and 2.x.
2 parents 8b8d5ea + 8c292e0 commit 9ec80aa

File tree

3 files changed

+58
-21
lines changed

3 files changed

+58
-21
lines changed

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def check_op_count(graph, op_type, expected_count):
332332

333333

334334
def check_lstm_count(graph, expected_count):
335-
return check_op_count(graph, "LSTM", expected_count)
335+
return len(group_nodes_by_type(graph)["LSTM"]) == expected_count
336336

337337

338338
def check_gru_count(graph, expected_count):

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,18 @@ def _get_weight_and_bias_for_lstm_cell(self, context):
9898
# check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add
9999
# for bias_add data format
100100
bias_add = match.get_op("bias_add")
101-
if bias_add.data_format != "NHWC":
101+
if bias_add is not None and bias_add.data_format != "NHWC":
102102
logger.debug("BiasAdd data_format is not NHWC, SKIP")
103103
return None
104104

105105
b_e = match.get_op("cell_bias")
106-
b = get_weights_from_const_node(self.g, b_e)
107-
if b is None or b.shape[0] != w.shape[1]:
108-
logger.warning("cell_kernel and cell_bias's dimensions does not match, skip")
109-
return None
106+
if b_e is None:
107+
b = np.array([0 for i in range(len(w[0]))]).astype(w.dtype)
108+
else:
109+
b = get_weights_from_const_node(self.g, b_e)
110+
if b is None or b.shape[0] != w.shape[1]:
111+
logger.warning("cell_kernel and cell_bias's dimensions does not match, skip")
112+
return None
110113

111114
ft_bias_node = match.get_op("ft_bias")
112115
ft_bias = get_weights_from_const_node(self.g, ft_bias_node)

tf2onnx/rewriter/rnn_utils.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,30 +30,31 @@ class REWRITER_RESULT(Enum):
3030

3131

3232
# TensorFlow LSTMCell/BasicLSTMCell computation graph matching
33-
xc_pattern = OpTypePattern('Split', inputs=[
34-
OpTypePattern("Const"), # axis for split
35-
OpTypePattern("BiasAdd", name="bias_add", inputs=[
36-
OpTypePattern("MatMul", inputs=[
37-
OpTypePattern("ConcatV2|Concat", name="xh"),
33+
34+
xc_pattern = \
35+
OpTypePattern('Split', inputs=[
36+
OpTypePattern("Const"), # axis for split
37+
OpTypePattern("BiasAdd", name="bias_add", inputs=[
38+
OpTypePattern("MatMul", inputs=[
39+
OpTypePattern("ConcatV2|Concat", name="xh"),
40+
OpTypePattern("Enter", inputs=[
41+
OpTypePattern("*", name="cell_kernel"),
42+
]),
43+
]),
3844
OpTypePattern("Enter", inputs=[
39-
OpTypePattern("*", name="cell_kernel"),
45+
OpTypePattern("*", name="cell_bias"),
4046
]),
4147
]),
42-
OpTypePattern("Enter", inputs=[
43-
OpTypePattern("*", name="cell_bias"),
44-
]),
45-
]),
46-
])
47-
48+
])
4849

4950
lstmcell_pattern = \
5051
OpTypePattern('Mul', name='ht', inputs=[
5152
OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern]),
5253
OpTypePattern('Tanh', inputs=[
53-
OpTypePattern("Add", name="ct", inputs=[
54+
OpTypePattern("Add|AddV2", name="ct", inputs=[
5455
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
5556
OpTypePattern("Sigmoid", name="ft", inputs=[
56-
OpTypePattern("Add", inputs=[
57+
OpTypePattern("Add|AddV2", inputs=[
5758
xc_pattern,
5859
OpTypePattern("*", name="ft_bias"),
5960
]),
@@ -68,6 +69,39 @@ class REWRITER_RESULT(Enum):
6869
]),
6970
])
7071

72+
xc_pattern_optimized = \
73+
OpTypePattern('Split', inputs=[
74+
OpTypePattern("Const"),
75+
OpTypePattern("Identity", inputs=[
76+
OpTypePattern("MatMul", inputs=[
77+
OpTypePattern("ConcatV2|Concat", name="xh"),
78+
OpTypePattern("Const", name="cell_kernel"),
79+
]),
80+
]),
81+
])
82+
83+
lstmcell_pattern_optimized = \
84+
OpTypePattern('Mul', name='ht', inputs=[
85+
OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern_optimized]),
86+
OpTypePattern('Tanh', inputs=[
87+
OpTypePattern("Add|AddV2", name="ct", inputs=[
88+
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
89+
OpTypePattern("Sigmoid", name="ft", inputs=[
90+
OpTypePattern("Add|AddV2", inputs=[
91+
xc_pattern_optimized,
92+
OpTypePattern("*", name="ft_bias"),
93+
]),
94+
]),
95+
OpTypePattern("*"),
96+
]),
97+
OpTypePattern("Mul", inputs=[
98+
OpTypePattern("Sigmoid", name="it", inputs=[xc_pattern_optimized]),
99+
OpTypePattern("Tanh", name="gt", inputs=[xc_pattern_optimized]),
100+
]),
101+
]),
102+
]),
103+
])
104+
71105
# input sequence: top to down, left to right
72106
# split into update gate and reset gate
73107
gru_split_pattern = \
@@ -237,7 +271,7 @@ class RNNUnitType(Enum):
237271

238272

239273
rnn_cell_patterns = {
240-
RNNUnitType.LSTMCell: [lstmcell_pattern],
274+
RNNUnitType.LSTMCell: [lstmcell_pattern, lstmcell_pattern_optimized],
241275
RNNUnitType.LSTMBlockCell: [lstmblockcell_pattern],
242276
RNNUnitType.GRUCell: [grucell_pattern],
243277
RNNUnitType.GRUBlockCell: [grublockcell_pattern0, grublockcell_pattern1],

0 commit comments

Comments
 (0)