Skip to content

Commit ed208aa

Browse files
authored
[Inference] Add add_group_norm_silu kernel and group_norm related pattern (#64199) (#64876)
* add group_kernel and add_norm_fuse pass * update * fix ci
1 parent 827bbce commit ed208aa

File tree

16 files changed

+1227
-194
lines changed

16 files changed

+1227
-194
lines changed

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@ const std::vector<std::string> kPirGpuPasses{
613613
"fused_weight_only_linear_pass",
614614
"matmul_add_act_fuse_pass",
615615
"fc_elementwise_layernorm_fuse_pass",
616+
"add_norm_fuse_pass",
616617
"matmul_scale_fuse_pass",
617618
"matmul_transpose_fuse_pass",
618619
"transpose_flatten_concat_fuse_pass",

paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,45 @@ void RewriteByInfermeta(pir::Operation* op, common::DataLayout new_layout) {
3838
}
3939
}
4040

41+
template <>
42+
std::vector<pir::Value> RelevantInputsImpl<AddGroupNormSiluOp>(
43+
pir::Operation* op) {
44+
auto concrete_op = op->dyn_cast<AddGroupNormSiluOp>();
45+
return {concrete_op.x(), concrete_op.residual()};
46+
}
47+
48+
template <>
49+
std::vector<pir::Value> RelevantOutputsImpl<AddGroupNormSiluOp>(
50+
pir::Operation* op) {
51+
auto concrete_op = op->dyn_cast<AddGroupNormSiluOp>();
52+
return {concrete_op.y(), concrete_op.residual_out()};
53+
}
54+
55+
template <>
56+
common::DataLayout PreferLayoutImpl<AddGroupNormSiluOp>(pir::Operation* op) {
57+
// Note(bukejiyu): add_group_norm_silu only supports NHWC layout now.
58+
return common::DataLayout::NHWC;
59+
}
60+
61+
template <>
62+
void RewriteByLayoutImpl<AddGroupNormSiluOp>(pir::Operation* op,
63+
common::DataLayout new_layout) {
64+
op->set_attribute(
65+
"data_format",
66+
pir::StrAttribute::get(pir::IrContext::Instance(),
67+
common::DataLayoutToString(new_layout)));
68+
69+
std::vector<pir::Type> new_outputs = AddGroupNormSiluOp::InferMeta(
70+
op->operands_source(), const_cast<pir::AttributeMap*>(&op->attributes()));
71+
for (size_t i = 0; i < new_outputs.size(); ++i) {
72+
op->result(i).set_type(new_outputs[i]);
73+
}
74+
75+
for (auto value : RelevantOutputsImpl<AddGroupNormSiluOp>(op)) {
76+
SetNewLayoutForValue(value, new_layout);
77+
}
78+
}
79+
4180
template <>
4281
common::DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
4382
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
@@ -97,6 +136,14 @@ common::DataLayout PreferLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op) {
97136
auto original_layout =
98137
common::StringToDataLayout(data_format_attr.AsString());
99138

139+
if (op->HasAttribute(kForceBackendAttr) &&
140+
op->attributes()
141+
.at(kForceBackendAttr)
142+
.dyn_cast<pir::StrAttribute>()
143+
.AsString() == "gpu") {
144+
return common::DataLayout::NHWC;
145+
}
146+
100147
auto concrete_op = op->dyn_cast<FusedConv2dAddActOp>();
101148
if (auto in = concrete_op.input()) {
102149
if (auto in_type = in.type()) {

paddle/fluid/pir/dialect/operator/interface/layout_transformation.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ OVERLOAD_REWRITE_BY_LAYOUT(GroupNormOp);
117117
OVERLOAD_RELEVANT_INPUTS(GroupNormOp);
118118
OVERLOAD_RELEVANT_OUTPUTS(GroupNormOp);
119119

120+
class AddGroupNormSiluOp;
121+
OVERLOAD_REWRITE_BY_LAYOUT(AddGroupNormSiluOp);
122+
OVERLOAD_PREFER_LAYOUT(AddGroupNormSiluOp);
123+
OVERLOAD_RELEVANT_INPUTS(AddGroupNormSiluOp);
124+
OVERLOAD_RELEVANT_OUTPUTS(AddGroupNormSiluOp);
125+
120126
class ReshapeOp;
121127
OVERLOAD_RELEVANT_INPUTS(ReshapeOp);
122128
OVERLOAD_RELEVANT_OUTPUTS(ReshapeOp);

paddle/fluid/pir/drr/src/rewrite_pattern.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ bool DrrRewritePattern::MatchFromOutputToInput(
391391
ir_input_values[i].use_count()) {
392392
matched = false;
393393
VLOG(8) << drr_node->name() << " Match failed: consumers of drr intput["
394-
<< i << "] { " << drr_node->outputs().size()
394+
<< i << "] { " << drr_input_tensors[i]->consumers().size()
395395
<< " } != consumers of pir intput[" << i << "] { "
396396
<< ir_input_values[i].use_count() << " }.";
397397
break;

paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.cc

Lines changed: 207 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,13 @@ class RmsNormFusePattern : public paddle::drr::DrrPatternBase {
141141
class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
142142
private:
143143
const bool extra_add_;
144+
const bool trans_extra_add_;
144145

145146
public:
146-
explicit AddRmsNormFusePattern(bool extra_add) : extra_add_(extra_add) {}
147+
AddRmsNormFusePattern(bool extra_add, bool trans_extra_add)
148+
: extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
147149

148-
uint32_t benefit() const override { return extra_add_ ? 2 : 1; }
150+
uint32_t benefit() const override { return extra_add_ ? 4 : 3; }
149151

150152
std::string name() const override { return "AddRmsNormFusePattern"; }
151153

@@ -176,7 +178,9 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
176178
if (extra_add_) {
177179
const auto &add1 = pat.Op(paddle::dialect::AddOp::name());
178180
pat.Tensor("add_out1") =
179-
add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
181+
trans_extra_add_
182+
? add1(pat.Tensor("any_tensor"), pat.Tensor("add_out"))
183+
: add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
180184
}
181185
paddle::drr::ResultPattern res = pat.ResultPattern();
182186
const auto &res_rms_norm =
@@ -207,11 +211,13 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
207211
class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
208212
private:
209213
const bool extra_add_;
214+
const bool trans_extra_add_;
210215

211216
public:
212-
explicit AddLayerNormFusePattern(bool extra_add) : extra_add_(extra_add) {}
217+
AddLayerNormFusePattern(bool extra_add, bool trans_extra_add)
218+
: extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
213219

214-
uint32_t benefit() const override { return extra_add_ ? 2 : 1; }
220+
uint32_t benefit() const override { return extra_add_ ? 4 : 3; }
215221
std::string name() const override { return "AddLayerNormFusePattern"; }
216222

217223
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
@@ -231,22 +237,20 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
231237
if (extra_add_) {
232238
const auto &add1 = pat.Op(paddle::dialect::AddOp::name());
233239
pat.Tensor("add_out1") =
234-
add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
240+
trans_extra_add_
241+
? add1(pat.Tensor("any_tensor"), pat.Tensor("add_out"))
242+
: add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
235243
}
236244

237245
paddle::drr::ResultPattern res = pat.ResultPattern();
238246
const auto &cast_op_dtype = res.ComputeAttr(
239247
[](const paddle::drr::MatchContext &match_ctx) -> phi::DataType {
240-
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
241-
return paddle::dialect::TransToPhiDataType(x_dtype);
248+
return phi::DataType::FLOAT32;
242249
});
243-
const auto &cast_op_1 =
250+
const auto cast_1_op =
244251
res.Op(paddle::dialect::CastOp::name(), {{"dtype", cast_op_dtype}});
245-
res.Tensor("casted_bias") = cast_op_1(res.Tensor("bias"));
246-
const auto &cast_op_2 =
252+
const auto cast_2_op =
247253
res.Op(paddle::dialect::CastOp::name(), {{"dtype", cast_op_dtype}});
248-
res.Tensor("casted_w") = cast_op_2(res.Tensor("w"));
249-
250254
const auto &fuse_layer_norm =
251255
res.Op(paddle::dialect::FusedBiasResidualLayernormOp::name(),
252256
{{"epsilon", pat.Attr("epsilon")},
@@ -256,14 +260,15 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
256260
{"quant_round_type", res.Int32Attr(0)},
257261
{"quant_max_bound", res.Float32Attr(0.0)},
258262
{"quant_min_bound", res.Float32Attr(0.0)}});
259-
263+
res.Tensor("w_cast") = cast_1_op(res.Tensor("w"));
264+
res.Tensor("bias_cast") = cast_1_op(res.Tensor("bias"));
260265
fuse_layer_norm(
261266
{
262267
&res.Tensor("x"),
263-
&res.Tensor("casted_bias"),
264-
&res.Tensor("residual"),
265-
&res.Tensor("casted_w"),
266268
&res.InputNoneTensor(),
269+
&res.Tensor("residual"),
270+
&res.Tensor("w_cast"),
271+
&res.Tensor("bias_cast"),
267272
},
268273
{&res.Tensor("layer_norm_out"),
269274
&res.Tensor("add_out"),
@@ -272,6 +277,163 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
272277
}
273278
};
274279

280+
class AddGroupNormFusePattern : public paddle::drr::DrrPatternBase {
281+
private:
282+
const bool extra_add_;
283+
const bool trans_extra_add_;
284+
285+
public:
286+
AddGroupNormFusePattern(bool extra_add, bool trans_extra_add)
287+
: extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
288+
289+
uint32_t benefit() const override { return extra_add_ ? 4 : 3; }
290+
std::string name() const override { return "AddGroupNormFusePattern"; }
291+
292+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
293+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
294+
const auto &add = pat.Op(paddle::dialect::AddOp::name());
295+
const auto &group_norm = pat.Op(paddle::dialect::GroupNormOp::name(),
296+
{{"epsilon", pat.Attr("epsilon")},
297+
{"groups", pat.Attr("groups")},
298+
{"data_format", pat.Attr("data_format")}});
299+
pat.Tensor("add_out") = add(pat.Tensor("x"), pat.Tensor("residual"));
300+
group_norm(
301+
{&pat.Tensor("add_out"), &pat.Tensor("scale"), &pat.Tensor("bias")},
302+
{&pat.Tensor("group_out"),
303+
&pat.Tensor("mean_out_0"),
304+
&pat.Tensor("variance_out_0")});
305+
// TODO(bukejiyu) :DRR support matching placeholder op,
306+
// the following needs to be deleted
307+
if (extra_add_) {
308+
const auto &add1 = pat.Op(paddle::dialect::AddOp::name());
309+
pat.Tensor("add_out1") =
310+
trans_extra_add_
311+
? add1(pat.Tensor("any_tensor"), pat.Tensor("add_out"))
312+
: add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
313+
}
314+
pat.AddConstraint([this](const paddle::drr::MatchContext &match_ctx) {
315+
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
316+
if (!x_dtype.isa<pir::Float16Type>() &&
317+
!x_dtype.isa<pir::BFloat16Type>()) {
318+
return false;
319+
}
320+
return true;
321+
});
322+
paddle::drr::ResultPattern res = pat.ResultPattern();
323+
const auto &add_group_norm_silu_op =
324+
res.Op(paddle::dialect::AddGroupNormSiluOp::name(),
325+
{{"epsilon", pat.Attr("epsilon")},
326+
{"groups", pat.Attr("groups")},
327+
{"data_format", pat.Attr("data_format")},
328+
{"activation", res.StrAttr("")}});
329+
330+
add_group_norm_silu_op({&res.Tensor("x"),
331+
&res.Tensor("residual"),
332+
&res.Tensor("scale"),
333+
&res.Tensor("bias")},
334+
{&res.Tensor("group_out"),
335+
&res.Tensor("add_out"),
336+
&res.Tensor("mean_out"),
337+
&res.Tensor("variance_out")});
338+
}
339+
};
340+
341+
class AddGroupNormWithActPattern : public paddle::drr::DrrPatternBase {
342+
public:
343+
uint32_t benefit() const override { return 2; }
344+
std::string name() const override { return "AddGroupNormWithActPattern"; }
345+
346+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
347+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
348+
const auto &add_group_norm_silu_op =
349+
pat.Op(paddle::dialect::AddGroupNormSiluOp::name(),
350+
{{"epsilon", pat.Attr("epsilon")},
351+
{"groups", pat.Attr("groups")},
352+
{"data_format", pat.Attr("data_format")},
353+
{"activation", pat.Attr("activation")}});
354+
const auto &silu = pat.Op(paddle::dialect::SiluOp::name());
355+
add_group_norm_silu_op({&pat.Tensor("x"),
356+
&pat.Tensor("residual"),
357+
&pat.Tensor("scale"),
358+
&pat.Tensor("bias")},
359+
{&pat.Tensor("group_out"),
360+
&pat.Tensor("add_out"),
361+
&pat.Tensor("mean_out_0"),
362+
&pat.Tensor("variance_out_0")});
363+
pat.Tensor("silu_out") = silu(pat.Tensor("group_out"));
364+
pat.AddConstraint([this](const paddle::drr::MatchContext &match_ctx) {
365+
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
366+
if (!x_dtype.isa<pir::Float16Type>() &&
367+
!x_dtype.isa<pir::BFloat16Type>()) {
368+
return false;
369+
}
370+
auto activation = match_ctx.Attr<std::string>("activation");
371+
if (activation != "") {
372+
return false;
373+
}
374+
return true;
375+
});
376+
paddle::drr::ResultPattern res = pat.ResultPattern();
377+
const auto &res_add_group_norm_silu_op =
378+
res.Op(paddle::dialect::AddGroupNormSiluOp::name(),
379+
{{"epsilon", pat.Attr("epsilon")},
380+
{"groups", pat.Attr("groups")},
381+
{"data_format", pat.Attr("data_format")},
382+
{"activation", res.StrAttr("silu")}});
383+
res_add_group_norm_silu_op({&res.Tensor("x"),
384+
&res.Tensor("residual"),
385+
&res.Tensor("scale"),
386+
&res.Tensor("bias")},
387+
{&res.Tensor("silu_out"),
388+
&res.Tensor("add_out"),
389+
&res.Tensor("mean_out"),
390+
&res.Tensor("variance_out")});
391+
}
392+
};
393+
394+
class GroupNormWithActPattern : public paddle::drr::DrrPatternBase {
395+
public:
396+
uint32_t benefit() const override { return 1; }
397+
std::string name() const override { return "GroupNormWithActPattern"; }
398+
399+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
400+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
401+
const auto &group_norm = pat.Op(paddle::dialect::GroupNormOp::name(),
402+
{{"epsilon", pat.Attr("epsilon")},
403+
{"groups", pat.Attr("groups")},
404+
{"data_format", pat.Attr("data_format")}});
405+
const auto &silu = pat.Op(paddle::dialect::SiluOp::name());
406+
group_norm({&pat.Tensor("x"), &pat.Tensor("scale"), &pat.Tensor("bias")},
407+
{&pat.Tensor("group_out"),
408+
&pat.Tensor("mean_out_0"),
409+
&pat.Tensor("variance_out_0")});
410+
pat.Tensor("silu_out") = silu(pat.Tensor("group_out"));
411+
pat.AddConstraint([this](const paddle::drr::MatchContext &match_ctx) {
412+
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
413+
if (!x_dtype.isa<pir::Float16Type>() &&
414+
!x_dtype.isa<pir::BFloat16Type>()) {
415+
return false;
416+
}
417+
return true;
418+
});
419+
paddle::drr::ResultPattern res = pat.ResultPattern();
420+
const auto &add_group_norm_silu_op =
421+
res.Op(paddle::dialect::AddGroupNormSiluOp::name(),
422+
{{"epsilon", pat.Attr("epsilon")},
423+
{"groups", pat.Attr("groups")},
424+
{"data_format", pat.Attr("data_format")},
425+
{"activation", res.StrAttr("silu")}});
426+
add_group_norm_silu_op({&res.Tensor("x"),
427+
&res.InputNoneTensor(),
428+
&res.Tensor("scale"),
429+
&res.Tensor("bias")},
430+
{&res.Tensor("silu_out"),
431+
&res.OutputNoneTensor(),
432+
&res.Tensor("mean_out"),
433+
&res.Tensor("variance_out")});
434+
}
435+
};
436+
275437
class AddNormFusePass : public pir::PatternRewritePass {
276438
public:
277439
AddNormFusePass() : pir::PatternRewritePass("add_norm_fuse_pass", 2) {}
@@ -290,13 +452,37 @@ class AddNormFusePass : public pir::PatternRewritePass {
290452
// x--------
291453
// add-rms_norm ---> rms_norm
292454
// residual-
293-
ps.Add(paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add));
294-
ps.Add(paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add));
455+
ps.Add(
456+
paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add, false));
457+
ps.Add(
458+
paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add, true));
459+
ps.Add(
460+
paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add, false));
461+
295462
// x--------
296463
// add-layer_norm ----> fused_bias_residual_layernorm
297464
// residual-
298-
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(context, !extra_add));
299-
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add));
465+
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(
466+
context, !extra_add, false));
467+
ps.Add(
468+
paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add, true));
469+
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(
470+
context, extra_add, false));
471+
472+
// x--------
473+
// add-group_norm ----> add_group_norm_silu
474+
// residual-
475+
ps.Add(paddle::drr::Create<AddGroupNormFusePattern>(
476+
context, !extra_add, true));
477+
ps.Add(
478+
paddle::drr::Create<AddGroupNormFusePattern>(context, extra_add, true));
479+
ps.Add(paddle::drr::Create<AddGroupNormFusePattern>(
480+
context, extra_add, false));
481+
482+
// add_group_norm_silu-silu --->add_group_norm_silu
483+
ps.Add(paddle::drr::Create<AddGroupNormWithActPattern>(context));
484+
// group-silu->add_group_norm_silu
485+
ps.Add(paddle::drr::Create<GroupNormWithActPattern>(context));
300486
return ps;
301487
}
302488
};

0 commit comments

Comments
 (0)