@@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
85
85
// (transpose_0, transpose_1) matmul -> matmul_qk
86
86
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
87
87
// (eltadd_qk) softmax -> softmax_qk
88
- // (softmax_qk) dropout -> dropout_qk
89
- // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
88
+ // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
90
89
// (matmul_qkv) transpose -> transpose_qkv
91
90
// (transpose_qkv) reshape -> reshape_qkv
92
91
// (reshape_qkv) matmul_v2 -> matmul_linear
93
92
// (matmul_linear) elementwise_add -> eltadd_linear
94
- // (eltadd_linear) dropout -> dropout_linear
95
93
// (eltadd_out) elementwise_add -> attention_out
96
94
//
97
95
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
@@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
100
98
// (ffn_eltadd0) gelu -> ffn_gelu
101
99
// (ffn_gelu) matmul_v2 -> ffn_matmul1
102
100
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
103
- // (ffn_eltadd1) dropout -> ffn_dropout
104
- // (attention_out, ffn_dropout) elementwise_add -> ffn_output
101
+ // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
105
102
106
103
Layers layers;
107
104
// MHA: pre LayerNorm
@@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
154
151
auto * bqk = layers.data (" biasqk" , {1 , 12 , 128 , 128 }, true );
155
152
auto * elementwise_qk = layers.elementwise_add (matmul_qk, bqk);
156
153
auto * softmax_qk = layers.softmax (elementwise_qk, -1 );
157
- auto * dropout_qk = layers.dropout (softmax_qk, 0.1 , " upscale_in_train" );
158
154
159
155
// MHA: QKV matmul
160
- auto * matmul_qkv = layers.matmul_v2 (dropout_qk , concat_v);
156
+ auto * matmul_qkv = layers.matmul_v2 (softmax_qk , concat_v);
161
157
162
158
auto * transpose_qkv = layers.transpose2 (matmul_qkv, {0 , 2 , 1 , 3 }, true );
163
159
auto * reshape_qkv_out = layers.reshape2 (transpose_qkv, {1 , 128 , 1024 }, true );
@@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
170
166
auto * linear_eltadd_out =
171
167
layers.elementwise_add (linear_matmut_out, bias_l, nullptr , 2 );
172
168
173
- auto * dropout_qkv =
174
- layers.dropout (linear_eltadd_out, 0.1 , " upscale_in_train" );
175
- auto * attention_out = layers.elementwise_add (x, dropout_qkv);
169
+ auto * attention_out = layers.elementwise_add (x, linear_eltadd_out);
176
170
177
171
// FFN: pre LayerNorm
178
172
auto * ffn_ln_scale = layers.data (" ffn_ln_scale" , {1024 }, true );
@@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
195
189
auto * ffn_eltadd1_out =
196
190
layers.elementwise_add (ffn_matmul1_out, ffn_bias1, nullptr , 2 );
197
191
198
- // FFN: dropout -> elementwise_add
199
- auto * ffn_dropout = layers.dropout (ffn_eltadd1_out, 0.1 , " upscale_in_train" );
200
- layers.elementwise_add (attention_out, ffn_dropout);
192
+ layers.elementwise_add (attention_out, ffn_eltadd1_out);
201
193
202
194
std::unique_ptr<ir::Graph> graph (new ir::Graph (layers.main_program ()));
203
195
graph->Set (" __param_scope__" , CreateParamScope ());
@@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
215
207
int num_fused_nodes_after = GetNumOpNodes (graph, " fused_multi_transformer" );
216
208
217
209
PADDLE_ENFORCE_EQ (num_nodes_before,
218
- num_nodes_after + 72 ,
210
+ num_nodes_after + 60 ,
219
211
platform::errors::InvalidArgument (
220
212
" After the fused_multi_transformer_decoder_pass, The "
221
213
" node num in graph "
222
214
" should be %d, but the result is %d" ,
223
- num_nodes_before - 72 ,
215
+ num_nodes_before - 60 ,
224
216
num_nodes_after));
225
217
PADDLE_ENFORCE_EQ (num_fused_nodes_after,
226
218
1 ,
@@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
253
245
// (split_q, split_k) matmul -> matmul_qk
254
246
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
255
247
// (eltadd_qk) softmax -> softmax_qk
256
- // (softmax_qk) dropout -> dropout_qk
257
- // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
248
+ // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
258
249
// (matmul_qkv) transpose -> transpose_qkv
259
250
// (transpose_qkv) reshape -> reshape_qkv
260
251
// (reshape_qkv) matmul_v2 -> matmul_linear
261
252
// (matmul_linear) elementwise_add -> eltadd_linear
262
- // (eltadd_linear) dropout -> dropout_linear
263
253
// (eltadd_out) elementwise_add -> attention_out
264
254
//
265
255
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
@@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
268
258
// (ffn_eltadd0) gelu -> ffn_gelu
269
259
// (ffn_gelu) matmul_v2 -> ffn_matmul1
270
260
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
271
- // (ffn_eltadd1) dropout -> ffn_dropout
272
- // (attention_out, ffn_dropout) elementwise_add -> ffn_output
261
+ // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
273
262
//
274
263
// (transpose_1, transpose_2) while -> decoder block
275
264
@@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
313
302
auto * bqk = layers.data (" biasqk" , {1 , 12 , 128 , 128 }, true );
314
303
auto * elementwise_qk = layers.elementwise_add (matmul_qk, bqk);
315
304
auto * softmax_qk = layers.softmax (elementwise_qk, -1 );
316
- auto * dropout_qk = layers.dropout (softmax_qk, 0.1 , " upscale_in_train" );
317
305
318
306
// MHA: QKV matmul
319
- auto * matmul_qkv = layers.matmul_v2 (dropout_qk , concat_v);
307
+ auto * matmul_qkv = layers.matmul_v2 (softmax_qk , concat_v);
320
308
321
309
auto * transpose_qkv = layers.transpose2 (matmul_qkv, {0 , 2 , 1 , 3 }, true );
322
310
auto * reshape_qkv_out = layers.reshape2 (transpose_qkv, {1 , 128 , 1024 }, true );
@@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
329
317
auto * linear_eltadd_out =
330
318
layers.elementwise_add (linear_matmut_out, bias_l, nullptr , 2 );
331
319
332
- auto * dropout_qkv =
333
- layers.dropout (linear_eltadd_out, 0.1 , " upscale_in_train" );
334
- auto * attention_out = layers.elementwise_add (x, dropout_qkv);
320
+ auto * attention_out = layers.elementwise_add (x, linear_eltadd_out);
335
321
336
322
// FFN: pre LayerNorm
337
323
auto * ffn_ln_scale = layers.data (" ffn_ln_scale" , {1024 }, true );
@@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
354
340
auto * ffn_eltadd1_out =
355
341
layers.elementwise_add (ffn_matmul1_out, ffn_bias1, nullptr , 2 );
356
342
357
- // FFN: dropout -> elementwise_add
358
- auto * ffn_dropout = layers.dropout (ffn_eltadd1_out, 0.1 , " upscale_in_train" );
359
- layers.elementwise_add (attention_out, ffn_dropout);
343
+ layers.elementwise_add (attention_out, ffn_eltadd1_out);
360
344
361
345
std::unique_ptr<ir::Graph> graph (new ir::Graph (layers.main_program ()));
362
346
graph->Set (" __param_scope__" , CreateParamScope ());
@@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
375
359
376
360
PADDLE_ENFORCE_EQ (
377
361
num_nodes_before,
378
- num_nodes_after + 62 ,
362
+ num_nodes_after + 50 ,
379
363
platform::errors::InvalidArgument (
380
364
" After the fused_multi_transformer_decoder_fuse_qkv_pass, "
381
365
" The node num in graph should be %d, but the result is %d" ,
382
- num_nodes_before - 62 ,
366
+ num_nodes_before - 50 ,
383
367
num_nodes_after));
384
368
PADDLE_ENFORCE_EQ (num_fused_nodes_after,
385
369
1 ,
@@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
413
397
// (split_q, split_k) matmul -> matmul_qk
414
398
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
415
399
// (eltadd_qk) softmax -> softmax_qk
416
- // (softmax_qk) dropout -> dropout_qk
417
- // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
400
+ // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
418
401
// (matmul_qkv) transpose -> transpose_qkv
419
402
// (transpose_qkv) reshape -> reshape_qkv
420
403
// (reshape_qkv) matmul_v2 -> matmul_linear
421
404
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out
422
405
// (matmul_linear) elementwise_add -> eltadd_linear
423
- // (eltadd_linear) dropout -> dropout_linear
424
406
// (eltadd_out) elementwise_add -> attention_out
425
407
//
426
408
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
@@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
431
413
// (ffn_gelu) matmul_v2 -> ffn_matmul1
432
414
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
433
415
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
434
- // (ffn_eltadd1) dropout -> ffn_dropout
435
- // (attention_out, ffn_dropout) elementwise_add -> ffn_output
416
+ // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
436
417
//
437
418
// (transpose_1, transpose_2) while -> decoder block
438
419
@@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
477
458
auto * bqk = layers.data (" biasqk" , {1 , 12 , 128 , 128 }, true );
478
459
auto * elementwise_qk = layers.elementwise_add (matmul_qk, bqk);
479
460
auto * softmax_qk = layers.softmax (elementwise_qk, -1 );
480
- auto * dropout_qk = layers.dropout (softmax_qk, 0.1 , " upscale_in_train" );
481
461
482
462
// MHA: QKV matmul
483
- auto * matmul_qkv = layers.matmul_v2 (dropout_qk , concat_v);
463
+ auto * matmul_qkv = layers.matmul_v2 (softmax_qk , concat_v);
484
464
485
465
auto * transpose_qkv = layers.transpose2 (matmul_qkv, {0 , 2 , 1 , 3 }, true );
486
466
auto * reshape_qkv_out = layers.reshape2 (transpose_qkv, {1 , 128 , 1024 }, true );
@@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
494
474
auto * linear_eltadd_out =
495
475
layers.elementwise_add (c_allreduce_out, bias_l, nullptr , 2 );
496
476
497
- auto * dropout_qkv =
498
- layers.dropout (linear_eltadd_out, 0.1 , " upscale_in_train" );
499
- auto * attention_out = layers.elementwise_add (x, dropout_qkv);
477
+ auto * attention_out = layers.elementwise_add (x, linear_eltadd_out);
500
478
501
479
// FFN: pre LayerNorm
502
480
auto * ffn_ln_scale = layers.data (" ffn_ln_scale" , {1024 }, true );
@@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
521
499
auto * ffn_eltadd1_out =
522
500
layers.elementwise_add (ffn_c_allreduce_out, ffn_bias1, nullptr , 2 );
523
501
524
- // FFN: dropout -> elementwise_add
525
- auto * ffn_dropout = layers.dropout (ffn_eltadd1_out, 0.1 , " upscale_in_train" );
526
- layers.elementwise_add (attention_out, ffn_dropout);
502
+ layers.elementwise_add (attention_out, ffn_eltadd1_out);
527
503
528
504
std::unique_ptr<ir::Graph> graph (new ir::Graph (layers.main_program ()));
529
505
graph->Set (" __param_scope__" , CreateParamScope ());
@@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
544
520
545
521
PADDLE_ENFORCE_EQ (
546
522
num_nodes_before,
547
- num_nodes_after + 70 ,
523
+ num_nodes_after + 58 ,
548
524
platform::errors::InvalidArgument (
549
525
" After the fused_multi_transformer_decoder_fuse_qkv_pass, "
550
526
" The node num in graph should be %d, but the result is %d" ,
551
- num_nodes_before - 70 ,
527
+ num_nodes_before - 58 ,
552
528
num_nodes_after));
553
529
PADDLE_ENFORCE_EQ (num_fused_nodes_after,
554
530
1 ,
0 commit comments