@@ -141,11 +141,13 @@ class RmsNormFusePattern : public paddle::drr::DrrPatternBase {
141
141
class AddRmsNormFusePattern : public paddle ::drr::DrrPatternBase {
142
142
private:
143
143
const bool extra_add_;
144
+ const bool trans_extra_add_;
144
145
145
146
public:
146
- explicit AddRmsNormFusePattern (bool extra_add) : extra_add_(extra_add) {}
147
+ AddRmsNormFusePattern (bool extra_add, bool trans_extra_add)
148
+ : extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
147
149
148
- uint32_t benefit () const override { return extra_add_ ? 2 : 1 ; }
150
+ uint32_t benefit () const override { return extra_add_ ? 4 : 3 ; }
149
151
150
152
std::string name () const override { return " AddRmsNormFusePattern" ; }
151
153
@@ -176,7 +178,9 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
176
178
if (extra_add_) {
177
179
const auto &add1 = pat.Op (paddle::dialect::AddOp::name ());
178
180
pat.Tensor (" add_out1" ) =
179
- add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
181
+ trans_extra_add_
182
+ ? add1 (pat.Tensor (" any_tensor" ), pat.Tensor (" add_out" ))
183
+ : add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
180
184
}
181
185
paddle::drr::ResultPattern res = pat.ResultPattern ();
182
186
const auto &res_rms_norm =
@@ -207,11 +211,13 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
207
211
class AddLayerNormFusePattern : public paddle ::drr::DrrPatternBase {
208
212
private:
209
213
const bool extra_add_;
214
+ const bool trans_extra_add_;
210
215
211
216
public:
212
- explicit AddLayerNormFusePattern (bool extra_add) : extra_add_(extra_add) {}
217
+ AddLayerNormFusePattern (bool extra_add, bool trans_extra_add)
218
+ : extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
213
219
214
- uint32_t benefit () const override { return extra_add_ ? 2 : 1 ; }
220
+ uint32_t benefit () const override { return extra_add_ ? 4 : 3 ; }
215
221
std::string name () const override { return " AddLayerNormFusePattern" ; }
216
222
217
223
void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
@@ -231,22 +237,20 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
231
237
if (extra_add_) {
232
238
const auto &add1 = pat.Op (paddle::dialect::AddOp::name ());
233
239
pat.Tensor (" add_out1" ) =
234
- add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
240
+ trans_extra_add_
241
+ ? add1 (pat.Tensor (" any_tensor" ), pat.Tensor (" add_out" ))
242
+ : add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
235
243
}
236
244
237
245
paddle::drr::ResultPattern res = pat.ResultPattern ();
238
246
const auto &cast_op_dtype = res.ComputeAttr (
239
247
[](const paddle::drr::MatchContext &match_ctx) -> phi::DataType {
240
- auto x_dtype = pir::GetDataTypeFromValue (match_ctx.Tensor (" x" ));
241
- return paddle::dialect::TransToPhiDataType (x_dtype);
248
+ return phi::DataType::FLOAT32;
242
249
});
243
- const auto &cast_op_1 =
250
+ const auto cast_1_op =
244
251
res.Op (paddle::dialect::CastOp::name (), {{" dtype" , cast_op_dtype}});
245
- res.Tensor (" casted_bias" ) = cast_op_1 (res.Tensor (" bias" ));
246
- const auto &cast_op_2 =
252
+ const auto cast_2_op =
247
253
res.Op (paddle::dialect::CastOp::name (), {{" dtype" , cast_op_dtype}});
248
- res.Tensor (" casted_w" ) = cast_op_2 (res.Tensor (" w" ));
249
-
250
254
const auto &fuse_layer_norm =
251
255
res.Op (paddle::dialect::FusedBiasResidualLayernormOp::name (),
252
256
{{" epsilon" , pat.Attr (" epsilon" )},
@@ -256,14 +260,15 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
256
260
{" quant_round_type" , res.Int32Attr (0 )},
257
261
{" quant_max_bound" , res.Float32Attr (0.0 )},
258
262
{" quant_min_bound" , res.Float32Attr (0.0 )}});
259
-
263
+ res.Tensor (" w_cast" ) = cast_1_op (res.Tensor (" w" ));
264
+ res.Tensor (" bias_cast" ) = cast_1_op (res.Tensor (" bias" ));
260
265
fuse_layer_norm (
261
266
{
262
267
&res.Tensor (" x" ),
263
- &res.Tensor (" casted_bias" ),
264
- &res.Tensor (" residual" ),
265
- &res.Tensor (" casted_w" ),
266
268
&res.InputNoneTensor (),
269
+ &res.Tensor (" residual" ),
270
+ &res.Tensor (" w_cast" ),
271
+ &res.Tensor (" bias_cast" ),
267
272
},
268
273
{&res.Tensor (" layer_norm_out" ),
269
274
&res.Tensor (" add_out" ),
@@ -272,6 +277,163 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
272
277
}
273
278
};
274
279
280
+ class AddGroupNormFusePattern : public paddle ::drr::DrrPatternBase {
281
+ private:
282
+ const bool extra_add_;
283
+ const bool trans_extra_add_;
284
+
285
+ public:
286
+ AddGroupNormFusePattern (bool extra_add, bool trans_extra_add)
287
+ : extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
288
+
289
+ uint32_t benefit () const override { return extra_add_ ? 4 : 3 ; }
290
+ std::string name () const override { return " AddGroupNormFusePattern" ; }
291
+
292
+ void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
293
+ paddle::drr::SourcePattern pat = ctx->SourcePattern ();
294
+ const auto &add = pat.Op (paddle::dialect::AddOp::name ());
295
+ const auto &group_norm = pat.Op (paddle::dialect::GroupNormOp::name (),
296
+ {{" epsilon" , pat.Attr (" epsilon" )},
297
+ {" groups" , pat.Attr (" groups" )},
298
+ {" data_format" , pat.Attr (" data_format" )}});
299
+ pat.Tensor (" add_out" ) = add (pat.Tensor (" x" ), pat.Tensor (" residual" ));
300
+ group_norm (
301
+ {&pat.Tensor (" add_out" ), &pat.Tensor (" scale" ), &pat.Tensor (" bias" )},
302
+ {&pat.Tensor (" group_out" ),
303
+ &pat.Tensor (" mean_out_0" ),
304
+ &pat.Tensor (" variance_out_0" )});
305
+ // TODO(bukejiyu) :DRR support matching placeholder op,
306
+ // the following needs to be deleted
307
+ if (extra_add_) {
308
+ const auto &add1 = pat.Op (paddle::dialect::AddOp::name ());
309
+ pat.Tensor (" add_out1" ) =
310
+ trans_extra_add_
311
+ ? add1 (pat.Tensor (" any_tensor" ), pat.Tensor (" add_out" ))
312
+ : add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
313
+ }
314
+ pat.AddConstraint ([this ](const paddle::drr::MatchContext &match_ctx) {
315
+ auto x_dtype = pir::GetDataTypeFromValue (match_ctx.Tensor (" x" ));
316
+ if (!x_dtype.isa <pir::Float16Type>() &&
317
+ !x_dtype.isa <pir::BFloat16Type>()) {
318
+ return false ;
319
+ }
320
+ return true ;
321
+ });
322
+ paddle::drr::ResultPattern res = pat.ResultPattern ();
323
+ const auto &add_group_norm_silu_op =
324
+ res.Op (paddle::dialect::AddGroupNormSiluOp::name (),
325
+ {{" epsilon" , pat.Attr (" epsilon" )},
326
+ {" groups" , pat.Attr (" groups" )},
327
+ {" data_format" , pat.Attr (" data_format" )},
328
+ {" activation" , res.StrAttr (" " )}});
329
+
330
+ add_group_norm_silu_op ({&res.Tensor (" x" ),
331
+ &res.Tensor (" residual" ),
332
+ &res.Tensor (" scale" ),
333
+ &res.Tensor (" bias" )},
334
+ {&res.Tensor (" group_out" ),
335
+ &res.Tensor (" add_out" ),
336
+ &res.Tensor (" mean_out" ),
337
+ &res.Tensor (" variance_out" )});
338
+ }
339
+ };
340
+
341
+ class AddGroupNormWithActPattern : public paddle ::drr::DrrPatternBase {
342
+ public:
343
+ uint32_t benefit () const override { return 2 ; }
344
+ std::string name () const override { return " AddGroupNormWithActPattern" ; }
345
+
346
+ void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
347
+ paddle::drr::SourcePattern pat = ctx->SourcePattern ();
348
+ const auto &add_group_norm_silu_op =
349
+ pat.Op (paddle::dialect::AddGroupNormSiluOp::name (),
350
+ {{" epsilon" , pat.Attr (" epsilon" )},
351
+ {" groups" , pat.Attr (" groups" )},
352
+ {" data_format" , pat.Attr (" data_format" )},
353
+ {" activation" , pat.Attr (" activation" )}});
354
+ const auto &silu = pat.Op (paddle::dialect::SiluOp::name ());
355
+ add_group_norm_silu_op ({&pat.Tensor (" x" ),
356
+ &pat.Tensor (" residual" ),
357
+ &pat.Tensor (" scale" ),
358
+ &pat.Tensor (" bias" )},
359
+ {&pat.Tensor (" group_out" ),
360
+ &pat.Tensor (" add_out" ),
361
+ &pat.Tensor (" mean_out_0" ),
362
+ &pat.Tensor (" variance_out_0" )});
363
+ pat.Tensor (" silu_out" ) = silu (pat.Tensor (" group_out" ));
364
+ pat.AddConstraint ([this ](const paddle::drr::MatchContext &match_ctx) {
365
+ auto x_dtype = pir::GetDataTypeFromValue (match_ctx.Tensor (" x" ));
366
+ if (!x_dtype.isa <pir::Float16Type>() &&
367
+ !x_dtype.isa <pir::BFloat16Type>()) {
368
+ return false ;
369
+ }
370
+ auto activation = match_ctx.Attr <std::string>(" activation" );
371
+ if (activation != " " ) {
372
+ return false ;
373
+ }
374
+ return true ;
375
+ });
376
+ paddle::drr::ResultPattern res = pat.ResultPattern ();
377
+ const auto &res_add_group_norm_silu_op =
378
+ res.Op (paddle::dialect::AddGroupNormSiluOp::name (),
379
+ {{" epsilon" , pat.Attr (" epsilon" )},
380
+ {" groups" , pat.Attr (" groups" )},
381
+ {" data_format" , pat.Attr (" data_format" )},
382
+ {" activation" , res.StrAttr (" silu" )}});
383
+ res_add_group_norm_silu_op ({&res.Tensor (" x" ),
384
+ &res.Tensor (" residual" ),
385
+ &res.Tensor (" scale" ),
386
+ &res.Tensor (" bias" )},
387
+ {&res.Tensor (" silu_out" ),
388
+ &res.Tensor (" add_out" ),
389
+ &res.Tensor (" mean_out" ),
390
+ &res.Tensor (" variance_out" )});
391
+ }
392
+ };
393
+
394
+ class GroupNormWithActPattern : public paddle ::drr::DrrPatternBase {
395
+ public:
396
+ uint32_t benefit () const override { return 1 ; }
397
+ std::string name () const override { return " GroupNormWithActPattern" ; }
398
+
399
+ void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
400
+ paddle::drr::SourcePattern pat = ctx->SourcePattern ();
401
+ const auto &group_norm = pat.Op (paddle::dialect::GroupNormOp::name (),
402
+ {{" epsilon" , pat.Attr (" epsilon" )},
403
+ {" groups" , pat.Attr (" groups" )},
404
+ {" data_format" , pat.Attr (" data_format" )}});
405
+ const auto &silu = pat.Op (paddle::dialect::SiluOp::name ());
406
+ group_norm ({&pat.Tensor (" x" ), &pat.Tensor (" scale" ), &pat.Tensor (" bias" )},
407
+ {&pat.Tensor (" group_out" ),
408
+ &pat.Tensor (" mean_out_0" ),
409
+ &pat.Tensor (" variance_out_0" )});
410
+ pat.Tensor (" silu_out" ) = silu (pat.Tensor (" group_out" ));
411
+ pat.AddConstraint ([this ](const paddle::drr::MatchContext &match_ctx) {
412
+ auto x_dtype = pir::GetDataTypeFromValue (match_ctx.Tensor (" x" ));
413
+ if (!x_dtype.isa <pir::Float16Type>() &&
414
+ !x_dtype.isa <pir::BFloat16Type>()) {
415
+ return false ;
416
+ }
417
+ return true ;
418
+ });
419
+ paddle::drr::ResultPattern res = pat.ResultPattern ();
420
+ const auto &add_group_norm_silu_op =
421
+ res.Op (paddle::dialect::AddGroupNormSiluOp::name (),
422
+ {{" epsilon" , pat.Attr (" epsilon" )},
423
+ {" groups" , pat.Attr (" groups" )},
424
+ {" data_format" , pat.Attr (" data_format" )},
425
+ {" activation" , res.StrAttr (" silu" )}});
426
+ add_group_norm_silu_op ({&res.Tensor (" x" ),
427
+ &res.InputNoneTensor (),
428
+ &res.Tensor (" scale" ),
429
+ &res.Tensor (" bias" )},
430
+ {&res.Tensor (" silu_out" ),
431
+ &res.OutputNoneTensor (),
432
+ &res.Tensor (" mean_out" ),
433
+ &res.Tensor (" variance_out" )});
434
+ }
435
+ };
436
+
275
437
class AddNormFusePass : public pir ::PatternRewritePass {
276
438
public:
277
439
AddNormFusePass () : pir::PatternRewritePass(" add_norm_fuse_pass" , 2 ) {}
@@ -290,13 +452,37 @@ class AddNormFusePass : public pir::PatternRewritePass {
290
452
// x--------
291
453
// add-rms_norm ---> rms_norm
292
454
// residual-
293
- ps.Add (paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add));
294
- ps.Add (paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add));
455
+ ps.Add (
456
+ paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add, false ));
457
+ ps.Add (
458
+ paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add, true ));
459
+ ps.Add (
460
+ paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add, false ));
461
+
295
462
// x--------
296
463
// add-layer_norm ----> fused_bias_residual_layernorm
297
464
// residual-
298
- ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(context, !extra_add));
299
- ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add));
465
+ ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(
466
+ context, !extra_add, false ));
467
+ ps.Add (
468
+ paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add, true ));
469
+ ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(
470
+ context, extra_add, false ));
471
+
472
+ // x--------
473
+ // add-group_norm ----> add_group_norm_silu
474
+ // residual-
475
+ ps.Add (paddle::drr::Create<AddGroupNormFusePattern>(
476
+ context, !extra_add, true ));
477
+ ps.Add (
478
+ paddle::drr::Create<AddGroupNormFusePattern>(context, extra_add, true ));
479
+ ps.Add (paddle::drr::Create<AddGroupNormFusePattern>(
480
+ context, extra_add, false ));
481
+
482
+ // add_group_norm_silu-silu --->add_group_norm_silu
483
+ ps.Add (paddle::drr::Create<AddGroupNormWithActPattern>(context));
484
+ // group-silu->add_group_norm_silu
485
+ ps.Add (paddle::drr::Create<GroupNormWithActPattern>(context));
300
486
return ps;
301
487
}
302
488
};
0 commit comments