Skip to content

Commit 6945a80

Browse files
authored
cherry-pick 22551. test=develop test=release/1.7 (#22609)
[cherry-pick] #22551 当一个模型中有多个fc_lstm子图的时候,且其中fc共用了同一个persistable的bias,此时不应该将bias节点删除,只将非persistable的节点去除即可。
1 parent a06883c commit 6945a80

File tree

6 files changed

+301
-6
lines changed

6 files changed

+301
-6
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
116116
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
117117
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
118118
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
119+
cc_test(test_fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass_tester.cc DEPS fc_lstm_fuse_pass framework_proto)
120+
cc_test(test_fc_gru_fuse_pass SRCS fc_gru_fuse_pass_tester.cc DEPS fc_gru_fuse_pass framework_proto)
119121
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
120122
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
121123
cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto)

paddle/fluid/framework/ir/fc_gru_fuse_pass.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
127127
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, gru_pattern);
128128
// nodes need be removed
129129
GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, gru_pattern);
130-
GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchGate, gru_pattern);
131-
GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchGate, gru_pattern);
130+
GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchResetHiddenPrev,
131+
gru_pattern);
132+
GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchHidden, gru_pattern);
132133

133134
if (with_fc_bias) {
134135
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
@@ -138,7 +139,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
138139
gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias);
139140
// Remove unneeded nodes.
140141
std::unordered_set<const Node*> marked_nodes(
141-
{mul, gru, elementwise_add, fc_bias, fc_out, mul_out, BatchGate,
142+
{mul, gru, elementwise_add, fc_out, mul_out, BatchGate,
142143
BatchResetHiddenPrev, BatchHidden});
143144
GraphSafeRemoveNodes(graph, marked_nodes);
144145
} else {
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"
16+
17+
#include <gtest/gtest.h>
18+
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
void AddVarToScope(Scope* param_scope, const std::string& name,
25+
const DDim& dims) {
26+
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
27+
tensor->Resize(dims);
28+
tensor->mutable_data<float>(platform::CPUPlace());
29+
}
30+
31+
Scope* CreateParamScope() {
32+
auto param_scope = new Scope();
33+
AddVarToScope(param_scope, "gru_fc_w", {});
34+
AddVarToScope(param_scope, "gru_fc_b", {});
35+
AddVarToScope(param_scope, "gru_w", {});
36+
AddVarToScope(param_scope, "gru_b", {});
37+
AddVarToScope(param_scope, "gru_batch_gate_0", {});
38+
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_0", {});
39+
AddVarToScope(param_scope, "gru_batch_hidden_0", {});
40+
AddVarToScope(param_scope, "gru_hidden_0", {});
41+
AddVarToScope(param_scope, "gru_batch_gate_1", {});
42+
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_1", {});
43+
AddVarToScope(param_scope, "gru_batch_hidden_1", {});
44+
AddVarToScope(param_scope, "gru_hidden_1", {});
45+
return param_scope;
46+
}
47+
48+
TEST(FCFusePass, basic) {
49+
// inputs operator output
50+
// --------------------------------------------------------
51+
// (a, gru_fc_w) mul -> fc_0_tmp_0
52+
// (fc_0_tmp_0, gru_fc_b) elementwise_add -> fc_0_tmp_1
53+
// (fc_0_tmp_1,gru_w,gru_b gru -> gru_out_0
54+
55+
// (b, gru_fc_w) mul -> fc_1_tmp_0
56+
// (fc_1_tmp_0, gru_fc_b) elementwise_add -> fc_1_tmp_1
57+
// (fc_1_tmp_1,gru_w,gru_b) gru -> gru_out_1
58+
Layers layers;
59+
auto* a = layers.data("a");
60+
auto* b = layers.data("b");
61+
auto* fc_w = layers.data("gru_fc_w", {}, true);
62+
auto* fc_b = layers.data("gru_fc_b", {}, true);
63+
auto* gru_w = layers.data("gru_w", {}, true);
64+
auto* gru_b = layers.data("gru_b", {}, true);
65+
auto* fc_0_tmp0 = layers.mul(a, fc_w);
66+
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
67+
auto* gru_batch_gate_0 = layers.data("gru_batch_gate_0", {}, false);
68+
auto* gru_batch_reset_hidden_prev_0 =
69+
layers.data("gru_batch_reset_hidden_prev_0", {}, false);
70+
auto* gru_batch_hidden_0 = layers.data("gru_batch_hidden_0", {}, false);
71+
auto* gru_hidden_0 = layers.data("gru_hidden_0", {}, false);
72+
layers.gru(fc_0_tmp1, gru_w, gru_b, gru_batch_gate_0,
73+
gru_batch_reset_hidden_prev_0, gru_batch_hidden_0, gru_hidden_0);
74+
75+
auto* fc_1_tmp0 = layers.mul(b, fc_w);
76+
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
77+
auto* gru_batch_gate_1 = layers.data("gru_batch_gate_1", {}, false);
78+
auto* gru_batch_reset_hidden_prev_1 =
79+
layers.data("gru_batch_reset_hidden_prev_1", {}, false);
80+
auto* gru_batch_hidden_1 = layers.data("gru_batch_hidden_1", {}, false);
81+
auto* gru_hidden_1 = layers.data("gru_hidden_1", {}, false);
82+
layers.gru(fc_1_tmp1, gru_w, gru_b, gru_batch_gate_1,
83+
gru_batch_reset_hidden_prev_1, gru_batch_hidden_1, gru_hidden_1);
84+
85+
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
86+
auto pass = PassRegistry::Instance().Get("fc_gru_fuse_pass");
87+
pass->Set("use_gpu", new bool(true));
88+
graph->Set("__param_scope__", CreateParamScope());
89+
int num_nodes_before = graph->Nodes().size();
90+
int num_gru_nodes_before = GetNumOpNodes(graph, "gru");
91+
VLOG(3) << DebugString(graph);
92+
93+
graph.reset(pass->Apply(graph.release()));
94+
int num_nodes_after = graph->Nodes().size();
95+
int num_fuse_gru_nodes_after = GetNumOpNodes(graph, "fusion_gru");
96+
VLOG(3) << DebugString(graph);
97+
98+
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6,
99+
platform::errors::PreconditionNotMet(
100+
"The number of nodes before and after "
101+
"the fuse does not meet expectations"));
102+
PADDLE_ENFORCE_EQ(
103+
num_fuse_gru_nodes_after, 2,
104+
platform::errors::PreconditionNotMet("The number of gru nodes before the "
105+
"fuse does not meet expectations"));
106+
PADDLE_ENFORCE_EQ(num_gru_nodes_before, num_fuse_gru_nodes_after,
107+
platform::errors::PreconditionNotMet(
108+
"The number of fusion_gru nodes does not meet "
109+
"expectations after fuse"));
110+
}
111+
112+
} // namespace ir
113+
} // namespace framework
114+
} // namespace paddle
115+
116+
USE_PASS(fc_gru_fuse_pass);

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,26 +133,30 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
133133
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
134134
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
135135
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
136-
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
137136
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
137+
GET_IR_NODE_FROM_SUBGRAPH(BatchCellPreAct, BatchCellPreAct, lstm_pattern);
138+
GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, lstm_pattern);
139+
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
138140
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
139141
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
140142
if (with_fc_bias) {
141143
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
142144
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
145+
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
143146
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
144147
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
145148
fc_bias);
146149
// Remove unneeded nodes.
147150
std::unordered_set<const Node*> marked_nodes(
148-
{mul, lstm, elementwise_add, fc_bias});
151+
{mul, lstm, elementwise_add, mul_out, BatchGate, BatchCellPreAct});
149152
GraphSafeRemoveNodes(graph, marked_nodes);
150153
} else {
151154
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
152155
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
153156
nullptr);
154157
// Remove unneeded nodes.
155-
std::unordered_set<const Node*> marked_nodes({mul, lstm});
158+
std::unordered_set<const Node*> marked_nodes(
159+
{mul, lstm, BatchGate, BatchCellPreAct});
156160
GraphSafeRemoveNodes(graph, marked_nodes);
157161
}
158162

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
16+
17+
#include <gtest/gtest.h>
18+
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
void AddVarToScope(Scope* param_scope, const std::string& name,
25+
const DDim& dims) {
26+
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
27+
tensor->Resize(dims);
28+
tensor->mutable_data<float>(platform::CPUPlace());
29+
}
30+
31+
Scope* CreateParamScope() {
32+
auto param_scope = new Scope();
33+
AddVarToScope(param_scope, "lstm_fc_w", {});
34+
AddVarToScope(param_scope, "lstm_fc_b", {});
35+
AddVarToScope(param_scope, "lstm_w", {});
36+
AddVarToScope(param_scope, "lstm_b", {});
37+
AddVarToScope(param_scope, "lstm_cell_0", {});
38+
AddVarToScope(param_scope, "lstm_batch_gate_0", {});
39+
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_0", {});
40+
AddVarToScope(param_scope, "lstm_hidden_0", {});
41+
AddVarToScope(param_scope, "lstm_cell_1", {});
42+
AddVarToScope(param_scope, "lstm_batch_gate_1", {});
43+
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_1", {});
44+
AddVarToScope(param_scope, "lstm_hidden_1", {});
45+
return param_scope;
46+
}
47+
48+
TEST(FCLSTMFusePass, basic) {
49+
// inputs operator output
50+
// --------------------------------------------------------
51+
// (a, lstm_fc_w) mul -> fc_0_tmp_0
52+
// (fc_0_tmp_0, lstm_fc_b) elementwise_add -> fc_0_tmp_1
53+
// fc_0_tmp_1,lstm_w,lstm_b lstm -> lstm_out_0
54+
55+
// (b, lstm_fc_w) mul -> fc_1_tmp_0
56+
// (fc_1_tmp_0, lstm_fc_b) elementwise_add -> fc_1_tmp_1
57+
// (fc_1_tmp_1,lstm_w,lstm_b) lstm -> lstm_out_1
58+
Layers layers;
59+
auto* a = layers.data("a");
60+
auto* b = layers.data("b");
61+
auto* fc_w = layers.data("lstm_fc_w", {}, true);
62+
auto* fc_b = layers.data("lstm_fc_b", {}, true);
63+
auto* lstm_w = layers.data("lstm_w", {}, true);
64+
auto* lstm_b = layers.data("lstm_b", {}, true);
65+
auto* fc_0_tmp0 = layers.mul(a, fc_w);
66+
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
67+
auto* lstm_cell_0 = layers.data("lstm_cell_0", {}, false);
68+
auto* lstm_batch_gate_0 = layers.data("lstm_batch_gate_0", {}, false);
69+
auto* lstm_batch_cell_pre_gate_0 =
70+
layers.data("lstm_batch_cell_pre_gate_0", {}, false);
71+
auto* lstm_hidden_0 = layers.data("lstm_hidden_0", {}, false);
72+
layers.lstm(fc_0_tmp1, lstm_w, lstm_b, lstm_cell_0, lstm_batch_gate_0,
73+
lstm_hidden_0, lstm_batch_cell_pre_gate_0);
74+
75+
auto* fc_1_tmp0 = layers.mul(b, fc_w);
76+
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
77+
auto* lstm_cell_1 = layers.data("lstm_cell_1", {}, false);
78+
auto* lstm_batch_gate_1 = layers.data("lstm_batch_gate_1", {}, false);
79+
auto* lstm_batch_cell_pre_gate_1 =
80+
layers.data("lstm_batch_cell_pre_gate_1", {}, false);
81+
auto* lstm_hidden_1 = layers.data("lstm_hidden_1", {}, false);
82+
layers.lstm(fc_1_tmp1, lstm_w, lstm_b, lstm_cell_1, lstm_batch_gate_1,
83+
lstm_hidden_1, lstm_batch_cell_pre_gate_1);
84+
85+
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
86+
auto pass = PassRegistry::Instance().Get("fc_lstm_fuse_pass");
87+
pass->Set("use_gpu", new bool(false));
88+
graph->Set("__param_scope__", CreateParamScope());
89+
int num_nodes_before = graph->Nodes().size();
90+
int num_lstm_nodes_before = GetNumOpNodes(graph, "lstm");
91+
VLOG(3) << DebugString(graph);
92+
93+
graph.reset(pass->Apply(graph.release()));
94+
int num_nodes_after = graph->Nodes().size();
95+
int num_fusion_lstm_nodes_after = GetNumOpNodes(graph, "fusion_lstm");
96+
VLOG(3) << DebugString(graph);
97+
98+
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after - 6,
99+
platform::errors::PreconditionNotMet(
100+
"The number of nodes before and after "
101+
"the fuse does not meet expectations"));
102+
PADDLE_ENFORCE_EQ(num_fusion_lstm_nodes_after, 2,
103+
platform::errors::PreconditionNotMet(
104+
"The number of lstm nodes before the "
105+
"fuse does not meet expectations"));
106+
PADDLE_ENFORCE_EQ(num_lstm_nodes_before, num_fusion_lstm_nodes_after,
107+
platform::errors::PreconditionNotMet(
108+
"The number of fusion_gru nodes does "
109+
"not meet expectations after fuse"));
110+
}
111+
112+
} // namespace ir
113+
} // namespace framework
114+
} // namespace paddle
115+
116+
USE_PASS(fc_lstm_fuse_pass);

paddle/fluid/framework/ir/pass_tester_helper.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,62 @@ struct Layers {
120120
return out;
121121
}
122122

123+
void lstm(VarDesc* input, VarDesc* w, VarDesc* bias, VarDesc* cell,
124+
VarDesc* batch_gate, VarDesc* hidden, VarDesc* batch_cell_pre_act,
125+
VarDesc* h0 = nullptr, VarDesc* c0 = nullptr,
126+
bool use_peepholes = true, bool is_reverse = false,
127+
std::string gate_activation = "sigmoid",
128+
std::string cell_activation = "tanh",
129+
std::string candidate_activation = "tanh") {
130+
OpDesc* op = program_.MutableBlock(0)->AppendOp();
131+
op->SetType("lstm");
132+
op->SetInput("Input", {input->Name()});
133+
op->SetInput("Weight", {w->Name()});
134+
op->SetInput("Bias", {bias->Name()});
135+
if (h0) {
136+
op->SetInput("H0", {h0->Name()});
137+
}
138+
if (c0) {
139+
op->SetInput("C0", {c0->Name()});
140+
}
141+
op->SetOutput("Hidden", {hidden->Name()});
142+
op->SetOutput("Cell", {cell->Name()});
143+
op->SetOutput("BatchGate", {batch_gate->Name()});
144+
op->SetOutput("BatchCellPreAct", {batch_cell_pre_act->Name()});
145+
op->SetAttr("use_peepholes", use_peepholes);
146+
op->SetAttr("is_reverse", is_reverse);
147+
op->SetAttr("gate_activation", gate_activation);
148+
op->SetAttr("cell_activation", cell_activation);
149+
op->SetAttr("candidate_activation", candidate_activation);
150+
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
151+
static_cast<int>(OpRole::kForward));
152+
}
153+
154+
void gru(VarDesc* input, VarDesc* w, VarDesc* bias, VarDesc* batch_gate,
155+
VarDesc* batch_reset_hidden_prev, VarDesc* batch_hidden,
156+
VarDesc* hidden, VarDesc* h0 = nullptr, bool origin_mode = false,
157+
bool is_reverse = false, std::string activation = "tanh",
158+
std::string gate_activation = "sigmoid") {
159+
OpDesc* op = program_.MutableBlock(0)->AppendOp();
160+
op->SetType("gru");
161+
op->SetInput("Input", {input->Name()});
162+
op->SetInput("Weight", {w->Name()});
163+
op->SetInput("Bias", {bias->Name()});
164+
if (h0) {
165+
op->SetInput("H0", {h0->Name()});
166+
}
167+
op->SetOutput("BatchGate", {batch_gate->Name()});
168+
op->SetOutput("BatchResetHiddenPrev", {batch_reset_hidden_prev->Name()});
169+
op->SetOutput("BatchHidden", {batch_hidden->Name()});
170+
op->SetOutput("Hidden", {hidden->Name()});
171+
op->SetAttr("origin_mode", origin_mode);
172+
op->SetAttr("is_reverse", is_reverse);
173+
op->SetAttr("activation", activation);
174+
op->SetAttr("gate_activation", gate_activation);
175+
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
176+
static_cast<int>(OpRole::kForward));
177+
}
178+
123179
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
124180
int x_num_col_dims = 1) {
125181
AttributeMap attrs;

0 commit comments

Comments
 (0)