Skip to content

Commit 9e8d372

Browse files
authored
hide attention lstm fuse (#13615)
1 parent e9bc5fa commit 9e8d372

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,22 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
257257
std::unique_ptr<ir::Graph> graph) const {
258258
PDPattern external_pattern, subblock_pattern;
259259

260+
// Use the following variables to tell whether this model is RNN1.
261+
// This fuse can only works on the RNN1 model.
262+
std::unordered_set<std::string> specified_vars({"data_lod_attention",
263+
"cell_init", "hidden_init",
264+
"data", "week", "minute"});
265+
int count = 0;
266+
for (auto* node : graph->Nodes()) {
267+
if (node->IsVar() && specified_vars.count(node->Name())) {
268+
++count;
269+
}
270+
}
271+
if (count < specified_vars.size()) {
272+
return graph;
273+
}
274+
275+
// Continue to fuse.
260276
FindWhileOp(graph.get());
261277
return graph;
262278
}

paddle/fluid/inference/api/paddle_inference_api.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,11 @@ struct AnalysisConfig : public NativeConfig {
212212
kExclude // Specify the disabled passes in `ir_passes`.
213213
};
214214

215+
// Determine whether to perform graph optimization.
215216
bool enable_ir_optim = true;
217+
// Manually determine the IR passes to run.
216218
IrPassMode ir_mode{IrPassMode::kExclude};
217-
// attention lstm fuse works only on some specific models, disable as default.
218-
std::vector<std::string> ir_passes{"attention_lstm_fuse_pass"};
219+
std::vector<std::string> ir_passes;
219220

220221
// NOTE this is just for internal development, please not use it.
221222
bool _use_mkldnn{false};

0 commit comments

Comments
 (0)