File tree Expand file tree Collapse file tree 2 files changed +19
-2
lines changed Expand file tree Collapse file tree 2 files changed +19
-2
lines changed Original file line number Diff line number Diff line change @@ -257,6 +257,22 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
257
257
std::unique_ptr<ir::Graph> graph) const {
258
258
PDPattern external_pattern, subblock_pattern;
259
259
260
+ // Use the following variables to tell whether this model is RNN1.
261
+ // This fuse can only works on the RNN1 model.
262
+ std::unordered_set<std::string> specified_vars ({" data_lod_attention" ,
263
+ " cell_init" , " hidden_init" ,
264
+ " data" , " week" , " minute" });
265
+ int count = 0 ;
266
+ for (auto * node : graph->Nodes ()) {
267
+ if (node->IsVar () && specified_vars.count (node->Name ())) {
268
+ ++count;
269
+ }
270
+ }
271
+ if (count < specified_vars.size ()) {
272
+ return graph;
273
+ }
274
+
275
+ // Continue to fuse.
260
276
FindWhileOp (graph.get ());
261
277
return graph;
262
278
}
Original file line number Diff line number Diff line change @@ -212,10 +212,11 @@ struct AnalysisConfig : public NativeConfig {
212
212
kExclude // Specify the disabled passes in `ir_passes`.
213
213
};
214
214
215
+ // Determine whether to perform graph optimization.
215
216
bool enable_ir_optim = true ;
217
+ // Manually determine the IR passes to run.
216
218
IrPassMode ir_mode{IrPassMode::kExclude };
217
- // attention lstm fuse works only on some specific models, disable as default.
218
- std::vector<std::string> ir_passes{" attention_lstm_fuse_pass" };
219
+ std::vector<std::string> ir_passes;
219
220
220
221
// NOTE this is just for internal development, please not use it.
221
222
bool _use_mkldnn{false };
You can’t perform that action at this time.
0 commit comments