Skip to content

Commit 902f19b

Browse files
authored
fea/fuse attention lstm simplify.with fusion lstm.with sequnce expand (#13006)
1 parent 55f240b commit 902f19b

40 files changed

+1507
-211
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
55
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
66
cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
77
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
8-
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
9-
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
8+
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
9+
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
10+
cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
1011
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
11-
12+
cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
13+
cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)
1214

1315
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
1416
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
1517
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
1618
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
17-
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
18-
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
19+
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
20+
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
16+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
17+
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
18+
#include "paddle/fluid/framework/lod_tensor.h"
19+
#include "paddle/fluid/inference/api/helper.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
25+
struct Param {
26+
std::string X = "concat_0.tmp_0";
27+
std::string C0 = "cell_init";
28+
std::string H0 = "hidden_init";
29+
std::string AttentionWeight = "attention_fc.w_0";
30+
std::string AttentionBias = "attention_fc.b_0";
31+
std::string AttentionScalar = "attention_output.w_0";
32+
std::string AttentionScalarBias = "attention_output.b_0";
33+
std::string LSTMWeight = "attention_w.new";
34+
std::string LSTMBias = "attention_b.new";
35+
std::string Hidden = "array_to_lod_tensor_0.tmp_0";
36+
std::string Cell = "at.cell.new";
37+
std::string AttentionedX = "at.x.new";
38+
std::string AttentionFCOut = "at.fc.new";
39+
std::string LSTMX = "at.lstmx.new";
40+
std::string LSTMOUT = "at.lstmout.new";
41+
};
42+
43+
void PrepareParameters(Graph* graph, const Param& param);
44+
45+
void FindWhileOp(Graph* graph) {
46+
GraphPatternDetector gpd;
47+
std::unordered_set<int> fused_external_ops(
48+
{35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
49+
57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51});
50+
51+
gpd.mutable_pattern()->NewNode(
52+
[&](Node* n) { return fused_external_ops.count(n->id()); }, "while");
53+
54+
if (!graph->Has(kGraphvizMarkedNodeAttr)) {
55+
graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t);
56+
}
57+
auto& marked_nodes =
58+
graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr);
59+
60+
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
61+
Graph* g) {
62+
auto* while_pat_node = gpd.pattern().RetriveNode("while");
63+
auto* while_node = subgraph.at(while_pat_node);
64+
marked_nodes.insert(while_node);
65+
};
66+
gpd(graph, handle);
67+
68+
Param param;
69+
// Add AttentionLSTM node
70+
OpDesc op_desc;
71+
op_desc.SetType("attention_lstm");
72+
73+
#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
74+
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
75+
OP_SET_IN(X);
76+
OP_SET_IN(C0);
77+
OP_SET_IN(H0);
78+
OP_SET_IN(AttentionWeight);
79+
OP_SET_IN(AttentionBias);
80+
OP_SET_IN(AttentionScalar);
81+
OP_SET_IN(AttentionScalarBias);
82+
OP_SET_IN(LSTMWeight);
83+
OP_SET_IN(LSTMBias);
84+
85+
OP_SET_OUT(Hidden);
86+
OP_SET_OUT(Cell);
87+
OP_SET_OUT(AttentionedX);
88+
OP_SET_OUT(AttentionFCOut);
89+
OP_SET_OUT(LSTMX);
90+
OP_SET_OUT(LSTMOUT);
91+
#undef OP_SET_IN
92+
#undef OP_SET_OUT
93+
94+
auto* X = graph->RetriveNode(34);
95+
auto* LSTMOUT = graph->RetriveNode(81);
96+
auto* cell_init = graph->RetriveNode(6);
97+
auto* hidden_init = graph->RetriveNode(8);
98+
99+
#define LINK_TO(node0, node1) \
100+
node0->outputs.push_back(node1); \
101+
node1->inputs.push_back(node0);
102+
103+
auto* lstm_op = graph->CreateOpNode(&op_desc);
104+
PrepareParameters(graph, param);
105+
106+
LINK_TO(X, lstm_op);
107+
LINK_TO(cell_init, lstm_op);
108+
LINK_TO(hidden_init, lstm_op);
109+
LINK_TO(lstm_op, LSTMOUT);
110+
111+
GraphSafeRemoveNodes(graph, marked_nodes);
112+
}
113+
114+
#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
115+
#define CHECK_P2(x0, x1) \
116+
CHECK_P1(x0); \
117+
CHECK_P1(x1);
118+
#define CHECK_P3(x0, x1, x2) \
119+
CHECK_P2(x0, x1); \
120+
CHECK_P1(x2);
121+
#define CHECK_P4(x0, x1, x2, x3) \
122+
CHECK_P3(x0, x1, x2); \
123+
CHECK_P1(x3);
124+
#define CHECK_P5(x0, x1, x2, x3, x4) \
125+
CHECK_P4(x0, x1, x2, x3); \
126+
CHECK_P1(x4);
127+
128+
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
129+
const LoDTensor& W_forget_w1,
130+
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
131+
const LoDTensor& W_output_w0,
132+
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
133+
const LoDTensor& W_cell_w1, LoDTensor* out);
134+
135+
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
136+
const LoDTensor& B_output, const LoDTensor& B_cell,
137+
LoDTensor* out);
138+
139+
void PrepareParameters(Graph* graph, const Param& param) {
140+
// Check parameters
141+
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
142+
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
143+
144+
// Create new parameters.
145+
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
146+
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
147+
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
148+
scope->Var(param.Cell)->GetMutable<LoDTensor>();
149+
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
150+
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
151+
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
152+
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
153+
154+
#define GATE_W(name__) \
155+
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
156+
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
157+
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
158+
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
159+
VLOG(4) << #name__ "_w0" \
160+
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
161+
VLOG(4) << #name__ "_w1" \
162+
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
163+
VLOG(4) << #name__ "_b0" \
164+
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
165+
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
166+
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
167+
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
168+
169+
GATE_W(forget);
170+
GATE_W(input);
171+
GATE_W(output);
172+
GATE_W(c);
173+
#undef GATE_W
174+
175+
auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
176+
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
177+
auto* attention_output_w = scope->FindVar("attention_output.w_0");
178+
auto* attention_output_b = scope->FindVar("attention_output.b_0");
179+
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
180+
attention_output_b);
181+
182+
auto* lstm_weight = scope->Var(param.LSTMWeight);
183+
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
184+
auto* lstm_bias = scope->Var(param.LSTMBias);
185+
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
186+
187+
// reshape attention_bias
188+
auto* attention_bias_t =
189+
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
190+
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
191+
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
192+
193+
auto* attention_scalar_bias_t =
194+
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
195+
attention_scalar_bias_t->Resize(
196+
make_ddim({1, attention_scalar_bias_t->dims()[0]}));
197+
198+
PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t,
199+
W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t,
200+
lstm_weight_t);
201+
PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t,
202+
lstm_bias_t);
203+
}
204+
205+
// Prepare parameters
206+
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
207+
const LoDTensor& W_forget_w1,
208+
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
209+
const LoDTensor& W_output_w0,
210+
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
211+
const LoDTensor& W_cell_w1, LoDTensor* out) {
212+
int D = W_forget_w0.dims()[0];
213+
int M = W_forget_w1.dims()[0];
214+
out->Resize(make_ddim({D + M, 4 * D}));
215+
VLOG(3) << "LSTMWeight resized to " << out->dims();
216+
217+
float* out_data = out->mutable_data<float>(platform::CPUPlace());
218+
std::array<const float*, 4> tensors(
219+
{W_forget_w0.data<float>(), W_input_w0.data<float>(),
220+
W_output_w0.data<float>(), W_cell_w0.data<float>()});
221+
std::array<const float*, 4> tensors1(
222+
{W_forget_w1.data<float>(), W_input_w1.data<float>(),
223+
W_output_w1.data<float>(), W_cell_w1.data<float>()});
224+
225+
for (int row = 0; row < D; row++) {
226+
for (int col = 0; col < 4; col++) {
227+
float* dst = out_data + 4 * D * row + D * col;
228+
const float* src = tensors[col] + D * row;
229+
memcpy(dst, src, D * sizeof(float));
230+
}
231+
}
232+
233+
for (int row = 0; row < M; row++) {
234+
for (int col = 0; col < 4; col++) {
235+
float* dst = out_data + 4 * D * (D + row) + D * col;
236+
const float* src = tensors1[col] + D * row;
237+
memcpy(dst, src, D * sizeof(float));
238+
}
239+
}
240+
}
241+
242+
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
243+
const LoDTensor& B_output, const LoDTensor& B_cell,
244+
LoDTensor* out) {
245+
std::array<const float*, 4> tensors(
246+
{B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
247+
B_cell.data<float>()});
248+
249+
PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
250+
int D = B_forget.dims()[0];
251+
out->Resize(make_ddim({1, 4 * D}));
252+
auto* out_data = out->mutable_data<float>(platform::CPUPlace());
253+
for (size_t i = 0; i < tensors.size(); i++) {
254+
memcpy(out_data + D * i, tensors[i], D * sizeof(float));
255+
}
256+
}
257+
258+
// Parameters
259+
260+
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
261+
std::unique_ptr<ir::Graph> graph) const {
262+
PDPattern external_pattern, subblock_pattern;
263+
264+
FindWhileOp(graph.get());
265+
return graph;
266+
}
267+
268+
} // namespace ir
269+
} // namespace framework
270+
} // namespace paddle
271+
272+
REGISTER_PASS(attention_lstm_fuse_pass,
273+
paddle::framework::ir::AttentionLSTMFusePass);
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,12 +12,19 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/inference/analysis/dot.h"
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
1618

1719
namespace paddle {
18-
namespace inference {
19-
namespace analysis {
20-
size_t Dot::counter = 0;
21-
} // namespace analysis
22-
} // namespace inference
20+
namespace framework {
21+
namespace ir {
22+
23+
class AttentionLSTMFusePass : public FusePassBase {
24+
protected:
25+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
26+
};
27+
28+
} // namespace ir
29+
} // namespace framework
2330
} // namespace paddle

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
100100
},
101101
"elementwise_add_out");
102102

103-
pattern->AddEdge(mul_parameter_var, mul_op);
104-
pattern->AddEdge(mul_tmp_input_var, mul_op);
105-
pattern->AddEdge(mul_op, mul_out_var);
106-
pattern->AddEdge(mul_out_var, elementwise_add_op);
107-
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
108-
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
103+
mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
104+
.LinksTo({mul_out_var});
105+
elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
106+
.LinksTo({elementwise_add_out_var});
109107
}
110108

111109
// Replace the node `from` in the links to `to`
@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
125123

126124
std::unordered_set<Node*> nodes2delete;
127125

128-
GraphPatternDetecter gpd;
126+
GraphPatternDetector gpd;
129127
BuildFCPattern(gpd.mutable_pattern());
130128

131129
#define GET_NODE(id) \
@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
134132
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
135133
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
136134

137-
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph,
135+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
138136
Graph* g) {
139137
VLOG(4) << "handle FC fuse";
140138
// Currently, there is no FC op available, so I will just simulate the

paddle/fluid/framework/ir/fc_fuse_pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/ir/graph.h"
16-
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
16+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
1717
#include "paddle/fluid/framework/ir/pass.h"
1818

1919
namespace paddle {

0 commit comments

Comments
 (0)