Skip to content

Commit c79acd9

Browse files
Pei YangNHZlX
andauthored
[Fix BUGs]: fix multhead matmul pass's instable bug (#25123) (#25975)
* fix multhead matmul's instable test=develop * fix multihead matmul bug test=develop * fix converage problem test=develop Co-authored-by: Zhaolong Xing <[email protected]>
1 parent 67e3629 commit c79acd9

File tree

5 files changed

+72
-84
lines changed

5 files changed

+72
-84
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ void GraphPatternDetector::ValidateByNodeRole(
141141
subgraphs->begin(), subgraphs->end(),
142142
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
143143
// Collect the inputs and outputs.
144-
std::unordered_set<Node *> ios;
144+
std::set<Node *> ios;
145145
for (auto &item : subgraph) {
146146
if (!item.first->IsIntermediate()) {
147147
ios.insert(item.second);
@@ -167,7 +167,7 @@ void GraphPatternDetector::ValidateByNodeRole(
167167
}
168168

169169
struct HitGroup {
170-
std::unordered_map<PDNode *, Node *> roles;
170+
std::map<PDNode *, Node *> roles;
171171

172172
bool Match(Node *node, PDNode *pat) {
173173
if (nodes_.count(node)) {
@@ -185,7 +185,7 @@ struct HitGroup {
185185
}
186186

187187
private:
188-
std::unordered_set<Node *> nodes_;
188+
std::set<Node *> nodes_;
189189
};
190190

191191
// Tell whether Node a links to b.
@@ -284,7 +284,7 @@ void GraphPatternDetector::UniquePatterns(
284284
if (subgraphs->empty()) return;
285285
std::vector<GraphPatternDetector::subgraph_t> result;
286286

287-
std::unordered_set<size_t> set;
287+
std::set<size_t> set;
288288
std::hash<std::string> hasher;
289289
for (auto &g : *subgraphs) {
290290
// Sort the items in the sub-graph, and transform to a string key.
@@ -306,7 +306,7 @@ void GraphPatternDetector::UniquePatterns(
306306
void GraphPatternDetector::RemoveOverlappedMatch(
307307
std::vector<subgraph_t> *subgraphs) {
308308
std::vector<subgraph_t> result;
309-
std::unordered_set<Node *> node_set;
309+
std::set<Node *> node_set;
310310

311311
for (const auto &subgraph : *subgraphs) {
312312
bool valid = true;

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class PDPattern {
231231

232232
std::vector<std::unique_ptr<PDNode>> nodes_;
233233
std::vector<edge_t> edges_;
234-
std::unordered_map<std::string, PDNode*> node_map_;
234+
std::map<std::string, PDNode*> node_map_;
235235
static size_t id_;
236236
};
237237

@@ -263,7 +263,7 @@ class PDPattern {
263263
*/
264264
class GraphPatternDetector {
265265
public:
266-
using subgraph_t = std::unordered_map<PDNode*, Node*>;
266+
using subgraph_t = std::map<PDNode*, Node*>;
267267

268268
// Operate on the detected pattern.
269269
using handle_t =

paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc

Lines changed: 62 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
4545
// Create pattern.
4646
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
4747

48-
PDNode* x =
49-
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
50-
51-
multihead_pattern(x);
48+
multihead_pattern();
5249
// Create New OpDesc
5350
auto fuse_creater = [&](
54-
Node* x, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
51+
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
5552
Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b,
5653
Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2,
5754
Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
@@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
115112
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
116113
Graph* g) {
117114
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
118-
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern);
115+
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
119116

120117
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
121118
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
@@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
185182
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
186183
multihead_pattern);
187184

188-
fuse_creater(layer_norm, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
185+
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
189186
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
190187
reshape2_qkv_out, scale, scale_out);
191188

@@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
232229
return fusion_count;
233230
}
234231

235-
PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
236-
// Create shared nodes.
237-
auto* layer_norm = pattern->NewNode(layer_norm_repr());
238-
239-
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr());
240-
layer_norm_out_var->assert_is_op_input("mul");
232+
PDNode* MultiHeadMatmulPattern::operator()() {
233+
auto* input0 = pattern->NewNode(input0_repr());
234+
input0->assert_is_op_input("mul");
241235

242236
// First path with scale
243237
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
@@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
390384
transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
391385
"matmul"); // link to matmul qkv
392386

393-
// Link all nodes together
394-
layer_norm->LinksFrom({x}).LinksTo({layer_norm_out_var});
395387
// Q path
396-
mul0->LinksFrom({layer_norm_out_var, mul0_w_var}).LinksTo({mul0_out_var});
388+
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
397389
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
398390

399391
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
400392
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
401393
scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var});
402394
// K path
403-
mul1->LinksFrom({layer_norm_out_var, mul1_w_var}).LinksTo({mul1_out_var});
395+
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
404396
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
405397
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
406398
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
@@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
411403
.LinksTo({eltadd_qk_out_var});
412404
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
413405
// V path
414-
mul2->LinksFrom({layer_norm_out_var, mul2_w_var}).LinksTo({mul2_out_var});
406+
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
415407
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
416408
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
417409
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
@@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
434426
// Create pattern.
435427
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
436428

437-
PDNode* x =
438-
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
439-
440-
multihead_pattern(x);
429+
multihead_pattern();
441430
// Create New OpDesc
442431
auto fuse_creater = [&](
443-
Node* layer_norm_out, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
432+
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
444433
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
445434
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
446435
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
@@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
471460
framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
472461
auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]});
473462

474-
// create a new var in scope
475-
VarDesc combined_w_desc(
476-
patterns::PDNodeName(name_scope, "multi_head_combined_weight"));
477-
combined_w_desc.SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
478-
combined_w_desc.SetDataType(wq_tensor->type());
479-
combined_w_desc.SetLoDLevel(mul0_w->Var()->GetLoDLevel());
480-
combined_w_desc.SetPersistable(true);
481-
482-
// create a new var in scope
483-
VarDesc combined_bias_desc(
484-
patterns::PDNodeName(name_scope, "multi_head_combined_bias"));
485-
combined_bias_desc.SetShape({3, bq_tensor->dims()[0]});
486-
combined_bias_desc.SetDataType(bq_tensor->type());
487-
combined_bias_desc.SetLoDLevel(eltadd0_b->Var()->GetLoDLevel());
488-
combined_bias_desc.SetPersistable(true);
489-
490-
auto* combined_w_node = graph->CreateVarNode(&combined_w_desc);
491-
auto* combined_w_tensor =
492-
scope->Var(combined_w_node->Name())->GetMutable<LoDTensor>();
493-
494-
combined_w_tensor->Resize(combined_w_dims);
495-
auto* combined_w_data =
496-
combined_w_tensor->mutable_data<float>(platform::CPUPlace());
463+
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
464+
auto* combined_w_desc = mul0_w->Var();
465+
combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
466+
combined_w_desc->SetPersistable(true);
467+
468+
auto* combined_bias_desc = eltadd0_b->Var();
469+
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
470+
combined_bias_desc->SetPersistable(true);
471+
472+
framework::LoDTensor tmp_combined_w_tensor;
473+
tmp_combined_w_tensor.Resize(combined_w_dims);
474+
auto* tmp_combined_w_data =
475+
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
476+
497477
std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
498478
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
499479
// Combine the three fc weights together.
@@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
502482
for (int k = 0; k < dims_w; k++) {
503483
int out_index = i * (3 * dims_w) + j * dims_w + k;
504484
int in_index = i * dims_w + k;
505-
combined_w_data[out_index] = w_vec[j][in_index];
485+
tmp_combined_w_data[out_index] = w_vec[j][in_index];
506486
}
507487
}
508488
}
509-
scope->EraseVars({mul0_w->Name(), mul1_w->Name(), mul2_w->Name()});
510-
auto* combined_bias_node = graph->CreateVarNode(&combined_bias_desc);
511-
auto* combined_bias_tensor =
512-
scope->Var(combined_bias_node->Name())->GetMutable<LoDTensor>();
513-
514-
combined_bias_tensor->Resize(combined_bias_dims);
515-
auto* combined_bias_data =
516-
combined_bias_tensor->mutable_data<float>(platform::CPUPlace());
489+
490+
wq_tensor->Resize(combined_w_dims);
491+
auto* new_combined_w_data =
492+
wq_tensor->mutable_data<float>(platform::CPUPlace());
493+
memcpy(new_combined_w_data, tmp_combined_w_data,
494+
sizeof(float) * wq_tensor->numel());
495+
496+
scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
497+
498+
framework::LoDTensor tmp_combined_bias_tensor;
499+
tmp_combined_bias_tensor.Resize(combined_bias_dims);
500+
auto* tmp_combined_bias_data =
501+
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
502+
517503
size_t bias_size = bq_tensor->numel();
518-
memcpy(combined_bias_data, bq_data, sizeof(float) * bias_size);
519-
memcpy(combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
520-
memcpy(combined_bias_data + 2 * bias_size, bv_data,
504+
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
505+
memcpy(tmp_combined_bias_data + bias_size, bk_data,
506+
sizeof(float) * bias_size);
507+
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
521508
sizeof(float) * bias_size);
522509

523-
scope->EraseVars({eltadd0_b->Name(), eltadd1_b->Name(), eltadd2_b->Name()});
510+
bq_tensor->Resize(combined_bias_dims);
511+
auto* new_combined_bias_data =
512+
bq_tensor->mutable_data<float>(platform::CPUPlace());
513+
memcpy(new_combined_bias_data, tmp_combined_bias_data,
514+
sizeof(float) * bq_tensor->numel());
515+
516+
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
524517

525518
auto reshape_desc = reshape2->Op();
526519
int head_number =
@@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
529522
OpDesc multihead_op_desc;
530523
multihead_op_desc.SetType("multihead_matmul");
531524

532-
multihead_op_desc.SetInput("Input", {layer_norm_out->Name()});
533-
multihead_op_desc.SetInput("W", {combined_w_node->Name()});
534-
multihead_op_desc.SetInput("Bias", {combined_bias_node->Name()});
525+
multihead_op_desc.SetInput("Input", {input0->Name()});
526+
multihead_op_desc.SetInput("W", {mul0_w->Name()});
527+
multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()});
535528
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
536529

537530
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
@@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
540533

541534
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
542535

543-
IR_NODE_LINK_TO(layer_norm_out, multihead);
544-
IR_NODE_LINK_TO(combined_w_node, multihead);
545-
IR_NODE_LINK_TO(combined_bias_node, multihead);
536+
IR_NODE_LINK_TO(input0, multihead);
537+
IR_NODE_LINK_TO(mul0_w, multihead);
538+
IR_NODE_LINK_TO(eltadd0_b, multihead);
546539
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
547540

548541
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
@@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
552545
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
553546
Graph* g) {
554547
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
555-
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern);
556-
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
557-
multihead_pattern);
548+
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
558549

559550
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
560551
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
@@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
624615
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
625616
multihead_pattern);
626617

627-
fuse_creater(layer_norm_out, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
628-
mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b,
629-
eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);
618+
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
619+
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
620+
reshape2_0, reshape2_qkv_out, scale, scale_out);
630621

631622
std::unordered_set<const Node*> marked_nodes({eltadd0,
632623
eltadd1,
633624
eltadd2,
634-
eltadd0_b,
635625
eltadd1_b,
636626
eltadd2_b,
637627
eltadd0_out,
@@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
665655
mul0_out,
666656
mul1_out,
667657
mul2_out,
668-
mul0_w,
669658
mul1_w,
670659
mul2_w,
671660
reshape2_qkv,

paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@ struct MultiHeadMatmulPattern : public PatternBase {
2929
MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
3030
: PatternBase(pattern, name_scope, "multihead_matmul") {}
3131

32-
PDNode* operator()(PDNode* x);
32+
PDNode* operator()();
3333

3434
// declare operator node's name
35-
PATTERN_DECL_NODE(layer_norm);
36-
PATTERN_DECL_NODE(layer_norm_out);
35+
PATTERN_DECL_NODE(input0);
3736
PATTERN_DECL_NODE(mul0);
3837
PATTERN_DECL_NODE(mul1);
3938
PATTERN_DECL_NODE(mul2);

paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions(
167167
ret.nbDims = 5;
168168
ret.d[0] = inputs[0].d[0];
169169
ret.d[1] = inputs[0].d[1];
170-
ret.d[2] = expr_builder.constant(hidden_);
170+
ret.d[2] = expr_builder.constant(head_size_ * head_number_);
171171
ret.d[3] = expr_builder.constant(1);
172172
ret.d[4] = expr_builder.constant(1);
173173
return ret;

0 commit comments

Comments
 (0)