@@ -31,7 +31,8 @@ namespace framework {
31
31
namespace ir {
32
32
namespace patterns {
33
33
34
- static PDNode* create_emb_vars (PDPattern* pattern, const std::string& name,
34
+ static PDNode* create_emb_vars (PDPattern* pattern,
35
+ const std::string& name,
35
36
const std::string& arg,
36
37
bool is_persist = false ) {
37
38
std::unordered_set<std::string> embedding_ops{" lookup_table" ,
@@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
41
42
if (is_persist) return node->assert_is_persistable_var ();
42
43
return node;
43
44
}
44
- static PDNode* create_emb_out_vars (PDPattern* pattern, const std::string& name,
45
+ static PDNode* create_emb_out_vars (PDPattern* pattern,
46
+ const std::string& name,
45
47
const std::string& arg) {
46
48
std::unordered_set<std::string> embedding_ops{" lookup_table" ,
47
49
" lookup_table_v2" };
@@ -62,6 +64,9 @@ void Embedding2Eltwise1Pattern::operator()() {
62
64
create_emb_vars (pattern, lookup_table2_w_repr (), " W" , true );
63
65
std::unordered_set<std::string> embedding_ops{" lookup_table" ,
64
66
" lookup_table_v2" };
67
+
68
+ auto * feed1 = pattern->NewNode (feed1_repr ())->assert_is_op (" feed" );
69
+ auto * feed2 = pattern->NewNode (feed2_repr ())->assert_is_op (" feed" );
65
70
auto * lookup_table1 =
66
71
pattern->NewNode (lookup_table1_repr ())->assert_is_ops (embedding_ops);
67
72
auto * lookup_table2 =
@@ -74,8 +79,10 @@ void Embedding2Eltwise1Pattern::operator()() {
74
79
pattern->NewNode (eltwise_add_repr ())->assert_is_op (" elementwise_add" );
75
80
auto * eltwise_add_out = pattern->NewNode (eltwise_add_out_repr ())
76
81
->assert_is_op_output (" elementwise_add" );
82
+ feed1->LinksTo ({lookup_table1_x});
77
83
lookup_table1->LinksFrom ({lookup_table1_x, lookup_table1_w})
78
84
.LinksTo ({lookup_table1_out});
85
+ feed2->LinksTo ({lookup_table2_x});
79
86
lookup_table2->LinksFrom ({lookup_table2_x, lookup_table2_w})
80
87
.LinksTo ({lookup_table2_out});
81
88
eltwise_add->LinksFrom ({lookup_table1_out, lookup_table2_out})
@@ -88,6 +95,7 @@ void Embedding1Eltwise1Pattern::operator()() {
88
95
create_emb_vars (pattern, lookup_table1_w_repr (), " W" , true );
89
96
std::unordered_set<std::string> embedding_ops{" lookup_table" ,
90
97
" lookup_table_v2" };
98
+ auto * feed1 = pattern->NewNode (feed1_repr ())->assert_is_op (" feed" );
91
99
auto * lookup_table1 =
92
100
pattern->NewNode (lookup_table1_repr ())->assert_is_ops (embedding_ops);
93
101
auto * lookup_table1_out =
@@ -99,6 +107,7 @@ void Embedding1Eltwise1Pattern::operator()() {
99
107
->assert_is_op_output (" elementwise_add" );
100
108
auto * eltwise_add_out = pattern->NewNode (eltwise_add_out_repr ())
101
109
->assert_is_op_output (" elementwise_add" );
110
+ feed1->LinksTo ({lookup_table1_x});
102
111
lookup_table1->LinksFrom ({lookup_table1_x, lookup_table1_w})
103
112
.LinksTo ({lookup_table1_out});
104
113
eltwise_add->LinksFrom ({lookup_table1_out, eltwise_add_in})
@@ -161,10 +170,10 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
161
170
GET_IR_NODE_FROM_SUBGRAPH (lookup_table2_w, lookup_table2_w, start_pattern);
162
171
GET_IR_NODE_FROM_SUBGRAPH (lookup_table1, lookup_table1, start_pattern);
163
172
GET_IR_NODE_FROM_SUBGRAPH (lookup_table2, lookup_table2, start_pattern);
164
- GET_IR_NODE_FROM_SUBGRAPH (lookup_table1_out, lookup_table1_out,
165
- start_pattern);
166
- GET_IR_NODE_FROM_SUBGRAPH (lookup_table2_out, lookup_table2_out,
167
- start_pattern);
173
+ GET_IR_NODE_FROM_SUBGRAPH (
174
+ lookup_table1_out, lookup_table1_out, start_pattern);
175
+ GET_IR_NODE_FROM_SUBGRAPH (
176
+ lookup_table2_out, lookup_table2_out, start_pattern);
168
177
GET_IR_NODE_FROM_SUBGRAPH (eltwise_add, eltwise_add, start_pattern);
169
178
GET_IR_NODE_FROM_SUBGRAPH (eltwise_add_out, eltwise_add_out, start_pattern);
170
179
if (!IsCompat (subgraph, graph)) {
@@ -178,8 +187,12 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
178
187
start_pattern_out_node.push_back (eltwise_add_out);
179
188
180
189
std::unordered_set<Node*> rm_nodes;
181
- rm_nodes.insert ({lookup_table1, lookup_table2, lookup_table1_out,
182
- lookup_table2_out, eltwise_add, eltwise_add_out});
190
+ rm_nodes.insert ({lookup_table1,
191
+ lookup_table2,
192
+ lookup_table1_out,
193
+ lookup_table2_out,
194
+ eltwise_add,
195
+ eltwise_add_out});
183
196
start_pattern_remove_nodes.push_back (rm_nodes);
184
197
};
185
198
gpd (graph, handler);
@@ -199,8 +212,8 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
199
212
GET_IR_NODE_FROM_SUBGRAPH (lookup_table1_x, lookup_table1_x, second_pattern);
200
213
GET_IR_NODE_FROM_SUBGRAPH (lookup_table1_w, lookup_table1_w, second_pattern);
201
214
GET_IR_NODE_FROM_SUBGRAPH (lookup_table1, lookup_table1, second_pattern);
202
- GET_IR_NODE_FROM_SUBGRAPH (lookup_table1_out, lookup_table1_out,
203
- second_pattern);
215
+ GET_IR_NODE_FROM_SUBGRAPH (
216
+ lookup_table1_out, lookup_table1_out, second_pattern);
204
217
GET_IR_NODE_FROM_SUBGRAPH (eltwise_add_in, eltwise_add_in, second_pattern);
205
218
GET_IR_NODE_FROM_SUBGRAPH (eltwise_add, eltwise_add, second_pattern);
206
219
GET_IR_NODE_FROM_SUBGRAPH (eltwise_add_out, eltwise_add_out, second_pattern);
@@ -234,19 +247,19 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
234
247
auto handler3 = [&](const GraphPatternDetector::subgraph_t & subgraph,
235
248
Graph* g) {
236
249
GET_IR_NODE_FROM_SUBGRAPH (eltwise_add, eltwise_add, skip_layernorm_pattern);
237
- GET_IR_NODE_FROM_SUBGRAPH (eltwise_add_out, eltwise_add_out,
238
- skip_layernorm_pattern);
250
+ GET_IR_NODE_FROM_SUBGRAPH (
251
+ eltwise_add_out, eltwise_add_out, skip_layernorm_pattern);
239
252
GET_IR_NODE_FROM_SUBGRAPH (layer_norm, layer_norm, skip_layernorm_pattern);
240
- GET_IR_NODE_FROM_SUBGRAPH (layer_norm_out, layer_norm_out,
241
- skip_layernorm_pattern);
242
- GET_IR_NODE_FROM_SUBGRAPH (layer_norm_bias, layer_norm_bias,
243
- skip_layernorm_pattern);
244
- GET_IR_NODE_FROM_SUBGRAPH (layer_norm_scale, layer_norm_scale,
245
- skip_layernorm_pattern);
246
- GET_IR_NODE_FROM_SUBGRAPH (layer_norm_mean, layer_norm_mean,
247
- skip_layernorm_pattern);
248
- GET_IR_NODE_FROM_SUBGRAPH (layer_norm_variance, layer_norm_variance,
249
- skip_layernorm_pattern);
253
+ GET_IR_NODE_FROM_SUBGRAPH (
254
+ layer_norm_out, layer_norm_out, skip_layernorm_pattern);
255
+ GET_IR_NODE_FROM_SUBGRAPH (
256
+ layer_norm_bias, layer_norm_bias, skip_layernorm_pattern);
257
+ GET_IR_NODE_FROM_SUBGRAPH (
258
+ layer_norm_scale, layer_norm_scale, skip_layernorm_pattern);
259
+ GET_IR_NODE_FROM_SUBGRAPH (
260
+ layer_norm_mean, layer_norm_mean, skip_layernorm_pattern);
261
+ GET_IR_NODE_FROM_SUBGRAPH (
262
+ layer_norm_variance, layer_norm_variance, skip_layernorm_pattern);
250
263
if (!IsCompat (subgraph, graph)) {
251
264
LOG (WARNING) << " Pass(SkipLayerNorm) in op compat failed." ;
252
265
return ;
0 commit comments