Skip to content

Commit ba4fbe7

Browse files
authored
[cherry pick] fix memory copy in prepare_data of FusedMultiTransformer pass (#47308)
* fix memory copy in prepare_data. test=develop * add cache_kv fp16 support. test=develop * fit for simplify_with_basic_ops_pass. test=develop
1 parent 7a1cf27 commit ba4fbe7

7 files changed

+154
-580
lines changed

paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc

Lines changed: 51 additions & 227 deletions
Large diffs are not rendered by default.

paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
8888
PATTERN_DECL_NODE(eltadd_qk_out);
8989
PATTERN_DECL_NODE(softmax_qk);
9090
PATTERN_DECL_NODE(softmax_qk_out);
91-
PATTERN_DECL_NODE(dropout_qk);
92-
PATTERN_DECL_NODE(dropout_qk_out);
9391

9492
// QK, V matmul
9593
PATTERN_DECL_NODE(matmul_qkv);
@@ -106,8 +104,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
106104
PATTERN_DECL_NODE(eltadd_linear);
107105
PATTERN_DECL_NODE(eltadd_linear_b);
108106
PATTERN_DECL_NODE(eltadd_linear_out);
109-
PATTERN_DECL_NODE(dropout_linear);
110-
PATTERN_DECL_NODE(dropout_linear_out);
111107

112108
// output elementwise_add
113109
PATTERN_DECL_NODE(eltadd_out)
@@ -137,8 +133,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
137133
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
138134
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
139135
PATTERN_DECL_NODE(ffn_eltadd1_out);
140-
PATTERN_DECL_NODE(ffn_dropout);
141-
PATTERN_DECL_NODE(ffn_dropout_out);
142136

143137
// output elementwise_add
144138
PATTERN_DECL_NODE(ffn_eltadd_out)
@@ -193,8 +187,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
193187
PATTERN_DECL_NODE(eltadd_qk_out);
194188
PATTERN_DECL_NODE(softmax_qk);
195189
PATTERN_DECL_NODE(softmax_qk_out);
196-
PATTERN_DECL_NODE(dropout_qk);
197-
PATTERN_DECL_NODE(dropout_qk_out);
198190

199191
// QK, V matmul
200192
PATTERN_DECL_NODE(matmul_qkv);
@@ -211,8 +203,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
211203
PATTERN_DECL_NODE(eltadd_linear);
212204
PATTERN_DECL_NODE(eltadd_linear_b);
213205
PATTERN_DECL_NODE(eltadd_linear_out);
214-
PATTERN_DECL_NODE(dropout_linear);
215-
PATTERN_DECL_NODE(dropout_linear_out);
216206

217207
// output elementwise_add
218208
PATTERN_DECL_NODE(eltadd_out)
@@ -239,8 +229,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
239229
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
240230
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
241231
PATTERN_DECL_NODE(ffn_eltadd1_out);
242-
PATTERN_DECL_NODE(ffn_dropout);
243-
PATTERN_DECL_NODE(ffn_dropout_out);
244232

245233
// output elementwise_add
246234
PATTERN_DECL_NODE(ffn_eltadd_out)
@@ -299,8 +287,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
299287
PATTERN_DECL_NODE(eltadd_qk_out);
300288
PATTERN_DECL_NODE(softmax_qk);
301289
PATTERN_DECL_NODE(softmax_qk_out);
302-
PATTERN_DECL_NODE(dropout_qk);
303-
PATTERN_DECL_NODE(dropout_qk_out);
304290

305291
// QK, V matmul
306292
PATTERN_DECL_NODE(matmul_qkv);
@@ -319,8 +305,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
319305
PATTERN_DECL_NODE(eltadd_linear);
320306
PATTERN_DECL_NODE(eltadd_linear_b);
321307
PATTERN_DECL_NODE(eltadd_linear_out);
322-
PATTERN_DECL_NODE(dropout_linear);
323-
PATTERN_DECL_NODE(dropout_linear_out);
324308

325309
// output elementwise_add
326310
PATTERN_DECL_NODE(eltadd_out)
@@ -351,8 +335,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
351335
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
352336
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
353337
PATTERN_DECL_NODE(ffn_eltadd1_out);
354-
PATTERN_DECL_NODE(ffn_dropout);
355-
PATTERN_DECL_NODE(ffn_dropout_out);
356338

357339
// output elementwise_add
358340
PATTERN_DECL_NODE(ffn_eltadd_out)

paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
8585
// (transpose_0, transpose_1) matmul -> matmul_qk
8686
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
8787
// (eltadd_qk) softmax -> softmax_qk
88-
// (softmax_qk) dropout -> dropout_qk
89-
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
88+
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
9089
// (matmul_qkv) transpose -> transpose_qkv
9190
// (transpose_qkv) reshape -> reshape_qkv
9291
// (reshape_qkv) matmul_v2 -> matmul_linear
9392
// (matmul_linear) elementwise_add -> eltadd_linear
94-
// (eltadd_linear) dropout -> dropout_linear
9593
// (eltadd_out) elementwise_add -> attention_out
9694
//
9795
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
@@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
10098
// (ffn_eltadd0) gelu -> ffn_gelu
10199
// (ffn_gelu) matmul_v2 -> ffn_matmul1
102100
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
103-
// (ffn_eltadd1) dropout -> ffn_dropout
104-
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
101+
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
105102

106103
Layers layers;
107104
// MHA: pre LayerNorm
@@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
154151
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
155152
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
156153
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
157-
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
158154

159155
// MHA: QKV matmul
160-
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
156+
auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
161157

162158
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
163159
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
@@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
170166
auto* linear_eltadd_out =
171167
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
172168

173-
auto* dropout_qkv =
174-
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
175-
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
169+
auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
176170

177171
// FFN: pre LayerNorm
178172
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
@@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
195189
auto* ffn_eltadd1_out =
196190
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
197191

198-
// FFN: dropout -> elementwise_add
199-
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
200-
layers.elementwise_add(attention_out, ffn_dropout);
192+
layers.elementwise_add(attention_out, ffn_eltadd1_out);
201193

202194
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
203195
graph->Set("__param_scope__", CreateParamScope());
@@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
215207
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
216208

217209
PADDLE_ENFORCE_EQ(num_nodes_before,
218-
num_nodes_after + 72,
210+
num_nodes_after + 60,
219211
platform::errors::InvalidArgument(
220212
"After the fused_multi_transformer_decoder_pass, The "
221213
"node num in graph "
222214
"should be %d, but the result is %d",
223-
num_nodes_before - 72,
215+
num_nodes_before - 60,
224216
num_nodes_after));
225217
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
226218
1,
@@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
253245
// (split_q, split_k) matmul -> matmul_qk
254246
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
255247
// (eltadd_qk) softmax -> softmax_qk
256-
// (softmax_qk) dropout -> dropout_qk
257-
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
248+
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
258249
// (matmul_qkv) transpose -> transpose_qkv
259250
// (transpose_qkv) reshape -> reshape_qkv
260251
// (reshape_qkv) matmul_v2 -> matmul_linear
261252
// (matmul_linear) elementwise_add -> eltadd_linear
262-
// (eltadd_linear) dropout -> dropout_linear
263253
// (eltadd_out) elementwise_add -> attention_out
264254
//
265255
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
@@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
268258
// (ffn_eltadd0) gelu -> ffn_gelu
269259
// (ffn_gelu) matmul_v2 -> ffn_matmul1
270260
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
271-
// (ffn_eltadd1) dropout -> ffn_dropout
272-
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
261+
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
273262
//
274263
// (transpose_1, transpose_2) while -> decoder block
275264

@@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
313302
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
314303
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
315304
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
316-
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
317305

318306
// MHA: QKV matmul
319-
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
307+
auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
320308

321309
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
322310
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
@@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
329317
auto* linear_eltadd_out =
330318
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
331319

332-
auto* dropout_qkv =
333-
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
334-
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
320+
auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
335321

336322
// FFN: pre LayerNorm
337323
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
@@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
354340
auto* ffn_eltadd1_out =
355341
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
356342

357-
// FFN: dropout -> elementwise_add
358-
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
359-
layers.elementwise_add(attention_out, ffn_dropout);
343+
layers.elementwise_add(attention_out, ffn_eltadd1_out);
360344

361345
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
362346
graph->Set("__param_scope__", CreateParamScope());
@@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
375359

376360
PADDLE_ENFORCE_EQ(
377361
num_nodes_before,
378-
num_nodes_after + 62,
362+
num_nodes_after + 50,
379363
platform::errors::InvalidArgument(
380364
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
381365
"The node num in graph should be %d, but the result is %d",
382-
num_nodes_before - 62,
366+
num_nodes_before - 50,
383367
num_nodes_after));
384368
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
385369
1,
@@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
413397
// (split_q, split_k) matmul -> matmul_qk
414398
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
415399
// (eltadd_qk) softmax -> softmax_qk
416-
// (softmax_qk) dropout -> dropout_qk
417-
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
400+
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
418401
// (matmul_qkv) transpose -> transpose_qkv
419402
// (transpose_qkv) reshape -> reshape_qkv
420403
// (reshape_qkv) matmul_v2 -> matmul_linear
421404
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out
422405
// (matmul_linear) elementwise_add -> eltadd_linear
423-
// (eltadd_linear) dropout -> dropout_linear
424406
// (eltadd_out) elementwise_add -> attention_out
425407
//
426408
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
@@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
431413
// (ffn_gelu) matmul_v2 -> ffn_matmul1
432414
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
433415
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
434-
// (ffn_eltadd1) dropout -> ffn_dropout
435-
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
416+
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
436417
//
437418
// (transpose_1, transpose_2) while -> decoder block
438419

@@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
477458
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
478459
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
479460
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
480-
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
481461

482462
// MHA: QKV matmul
483-
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
463+
auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
484464

485465
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
486466
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
@@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
494474
auto* linear_eltadd_out =
495475
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
496476

497-
auto* dropout_qkv =
498-
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
499-
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
477+
auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
500478

501479
// FFN: pre LayerNorm
502480
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
@@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
521499
auto* ffn_eltadd1_out =
522500
layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2);
523501

524-
// FFN: dropout -> elementwise_add
525-
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
526-
layers.elementwise_add(attention_out, ffn_dropout);
502+
layers.elementwise_add(attention_out, ffn_eltadd1_out);
527503

528504
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
529505
graph->Set("__param_scope__", CreateParamScope());
@@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
544520

545521
PADDLE_ENFORCE_EQ(
546522
num_nodes_before,
547-
num_nodes_after + 70,
523+
num_nodes_after + 58,
548524
platform::errors::InvalidArgument(
549525
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
550526
"The node num in graph should be %d, but the result is %d",
551-
num_nodes_before - 70,
527+
num_nodes_before - 58,
552528
num_nodes_after));
553529
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
554530
1,

0 commit comments

Comments
 (0)