Skip to content

Commit 3a6cc57

Browse files
Fuse multi transformer layer pass (#47541) (#47830)
* add fuse_multi_transformer_layer_pass
1 parent 2e9e65d commit 3a6cc57

10 files changed

+690
-1
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ pass_library(skip_layernorm_fuse_pass base)
154154
pass_library(multihead_matmul_fuse_pass inference)
155155
pass_library(fused_multi_transformer_encoder_pass inference)
156156
pass_library(fused_multi_transformer_decoder_pass inference)
157+
pass_library(fuse_multi_transformer_layer_pass inference)
157158
pass_library(adaptive_pool2d_convert_global_pass inference)
158159
pass_library(unsqueeze2_eltwise_fuse_pass inference)
159160
pass_library(yolo_box_fuse_pass inference)
@@ -368,6 +369,10 @@ cc_test(
368369
test_fused_multi_transformer_decoder_pass
369370
SRCS fused_multi_transformer_decoder_pass_tester.cc
370371
DEPS fused_multi_transformer_decoder_pass)
372+
cc_test(
373+
test_fuse_multi_transformer_layer_pass
374+
SRCS fuse_multi_transformer_layer_pass_tester.cc
375+
DEPS fuse_multi_transformer_layer_pass)
371376
cc_test(
372377
test_conv_bn_fuse_pass_cc
373378
SRCS conv_bn_fuse_pass_tester.cc
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h"
16+
17+
#include <string>
18+
19+
#include "paddle/fluid/framework/convert_utils.h"
20+
#include "paddle/fluid/framework/lod_tensor.h"
21+
#include "paddle/fluid/framework/op_version_registry.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
class Scope;
26+
} // namespace framework
27+
} // namespace paddle
28+
29+
namespace paddle {
30+
namespace framework {
31+
namespace ir {
32+
namespace patterns {
33+
34+
std::unordered_map<std::string, std::string>
35+
MultiTransformerLayerPattern::operator()(bool enable_int8,
36+
int num_fused_op,
37+
bool is_decoder) {
38+
std::string fused_multi_transformer_name =
39+
enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer";
40+
41+
std::unordered_map<std::string, std::string> node_reprs;
42+
43+
// x0 and src_mask is unqiue input of subgraph
44+
auto* x0 = pattern->NewNode(x0_repr());
45+
x0->assert_is_op_input(fused_multi_transformer_name, "X")->AsInput();
46+
auto* src_mask = pattern->NewNode(src_mask_repr());
47+
src_mask->assert_is_op_input(fused_multi_transformer_name, "SrcMask")
48+
->AsInput();
49+
50+
for (int i = 0; i < num_fused_op; ++i) {
51+
auto fuse_op_repr =
52+
PDNodeName(name_scope_, repr_, id_, "fuse_op_" + std::to_string(i));
53+
node_reprs["fuse_op_" + std::to_string(i)] = fuse_op_repr;
54+
auto* fused_multi_transformer =
55+
pattern->NewNode(fuse_op_repr)
56+
->assert_is_op(fused_multi_transformer_name);
57+
58+
auto out_repr =
59+
PDNodeName(name_scope_, repr_, id_, "out_" + std::to_string(i));
60+
node_reprs["out_" + std::to_string(i)] = out_repr;
61+
auto* out = pattern->NewNode(out_repr)->assert_is_op_output(
62+
fused_multi_transformer_name, "Out");
63+
64+
if (is_decoder) {
65+
auto shape_repr =
66+
PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i));
67+
node_reprs["shape_" + std::to_string(i)] = shape_repr;
68+
auto* shape = pattern->NewNode(shape_repr)->assert_is_op("shape");
69+
70+
auto shape_out_repr =
71+
PDNodeName(name_scope_, repr_, id_, "shape_out_" + std::to_string(i));
72+
node_reprs["shape_out_" + std::to_string(i)] = shape_out_repr;
73+
auto* shape_out =
74+
pattern->NewNode(shape_out_repr)->assert_is_op_output("shape", "Out");
75+
76+
shape->LinksFrom({src_mask}).LinksTo({shape_out});
77+
78+
auto slice_repr =
79+
PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i));
80+
node_reprs["slice_" + std::to_string(i)] = slice_repr;
81+
auto* slice = pattern->NewNode(slice_repr)->assert_is_op("slice");
82+
83+
auto slice_out_repr =
84+
PDNodeName(name_scope_, repr_, id_, "slice_out_" + std::to_string(i));
85+
node_reprs["slice_out_" + std::to_string(i)] = slice_out_repr;
86+
auto* slice_out =
87+
pattern->NewNode(slice_out_repr)->assert_is_op_output("slice", "Out");
88+
89+
slice->LinksFrom({shape_out}).LinksTo({slice_out});
90+
91+
fused_multi_transformer->LinksFrom({x0, src_mask, slice_out})
92+
.LinksTo({out});
93+
} else {
94+
auto cache_kv_repr =
95+
PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i));
96+
node_reprs["cache_kv_" + std::to_string(i)] = cache_kv_repr;
97+
auto* cache_kv = pattern->NewNode(cache_kv_repr);
98+
cache_kv->assert_is_op_input(fused_multi_transformer_name, "CacheKV");
99+
cache_kv->AsInput();
100+
101+
auto fill_const_op_repr =
102+
PDNodeName(name_scope_, repr_, id_, "fill_op_" + std::to_string(i));
103+
node_reprs["fill_op_" + std::to_string(i)] = fill_const_op_repr;
104+
auto fill_const_op = pattern->NewNode(fill_const_op_repr)
105+
->assert_is_op("fill_constant_batch_size_like");
106+
107+
fused_multi_transformer->LinksFrom({x0, src_mask, cache_kv})
108+
.LinksTo({out});
109+
fill_const_op->LinksFrom({x0}).LinksTo({cache_kv});
110+
}
111+
x0 = out;
112+
}
113+
x0->AsOutput();
114+
return node_reprs;
115+
}
116+
} // namespace patterns
117+
118+
inline void MergeInput(OpDesc* op,
119+
const std::vector<VariableNameMap>& input_name_maps,
120+
const std::string& input_name) {
121+
std::vector<std::string> tmp = input_name_maps[0].at(input_name);
122+
for (size_t i = 1; i < input_name_maps.size(); ++i) {
123+
tmp.insert(tmp.end(),
124+
input_name_maps[i].at(input_name).begin(),
125+
input_name_maps[i].at(input_name).end());
126+
}
127+
op->SetInput(input_name, tmp);
128+
}
129+
130+
template <typename T>
131+
inline void MergeAttrs(const std::vector<OpDesc*>& ops,
132+
const std::string& attr_name) {
133+
std::vector<T> res;
134+
for (size_t i = 0; i < ops.size(); ++i) {
135+
auto scale_vec =
136+
PADDLE_GET_CONST(std::vector<T>, ops[i]->GetAttr(attr_name));
137+
res.insert(res.end(), scale_vec.begin(), scale_vec.end());
138+
}
139+
ops[0]->SetAttr(attr_name, res);
140+
}
141+
142+
int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
143+
const std::string& name_scope,
144+
Scope* scope) const {
145+
GraphPatternDetector gpd;
146+
auto* pattern = gpd.mutable_pattern();
147+
148+
// TODO(wufeisheng): Get enable_int8 attr from graph after
149+
// fused_multi_transformer pass with int8 merged
150+
bool enable_int8 = false;
151+
152+
int num_fuse_op = 0;
153+
bool is_decoder = false;
154+
155+
if (graph->Has(kFusedMultiTransformerEncoderFusionCount)) {
156+
num_fuse_op = graph->Get<int>(kFusedMultiTransformerEncoderFusionCount);
157+
is_decoder = false;
158+
} else if (graph->Has(kFusedMultiTransformerDecoderFusionCount)) {
159+
num_fuse_op = graph->Get<int>(kFusedMultiTransformerDecoderFusionCount);
160+
is_decoder = true;
161+
}
162+
if (num_fuse_op == 0) {
163+
VLOG(4) << "fuse_multi_transformer_layer_pass will be skipped "
164+
"cause num_fuse_op is not been set or set to 0";
165+
return 0;
166+
}
167+
if (!is_decoder) {
168+
VLOG(4) << "fuse_multi_transformer_layer_pass will match encoder pattern";
169+
} else {
170+
VLOG(4) << "fuse_multi_transformer_layer_pass will match decoder pattern";
171+
}
172+
173+
patterns::MultiTransformerLayerPattern multi_layer_pattern(pattern,
174+
name_scope);
175+
auto node_reprs = multi_layer_pattern(enable_int8, num_fuse_op, is_decoder);
176+
177+
int fusion_count{0};
178+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
179+
Graph* graph) {
180+
///////////////////
181+
//// Get nodes ////
182+
///////////////////
183+
184+
GET_IR_NODE_FROM_SUBGRAPH(src_mask, src_mask, multi_layer_pattern);
185+
GET_IR_NODE_FROM_SUBGRAPH(x0, x0, multi_layer_pattern);
186+
187+
std::vector<Node*> fuse_op_nodes;
188+
std::vector<Node*> out_nodes;
189+
190+
std::vector<std::string> unused_node_prefixes = {
191+
"shape_", "shape_out_", "slice_", "slice_out_"};
192+
std::vector<Node*> unused_nodes;
193+
194+
std::vector<OpDesc*> fuse_op_descs;
195+
std::vector<VariableNameMap> fuse_op_input_var_name_maps;
196+
std::vector<VariableNameMap> fuse_op_output_var_name_maps;
197+
198+
for (int i = 0; i < num_fuse_op; ++i) {
199+
PDNode* fuse_op_pdnode =
200+
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
201+
node_reprs["fuse_op_" + std::to_string(i)]);
202+
Node* fuse_op_node = subgraph.at(fuse_op_pdnode);
203+
fuse_op_nodes.push_back(fuse_op_node);
204+
fuse_op_descs.push_back(fuse_op_node->Op());
205+
fuse_op_input_var_name_maps.emplace_back(fuse_op_node->Op()->Inputs());
206+
fuse_op_output_var_name_maps.emplace_back(fuse_op_node->Op()->Outputs());
207+
208+
PDNode* out_pdnode =
209+
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
210+
node_reprs["out_" + std::to_string(i)]);
211+
out_nodes.push_back(subgraph.at(out_pdnode));
212+
213+
// fill_const op use x0 as input
214+
if (!is_decoder && i != 0) {
215+
PDNode* fill_op_pdnode =
216+
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
217+
node_reprs["fill_op_" + std::to_string(i)]);
218+
Node* fill_op_node = subgraph.at(fill_op_pdnode);
219+
fill_op_node->Op()->SetInput("Input", {x0->Name()});
220+
IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node);
221+
IR_NODE_LINK_TO(x0, fill_op_node);
222+
} else if (is_decoder && i != 0) {
223+
for (const auto& unused_node_prefix : unused_node_prefixes) {
224+
PDNode* unused_pdnode =
225+
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
226+
node_reprs[unused_node_prefix + std::to_string(i)]);
227+
Node* unused_node = subgraph.at(unused_pdnode);
228+
unused_nodes.push_back(unused_node);
229+
}
230+
}
231+
}
232+
233+
///////////////
234+
//// Merge ////
235+
///////////////
236+
237+
// Merge inputs
238+
std::vector<std::string> inputs_names = {"CacheKV",
239+
"FFN1Bias",
240+
"FFN1Weight",
241+
"FFN2Bias",
242+
"FFN2Weight",
243+
"FFNLnBias",
244+
"FFNLnScale",
245+
"LnBias",
246+
"LnScale",
247+
"OutLinearBias",
248+
"OutLinearW",
249+
"QKVBias",
250+
"QKVW"};
251+
252+
for (const auto& input_name : inputs_names) {
253+
MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name);
254+
}
255+
256+
// Merge outputs
257+
fuse_op_descs[0]->SetOutput(
258+
"Out", fuse_op_output_var_name_maps[num_fuse_op - 1]["Out"]);
259+
auto& merged_cache_kv_out_names =
260+
fuse_op_output_var_name_maps[0]["CacheKVOut"];
261+
for (int i = 1; i < num_fuse_op; ++i) {
262+
const auto& out_var_names = fuse_op_output_var_name_maps[i]["CacheKVOut"];
263+
merged_cache_kv_out_names.insert(merged_cache_kv_out_names.end(),
264+
out_var_names.begin(),
265+
out_var_names.end());
266+
}
267+
fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names);
268+
269+
////////////////
270+
//// ReLink ////
271+
////////////////
272+
// Before relink, out nodes (0 -> num_layer-1) should be removed
273+
std::unordered_set<const Node*> marked_out_nodes(out_nodes.begin(),
274+
out_nodes.end() - 1);
275+
GraphSafeRemoveNodes(graph, marked_out_nodes);
276+
277+
// Relink all input nodes of fused_multi_transformer ops to the first op
278+
auto& merged_inputs = fuse_op_nodes[0]->inputs;
279+
for (int i = 1; i < num_fuse_op; ++i) {
280+
merged_inputs.insert(merged_inputs.end(),
281+
fuse_op_nodes[i]->inputs.begin(),
282+
fuse_op_nodes[i]->inputs.end());
283+
}
284+
285+
// Relink fuse op -> out
286+
IR_NODE_UNLINK(fuse_op_nodes[num_fuse_op - 1], out_nodes[num_fuse_op - 1]);
287+
IR_NODE_LINK_TO(fuse_op_nodes[0], out_nodes[num_fuse_op - 1]);
288+
289+
/////////////////////////////
290+
//// Delete unused nodes ////
291+
/////////////////////////////
292+
// Delete fused_multi_transformer op expect for the first one
293+
std::unordered_set<const Node*> marked_fuse_op_nodes(
294+
fuse_op_nodes.begin() + 1, fuse_op_nodes.end());
295+
296+
if (is_decoder) {
297+
marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end());
298+
}
299+
300+
GraphSafeRemoveNodes(graph, marked_fuse_op_nodes);
301+
++fusion_count;
302+
};
303+
304+
gpd(graph, handler);
305+
return fusion_count;
306+
}
307+
308+
void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const {
309+
FusePassBase::Init(name_scope_, graph);
310+
auto* scope = param_scope();
311+
PADDLE_ENFORCE_NOT_NULL(
312+
scope,
313+
platform::errors::Fatal("During the fuse_multi_transformer_layer pass, "
314+
"The scope should not be null."));
315+
int fusion_count = BuildFusion(graph, name_scope_, scope);
316+
317+
AddStatis(fusion_count);
318+
}
319+
320+
} // namespace ir
321+
} // namespace framework
322+
} // namespace paddle
323+
324+
REGISTER_PASS(fuse_multi_transformer_layer_pass,
325+
paddle::framework::ir::FuseMultiTransformerLayerPass);

0 commit comments

Comments
 (0)