Skip to content

Commit 7100133

Browse files
committed
add pattern for lstm in 1.15.0
1 parent 59fed17 commit 7100133

File tree

1 file changed

+47
-13
lines changed

1 file changed

+47
-13
lines changed

tf2onnx/rewriter/rnn_utils.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,22 @@ class REWRITER_RESULT(Enum):
3030

3131

3232
# TensorFlow LSTMCell/BasicLSTMCell computation graph matching
33-
xc_pattern = OpTypePattern('Split', inputs=[
34-
OpTypePattern("Const"), # axis for split
35-
OpTypePattern("BiasAdd", name="bias_add", inputs=[
36-
OpTypePattern("MatMul", inputs=[
37-
OpTypePattern("ConcatV2|Concat", name="xh"),
33+
34+
xc_pattern = \
35+
OpTypePattern('Split', inputs=[
36+
OpTypePattern("Const"), # axis for split
37+
OpTypePattern("BiasAdd", name="bias_add", inputs=[
38+
OpTypePattern("MatMul", inputs=[
39+
OpTypePattern("ConcatV2|Concat", name="xh"),
40+
OpTypePattern("Enter", inputs=[
41+
OpTypePattern("*", name="cell_kernel"),
42+
]),
43+
]),
3844
OpTypePattern("Enter", inputs=[
39-
OpTypePattern("*", name="cell_kernel"),
45+
OpTypePattern("*", name="cell_bias"),
4046
]),
4147
]),
42-
OpTypePattern("Enter", inputs=[
43-
OpTypePattern("*", name="cell_bias"),
44-
]),
45-
]),
46-
])
47-
48+
])
4849

4950
lstmcell_pattern = \
5051
OpTypePattern('Mul', name='ht', inputs=[
@@ -68,6 +69,39 @@ class REWRITER_RESULT(Enum):
6869
]),
6970
])
7071

72+
xc_pattern_optimized = \
73+
OpTypePattern('Split', inputs=[
74+
OpTypePattern("Const"),
75+
OpTypePattern("Identity", inputs=[
76+
OpTypePattern("MatMul", inputs=[
77+
OpTypePattern("ConcatV2|Concat", name="xh"),
78+
OpTypePattern("Const", name="cell_kernel"),
79+
]),
80+
]),
81+
])
82+
83+
lstmcell_pattern_optimized = \
84+
OpTypePattern('Mul', name='ht', inputs=[
85+
OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern_optimized]),
86+
OpTypePattern('Tanh', inputs=[
87+
OpTypePattern("Add|AddV2", name="ct", inputs=[
88+
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
89+
OpTypePattern("Sigmoid", name="ft", inputs=[
90+
OpTypePattern("Add|AddV2", inputs=[
91+
xc_pattern_optimized,
92+
OpTypePattern("*", name="ft_bias"),
93+
]),
94+
]),
95+
OpTypePattern("*"),
96+
]),
97+
OpTypePattern("Mul", inputs=[
98+
OpTypePattern("Sigmoid", name="it", inputs=[xc_pattern_optimized]),
99+
OpTypePattern("Tanh", name="gt", inputs=[xc_pattern_optimized]),
100+
]),
101+
]),
102+
]),
103+
])
104+
71105
# input sequence: top to down, left to right
72106
# split into update gate and reset gate
73107
gru_split_pattern = \
@@ -237,7 +271,7 @@ class RNNUnitType(Enum):
237271

238272

239273
rnn_cell_patterns = {
240-
RNNUnitType.LSTMCell: [lstmcell_pattern],
274+
RNNUnitType.LSTMCell: [lstmcell_pattern, lstmcell_pattern_optimized],
241275
RNNUnitType.LSTMBlockCell: [lstmblockcell_pattern],
242276
RNNUnitType.GRUCell: [grucell_pattern],
243277
RNNUnitType.GRUBlockCell: [grublockcell_pattern0, grublockcell_pattern1],

0 commit comments

Comments
 (0)