@@ -30,30 +30,31 @@ class REWRITER_RESULT(Enum):
30
30
31
31
32
32
# TensorFlow LSTMCell/BasicLSTMCell computation graph matching
33
- xc_pattern = OpTypePattern ('Split' , inputs = [
34
- OpTypePattern ("Const" ), # axis for split
35
- OpTypePattern ("BiasAdd" , name = "bias_add" , inputs = [
36
- OpTypePattern ("MatMul" , inputs = [
37
- OpTypePattern ("ConcatV2|Concat" , name = "xh" ),
33
+
34
+ xc_pattern = \
35
+ OpTypePattern ('Split' , inputs = [
36
+ OpTypePattern ("Const" ), # axis for split
37
+ OpTypePattern ("BiasAdd" , name = "bias_add" , inputs = [
38
+ OpTypePattern ("MatMul" , inputs = [
39
+ OpTypePattern ("ConcatV2|Concat" , name = "xh" ),
40
+ OpTypePattern ("Enter" , inputs = [
41
+ OpTypePattern ("*" , name = "cell_kernel" ),
42
+ ]),
43
+ ]),
38
44
OpTypePattern ("Enter" , inputs = [
39
- OpTypePattern ("*" , name = "cell_kernel " ),
45
+ OpTypePattern ("*" , name = "cell_bias " ),
40
46
]),
41
47
]),
42
- OpTypePattern ("Enter" , inputs = [
43
- OpTypePattern ("*" , name = "cell_bias" ),
44
- ]),
45
- ]),
46
- ])
47
-
48
+ ])
48
49
49
50
lstmcell_pattern = \
50
51
OpTypePattern ('Mul' , name = 'ht' , inputs = [
51
52
OpTypePattern ("Sigmoid" , name = "ot" , inputs = [xc_pattern ]),
52
53
OpTypePattern ('Tanh' , inputs = [
53
- OpTypePattern ("Add" , name = "ct" , inputs = [
54
+ OpTypePattern ("Add|AddV2 " , name = "ct" , inputs = [
54
55
OpTypePattern ("Mul" , name = "ct_identity_consumer" , inputs = [
55
56
OpTypePattern ("Sigmoid" , name = "ft" , inputs = [
56
- OpTypePattern ("Add" , inputs = [
57
+ OpTypePattern ("Add|AddV2 " , inputs = [
57
58
xc_pattern ,
58
59
OpTypePattern ("*" , name = "ft_bias" ),
59
60
]),
@@ -68,6 +69,39 @@ class REWRITER_RESULT(Enum):
68
69
]),
69
70
])
70
71
72
+ xc_pattern_optimized = \
73
+ OpTypePattern ('Split' , inputs = [
74
+ OpTypePattern ("Const" ),
75
+ OpTypePattern ("Identity" , inputs = [
76
+ OpTypePattern ("MatMul" , inputs = [
77
+ OpTypePattern ("ConcatV2|Concat" , name = "xh" ),
78
+ OpTypePattern ("Const" , name = "cell_kernel" ),
79
+ ]),
80
+ ]),
81
+ ])
82
+
83
+ lstmcell_pattern_optimized = \
84
+ OpTypePattern ('Mul' , name = 'ht' , inputs = [
85
+ OpTypePattern ("Sigmoid" , name = "ot" , inputs = [xc_pattern_optimized ]),
86
+ OpTypePattern ('Tanh' , inputs = [
87
+ OpTypePattern ("Add|AddV2" , name = "ct" , inputs = [
88
+ OpTypePattern ("Mul" , name = "ct_identity_consumer" , inputs = [
89
+ OpTypePattern ("Sigmoid" , name = "ft" , inputs = [
90
+ OpTypePattern ("Add|AddV2" , inputs = [
91
+ xc_pattern_optimized ,
92
+ OpTypePattern ("*" , name = "ft_bias" ),
93
+ ]),
94
+ ]),
95
+ OpTypePattern ("*" ),
96
+ ]),
97
+ OpTypePattern ("Mul" , inputs = [
98
+ OpTypePattern ("Sigmoid" , name = "it" , inputs = [xc_pattern_optimized ]),
99
+ OpTypePattern ("Tanh" , name = "gt" , inputs = [xc_pattern_optimized ]),
100
+ ]),
101
+ ]),
102
+ ]),
103
+ ])
104
+
71
105
# input sequence: top to down, left to right
72
106
# split into update gate and reset gate
73
107
gru_split_pattern = \
@@ -237,7 +271,7 @@ class RNNUnitType(Enum):
237
271
238
272
239
273
rnn_cell_patterns = {
240
- RNNUnitType .LSTMCell : [lstmcell_pattern ],
274
+ RNNUnitType .LSTMCell : [lstmcell_pattern , lstmcell_pattern_optimized ],
241
275
RNNUnitType .LSTMBlockCell : [lstmblockcell_pattern ],
242
276
RNNUnitType .GRUCell : [grucell_pattern ],
243
277
RNNUnitType .GRUBlockCell : [grublockcell_pattern0 , grublockcell_pattern1 ],
0 commit comments