Skip to content

Commit ae67dce

Browse files
authored
Merge pull request #13366 from luotao1/fusion_lstm_bug
fix fusion_lstm unique_name bug
2 parents f6cbe10 + b12322c commit ae67dce

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
5151
if (with_fc_bias) {
5252
// Add FC-bias with LSTM-bias and create a new weight
5353
PADDLE_ENFORCE(scope);
54-
const std::string& new_bias_var = name_scope + "_bias.new";
54+
const std::string& new_bias_var = patterns::UniqueKey("NewBias");
5555
auto* bias_var = scope->Var(new_bias_var);
5656
PADDLE_ENFORCE(bias_var);
5757
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
120120

121121
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
122122
Graph* g) {
123-
124123
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
125124
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
126125
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
136135
fc_bias);
137136
// Remove unneeded nodes.
138137
std::unordered_set<const Node*> marked_nodes(
139-
{mul, lstm, elementwise_add});
138+
{mul, lstm, elementwise_add, fc_bias});
140139
GraphSafeRemoveNodes(graph, marked_nodes);
141140
} else {
142141
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);

paddle/fluid/inference/analysis/ir_pass_manager.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
1616
#include <string>
17+
#include <vector>
1718
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
1819
#include "paddle/fluid/framework/ir/graph.h"
1920
#include "paddle/fluid/framework/scope.h"
@@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
3738
void IRPassManager::Apply(const std::vector<std::string> &passes) {
3839
// Apply all the passes
3940
std::string pre_pass;
41+
int pass_num = 0;
4042
for (const std::string &pass_name : passes) {
4143
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name);
4244
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
4345
if (pass_name == "graph_viz_pass") {
44-
std::string dot_file_path =
45-
"ir_" + (pre_pass.empty() ? "origin" : pre_pass) + ".dot";
46+
std::string dot_file_path = std::to_string(pass_num) + "_ir_" +
47+
(pre_pass.empty() ? "origin" : pre_pass) +
48+
".dot";
4649
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
50+
pass_num++;
4751
}
4852
graph_ = pass->Apply(std::move(graph_));
4953
pre_pass = pass_name;

0 commit comments

Comments
 (0)