@@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
45
45
// Create pattern.
46
46
MultiHeadMatmulPattern multihead_pattern (pattern, name_scope);
47
47
48
- PDNode* x =
49
- pattern->NewNode (patterns::UniqueKey (" X" ))->assert_var_not_persistable ();
50
-
51
- multihead_pattern (x);
48
+ multihead_pattern ();
52
49
// Create New OpDesc
53
50
auto fuse_creater = [&](
54
- Node* x , Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
51
+ Node* input0 , Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
55
52
Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b,
56
53
Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2,
57
54
Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
@@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
115
112
auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
116
113
Graph* g) {
117
114
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
118
- GET_IR_NODE_FROM_SUBGRAPH (layer_norm, layer_norm , multihead_pattern);
115
+ GET_IR_NODE_FROM_SUBGRAPH (input0, input0 , multihead_pattern);
119
116
120
117
GET_IR_NODE_FROM_SUBGRAPH (mul0, mul0, multihead_pattern);
121
118
GET_IR_NODE_FROM_SUBGRAPH (mul0_out, mul0_out, multihead_pattern);
@@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
185
182
GET_IR_NODE_FROM_SUBGRAPH (transpose2_qkv_out, transpose2_qkv_out,
186
183
multihead_pattern);
187
184
188
- fuse_creater (layer_norm , mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
185
+ fuse_creater (input0 , mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
189
186
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
190
187
reshape2_qkv_out, scale, scale_out);
191
188
@@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
232
229
return fusion_count;
233
230
}
234
231
235
- PDNode* MultiHeadMatmulPattern::operator ()(paddle::framework::ir::PDNode* x) {
236
- // Create shared nodes.
237
- auto * layer_norm = pattern->NewNode (layer_norm_repr ());
238
-
239
- auto * layer_norm_out_var = pattern->NewNode (layer_norm_out_repr ());
240
- layer_norm_out_var->assert_is_op_input (" mul" );
232
+ PDNode* MultiHeadMatmulPattern::operator ()() {
233
+ auto * input0 = pattern->NewNode (input0_repr ());
234
+ input0->assert_is_op_input (" mul" );
241
235
242
236
// First path with scale
243
237
auto * mul0 = pattern->NewNode (mul0_repr ())->assert_is_op (" mul" );
@@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
390
384
transpose2_2_out_var->AsIntermediate ()->assert_is_op_input (
391
385
" matmul" ); // link to matmul qkv
392
386
393
- // Link all nodes together
394
- layer_norm->LinksFrom ({x}).LinksTo ({layer_norm_out_var});
395
387
// Q path
396
- mul0->LinksFrom ({layer_norm_out_var , mul0_w_var}).LinksTo ({mul0_out_var});
388
+ mul0->LinksFrom ({input0 , mul0_w_var}).LinksTo ({mul0_out_var});
397
389
eltadd0->LinksFrom ({mul0_out_var, eltadd0_b_var}).LinksTo ({eltadd0_out_var});
398
390
399
391
reshape2_0->LinksFrom ({eltadd0_out_var}).LinksTo ({reshape2_0_out_var});
400
392
transpose2_0->LinksFrom ({reshape2_0_out_var}).LinksTo ({transpose2_0_out_var});
401
393
scale->LinksFrom ({transpose2_0_out_var}).LinksTo ({scale_out_var});
402
394
// K path
403
- mul1->LinksFrom ({layer_norm_out_var , mul1_w_var}).LinksTo ({mul1_out_var});
395
+ mul1->LinksFrom ({input0 , mul1_w_var}).LinksTo ({mul1_out_var});
404
396
eltadd1->LinksFrom ({mul1_out_var, eltadd1_b_var}).LinksTo ({eltadd1_out_var});
405
397
reshape2_1->LinksFrom ({eltadd1_out_var}).LinksTo ({reshape2_1_out_var});
406
398
transpose2_1->LinksFrom ({reshape2_1_out_var}).LinksTo ({transpose2_1_out_var});
@@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
411
403
.LinksTo ({eltadd_qk_out_var});
412
404
softmax_qk->LinksFrom ({eltadd_qk_out_var}).LinksTo ({softmax_qk_out_var});
413
405
// V path
414
- mul2->LinksFrom ({layer_norm_out_var , mul2_w_var}).LinksTo ({mul2_out_var});
406
+ mul2->LinksFrom ({input0 , mul2_w_var}).LinksTo ({mul2_out_var});
415
407
eltadd2->LinksFrom ({mul2_out_var, eltadd2_b_var}).LinksTo ({eltadd2_out_var});
416
408
reshape2_2->LinksFrom ({eltadd2_out_var}).LinksTo ({reshape2_2_out_var});
417
409
transpose2_2->LinksFrom ({reshape2_2_out_var}).LinksTo ({transpose2_2_out_var});
@@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
434
426
// Create pattern.
435
427
MultiHeadMatmulPattern multihead_pattern (pattern, name_scope);
436
428
437
- PDNode* x =
438
- pattern->NewNode (patterns::UniqueKey (" X" ))->assert_var_not_persistable ();
439
-
440
- multihead_pattern (x);
429
+ multihead_pattern ();
441
430
// Create New OpDesc
442
431
auto fuse_creater = [&](
443
- Node* layer_norm_out , Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
432
+ Node* input0 , Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
444
433
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
445
434
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
446
435
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
@@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
471
460
framework::make_ddim ({wq_tensor->dims ()[0 ], 3 , wq_tensor->dims ()[1 ]});
472
461
auto combined_bias_dims = framework::make_ddim ({3 , bq_tensor->dims ()[0 ]});
473
462
474
- // create a new var in scope
475
- VarDesc combined_w_desc (
476
- patterns::PDNodeName (name_scope, " multi_head_combined_weight" ));
477
- combined_w_desc.SetShape ({wq_tensor->dims ()[0 ], 3 , wq_tensor->dims ()[1 ]});
478
- combined_w_desc.SetDataType (wq_tensor->type ());
479
- combined_w_desc.SetLoDLevel (mul0_w->Var ()->GetLoDLevel ());
480
- combined_w_desc.SetPersistable (true );
481
-
482
- // create a new var in scope
483
- VarDesc combined_bias_desc (
484
- patterns::PDNodeName (name_scope, " multi_head_combined_bias" ));
485
- combined_bias_desc.SetShape ({3 , bq_tensor->dims ()[0 ]});
486
- combined_bias_desc.SetDataType (bq_tensor->type ());
487
- combined_bias_desc.SetLoDLevel (eltadd0_b->Var ()->GetLoDLevel ());
488
- combined_bias_desc.SetPersistable (true );
489
-
490
- auto * combined_w_node = graph->CreateVarNode (&combined_w_desc);
491
- auto * combined_w_tensor =
492
- scope->Var (combined_w_node->Name ())->GetMutable <LoDTensor>();
493
-
494
- combined_w_tensor->Resize (combined_w_dims);
495
- auto * combined_w_data =
496
- combined_w_tensor->mutable_data <float >(platform::CPUPlace ());
463
+ // reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
464
+ auto * combined_w_desc = mul0_w->Var ();
465
+ combined_w_desc->SetShape ({wq_tensor->dims ()[0 ], 3 , wq_tensor->dims ()[1 ]});
466
+ combined_w_desc->SetPersistable (true );
467
+
468
+ auto * combined_bias_desc = eltadd0_b->Var ();
469
+ combined_bias_desc->SetShape ({3 , bq_tensor->dims ()[0 ]});
470
+ combined_bias_desc->SetPersistable (true );
471
+
472
+ framework::LoDTensor tmp_combined_w_tensor;
473
+ tmp_combined_w_tensor.Resize (combined_w_dims);
474
+ auto * tmp_combined_w_data =
475
+ tmp_combined_w_tensor.mutable_data <float >(platform::CPUPlace ());
476
+
497
477
std::vector<float *> w_vec = {wq_data, wk_data, wv_data};
498
478
int dims_h = combined_w_dims[0 ], dims_w = combined_w_dims[2 ];
499
479
// Combine the three fc weights together.
@@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
502
482
for (int k = 0 ; k < dims_w; k++) {
503
483
int out_index = i * (3 * dims_w) + j * dims_w + k;
504
484
int in_index = i * dims_w + k;
505
- combined_w_data [out_index] = w_vec[j][in_index];
485
+ tmp_combined_w_data [out_index] = w_vec[j][in_index];
506
486
}
507
487
}
508
488
}
509
- scope->EraseVars ({mul0_w->Name (), mul1_w->Name (), mul2_w->Name ()});
510
- auto * combined_bias_node = graph->CreateVarNode (&combined_bias_desc);
511
- auto * combined_bias_tensor =
512
- scope->Var (combined_bias_node->Name ())->GetMutable <LoDTensor>();
513
-
514
- combined_bias_tensor->Resize (combined_bias_dims);
515
- auto * combined_bias_data =
516
- combined_bias_tensor->mutable_data <float >(platform::CPUPlace ());
489
+
490
+ wq_tensor->Resize (combined_w_dims);
491
+ auto * new_combined_w_data =
492
+ wq_tensor->mutable_data <float >(platform::CPUPlace ());
493
+ memcpy (new_combined_w_data, tmp_combined_w_data,
494
+ sizeof (float ) * wq_tensor->numel ());
495
+
496
+ scope->EraseVars ({mul1_w->Name (), mul2_w->Name ()});
497
+
498
+ framework::LoDTensor tmp_combined_bias_tensor;
499
+ tmp_combined_bias_tensor.Resize (combined_bias_dims);
500
+ auto * tmp_combined_bias_data =
501
+ tmp_combined_bias_tensor.mutable_data <float >(platform::CPUPlace ());
502
+
517
503
size_t bias_size = bq_tensor->numel ();
518
- memcpy (combined_bias_data, bq_data, sizeof (float ) * bias_size);
519
- memcpy (combined_bias_data + bias_size, bk_data, sizeof (float ) * bias_size);
520
- memcpy (combined_bias_data + 2 * bias_size, bv_data,
504
+ memcpy (tmp_combined_bias_data, bq_data, sizeof (float ) * bias_size);
505
+ memcpy (tmp_combined_bias_data + bias_size, bk_data,
506
+ sizeof (float ) * bias_size);
507
+ memcpy (tmp_combined_bias_data + 2 * bias_size, bv_data,
521
508
sizeof (float ) * bias_size);
522
509
523
- scope->EraseVars ({eltadd0_b->Name (), eltadd1_b->Name (), eltadd2_b->Name ()});
510
+ bq_tensor->Resize (combined_bias_dims);
511
+ auto * new_combined_bias_data =
512
+ bq_tensor->mutable_data <float >(platform::CPUPlace ());
513
+ memcpy (new_combined_bias_data, tmp_combined_bias_data,
514
+ sizeof (float ) * bq_tensor->numel ());
515
+
516
+ scope->EraseVars ({eltadd1_b->Name (), eltadd2_b->Name ()});
524
517
525
518
auto reshape_desc = reshape2->Op ();
526
519
int head_number =
@@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
529
522
OpDesc multihead_op_desc;
530
523
multihead_op_desc.SetType (" multihead_matmul" );
531
524
532
- multihead_op_desc.SetInput (" Input" , {layer_norm_out ->Name ()});
533
- multihead_op_desc.SetInput (" W" , {combined_w_node ->Name ()});
534
- multihead_op_desc.SetInput (" Bias" , {combined_bias_node ->Name ()});
525
+ multihead_op_desc.SetInput (" Input" , {input0 ->Name ()});
526
+ multihead_op_desc.SetInput (" W" , {mul0_w ->Name ()});
527
+ multihead_op_desc.SetInput (" Bias" , {eltadd0_b ->Name ()});
535
528
multihead_op_desc.SetInput (" BiasQK" , {eltadd_qk_b->Name ()});
536
529
537
530
multihead_op_desc.SetOutput (" Out" , {reshape2_qkv_out->Name ()});
@@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
540
533
541
534
auto * multihead = graph->CreateOpNode (&multihead_op_desc);
542
535
543
- IR_NODE_LINK_TO (layer_norm_out , multihead);
544
- IR_NODE_LINK_TO (combined_w_node , multihead);
545
- IR_NODE_LINK_TO (combined_bias_node , multihead);
536
+ IR_NODE_LINK_TO (input0 , multihead);
537
+ IR_NODE_LINK_TO (mul0_w , multihead);
538
+ IR_NODE_LINK_TO (eltadd0_b , multihead);
546
539
IR_NODE_LINK_TO (eltadd_qk_b, multihead);
547
540
548
541
IR_NODE_LINK_TO (multihead, reshape2_qkv_out);
@@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
552
545
auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
553
546
Graph* g) {
554
547
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
555
- GET_IR_NODE_FROM_SUBGRAPH (layer_norm, layer_norm, multihead_pattern);
556
- GET_IR_NODE_FROM_SUBGRAPH (layer_norm_out, layer_norm_out,
557
- multihead_pattern);
548
+ GET_IR_NODE_FROM_SUBGRAPH (input0, input0, multihead_pattern);
558
549
559
550
GET_IR_NODE_FROM_SUBGRAPH (mul0, mul0, multihead_pattern);
560
551
GET_IR_NODE_FROM_SUBGRAPH (mul0_out, mul0_out, multihead_pattern);
@@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
624
615
GET_IR_NODE_FROM_SUBGRAPH (transpose2_qkv_out, transpose2_qkv_out,
625
616
multihead_pattern);
626
617
627
- fuse_creater (layer_norm_out , mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
628
- mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b,
629
- eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);
618
+ fuse_creater (input0 , mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w ,
619
+ mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b ,
620
+ reshape2_0, reshape2_qkv_out, scale, scale_out);
630
621
631
622
std::unordered_set<const Node*> marked_nodes ({eltadd0,
632
623
eltadd1,
633
624
eltadd2,
634
- eltadd0_b,
635
625
eltadd1_b,
636
626
eltadd2_b,
637
627
eltadd0_out,
@@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
665
655
mul0_out,
666
656
mul1_out,
667
657
mul2_out,
668
- mul0_w,
669
658
mul1_w,
670
659
mul2_w,
671
660
reshape2_qkv,
0 commit comments