Skip to content

Commit f0a0285

Browse files
bukejiyuyuanlehome
andauthored
[cherry-pick][bug fix] cp pr64879 and pr64859 (#64897)
* Fix match constraint of matmul_add_act_fuse_pass (#64879) * cp pr64859 --------- Co-authored-by: Yuanle Liu <[email protected]>
1 parent ed208aa commit f0a0285

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,7 @@ paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen_tmp/*
106106
paddle/fluid/pybind/static_op_function.*
107107
paddle/fluid/pybind/ops_api.cc
108108
python/paddle/tensor/tensor.pyi
109+
paddle/phi/kernels/fusion/cutlass/conv2d/build
110+
paddle/phi/kernels/fusion/cutlass/conv2d/cutlass
111+
paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build
112+
paddle/phi/kernels/fusion/cutlass/gemm_epilogue/cutlass

paddle/fluid/pir/transforms/gpu/matmul_add_act_fuse_pass.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,17 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase {
6060

6161
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
6262
auto w_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("w"));
63-
if (!w_dtype.isa<pir::Float16Type>() &&
64-
!w_dtype.isa<pir::BFloat16Type>() &&
65-
!w_dtype.isa<pir::Float32Type>() &&
66-
!w_dtype.isa<pir::Float64Type>()) {
67-
return false;
63+
if (fused_op_name_ == paddle::dialect::GemmEpilogueOp::name()) {
64+
if (!w_dtype.isa<pir::Float16Type>() &&
65+
!w_dtype.isa<pir::BFloat16Type>()) {
66+
return false;
67+
}
68+
} else {
69+
if (!w_dtype.isa<pir::Float16Type>() &&
70+
!w_dtype.isa<pir::Float32Type>() &&
71+
!w_dtype.isa<pir::Float64Type>()) {
72+
return false;
73+
}
6874
}
6975
auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w"));
7076
auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x"));

paddle/phi/kernels/fusion/cutlass/gemm_epilogue_kernel.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,5 @@ PD_REGISTER_KERNEL(gemm_epilogue,
204204
GPU,
205205
ALL_LAYOUT,
206206
phi::fusion::cutlass_internal::GemmEpilogueKernel,
207-
float,
208207
phi::dtype::bfloat16,
209208
phi::dtype::float16) {}

0 commit comments

Comments
 (0)