From 60d5258c3e23eafd2088cbe66a1c2f3d09ccb842 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 24 Jul 2025 15:02:41 -0400 Subject: [PATCH 1/5] [mlir][AMDGPU] Add canonicalizer for folding casts into gather_to_lds --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 1 + mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 57 +++++++++++++++++++ mlir/test/Dialect/AMDGPU/canonicalize.mlir | 14 +++++ 3 files changed, 72 insertions(+) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index b237f7b5749e7..92aacdaef4136 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -921,6 +921,7 @@ def AMDGPU_GatherToLDSOp : $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst) }]; let hasVerifier = 1; + let hasCanonicalizer = 1; } def AMDGPU_TransposeLoadOp : diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 18e8270f5aa99..28eb99600f48b 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// GatherToLDSOp +//===----------------------------------------------------------------------===// + LogicalResult GatherToLDSOp::verify() { MemRefType srcType = cast(getSrc().getType()); MemRefType dstType = cast(getDst().getType()); @@ -546,6 +550,59 @@ LogicalResult GatherToLDSOp::verify() { return success(); } +namespace { +/// If the source/target of a CopyOp is a CastOp that does not modify the shape +/// and element type, the cast can be skipped. Such CastOps only cast the layout +/// of the type. +struct FoldGatherToLDSOfCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherToLDSOp gatherOp, + PatternRewriter &rewriter) const override { + bool modified = false; + + // Check source. + if (auto castOp = gatherOp.getSrc().getDefiningOp()) { + auto fromType = llvm::dyn_cast(castOp.getSource().getType()); + auto toType = llvm::dyn_cast(castOp.getSource().getType()); + + if (fromType && toType && + fromType.getElementType() == toType.getElementType()) { + rewriter.modifyOpInPlace(gatherOp, [&] { + gatherOp.getSrcMutable().assign(castOp.getSource()); + }); + modified = true; + } + } + + // Check target. + if (auto castOp = gatherOp.getDst().getDefiningOp()) { + auto fromType = llvm::dyn_cast(castOp.getSource().getType()); + auto toType = llvm::dyn_cast(castOp.getSource().getType()); + + if (fromType && toType && + fromType.getElementType() == toType.getElementType()) { + rewriter.modifyOpInPlace(gatherOp, [&] { + gatherOp.getDstMutable().assign(castOp.getSource()); + }); + modified = true; + } + } + + return success(modified); + } +}; +} // namespace + +void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//===----------------------------------------------------------------------===// +// TransposeLoadOp +//===----------------------------------------------------------------------===// + LogicalResult TransposeLoadOp::verify() { MemRefType srcType = cast(getSrc().getType()); diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index 4559e39cf0569..19f258f439bbf 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -130,3 +130,17 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) { amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32 func.return } + +// ----- + +// CHECK-LABEL: func @fold_gather_to_lds_of_cast +func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) { +// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1> + %c0 = arith.constant 0 : index + %0 = memref.cast %global : memref<128x72xf32, 1> to memref + // CHECK: amdgpu.gather_to_lds %[[GLOBAL]] + // CHECK-SAME: : f32, memref<128x72xf32, 1> + amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0] + : f32, memref, memref<64x64xf32, 3> + func.return +} From bc4c650f713f9bc22f18b3b7c701363ab4662813 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 24 Jul 2025 15:38:57 -0400 Subject: [PATCH 2/5] Update comment + address comments + add dest test --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 19 +++++++++---------- mlir/test/Dialect/AMDGPU/canonicalize.mlir | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 28eb99600f48b..823f0c041231d 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -551,10 +551,9 @@ LogicalResult GatherToLDSOp::verify() { } namespace { -/// If the source/target of a CopyOp is a CastOp that does not modify the shape -/// and element type, the cast can be skipped. Such CastOps only cast the layout -/// of the type. -struct FoldGatherToLDSOfCast : public OpRewritePattern { +/// If the source/target of a GatherToLDSOp is a CastOp that only removes static +/// information or changes layout, the cast can be skipped. +struct FoldGatherToLDSOfCast final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherToLDSOp gatherOp, @@ -563,10 +562,10 @@ struct FoldGatherToLDSOfCast : public OpRewritePattern { // Check source. if (auto castOp = gatherOp.getSrc().getDefiningOp()) { - auto fromType = llvm::dyn_cast(castOp.getSource().getType()); - auto toType = llvm::dyn_cast(castOp.getSource().getType()); + auto fromType = dyn_cast(castOp.getSource().getType()); + auto toType = dyn_cast(castOp.getSource().getType()); - if (fromType && toType && + if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType && fromType.getElementType() == toType.getElementType()) { rewriter.modifyOpInPlace(gatherOp, [&] { gatherOp.getSrcMutable().assign(castOp.getSource()); @@ -577,10 +576,10 @@ struct FoldGatherToLDSOfCast : public OpRewritePattern { // Check target. if (auto castOp = gatherOp.getDst().getDefiningOp()) { - auto fromType = llvm::dyn_cast(castOp.getSource().getType()); - auto toType = llvm::dyn_cast(castOp.getSource().getType()); + auto fromType = dyn_cast(castOp.getSource().getType()); + auto toType = dyn_cast(castOp.getSource().getType()); - if (fromType && toType && + if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType && fromType.getElementType() == toType.getElementType()) { rewriter.modifyOpInPlace(gatherOp, [&] { gatherOp.getDstMutable().assign(castOp.getSource()); diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index 19f258f439bbf..5501ad42dbd90 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -144,3 +144,18 @@ func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memr : f32, memref, memref<64x64xf32, 3> func.return } + +// ----- + +// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest +func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) { +// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1> +// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3> + %c0 = arith.constant 0 : index + %0 = memref.cast %lds : memref<64x64xf32, 3> to memref + // CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]] + // CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3> + amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0] + : f32, memref<128x72xf32, 1>, memref + func.return +} From acda939fd53fa2354d8838a2fda1474debefd58d Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 24 Jul 2025 16:14:09 -0400 Subject: [PATCH 3/5] Drop redundant lines --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 823f0c041231d..626808d8586f4 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -562,11 +562,7 @@ struct FoldGatherToLDSOfCast final : OpRewritePattern { // Check source. if (auto castOp = gatherOp.getSrc().getDefiningOp()) { - auto fromType = dyn_cast(castOp.getSource().getType()); - auto toType = dyn_cast(castOp.getSource().getType()); - - if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType && - fromType.getElementType() == toType.getElementType()) { + if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { rewriter.modifyOpInPlace(gatherOp, [&] { gatherOp.getSrcMutable().assign(castOp.getSource()); }); @@ -576,11 +572,7 @@ struct FoldGatherToLDSOfCast final : OpRewritePattern { // Check target. if (auto castOp = gatherOp.getDst().getDefiningOp()) { - auto fromType = dyn_cast(castOp.getSource().getType()); - auto toType = dyn_cast(castOp.getSource().getType()); - - if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType && - fromType.getElementType() == toType.getElementType()) { + if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { rewriter.modifyOpInPlace(gatherOp, [&] { gatherOp.getDstMutable().assign(castOp.getSource()); }); From 7f46167392524a284b0d78ac57c9ff2df9379784 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 24 Jul 2025 16:32:18 -0400 Subject: [PATCH 4/5] dedupe --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 28 +++++++------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 626808d8586f4..728029e4b4fcf 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -559,26 +559,18 @@ struct FoldGatherToLDSOfCast final : OpRewritePattern { LogicalResult matchAndRewrite(GatherToLDSOp gatherOp, PatternRewriter &rewriter) const override { bool modified = false; - - // Check source. - if (auto castOp = gatherOp.getSrc().getDefiningOp()) { - if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { - rewriter.modifyOpInPlace(gatherOp, [&] { - gatherOp.getSrcMutable().assign(castOp.getSource()); - }); - modified = true; + auto foldCast = [&](OpOperand &operand) { + if (auto castOp = operand.get().getDefiningOp()) { + if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { + rewriter.modifyOpInPlace(gatherOp, + [&] { operand.assign(castOp.getSource()); }); + modified = true; + } } - } + }; - // Check target. - if (auto castOp = gatherOp.getDst().getDefiningOp()) { - if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { - rewriter.modifyOpInPlace(gatherOp, [&] { - gatherOp.getDstMutable().assign(castOp.getSource()); - }); - modified = true; - } - } + foldCast(gatherOp.getSrcMutable()); + foldCast(gatherOp.getDstMutable()); return success(modified); } From ef2618768c6734f95b63bb2251e70afb88b51820 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 24 Jul 2025 17:06:49 -0400 Subject: [PATCH 5/5] drop name --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 728029e4b4fcf..9a0a230e8abca 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -554,7 +554,7 @@ namespace { /// If the source/target of a GatherToLDSOp is a CastOp that only removes static /// information or changes layout, the cast can be skipped. struct FoldGatherToLDSOfCast final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherToLDSOp gatherOp, PatternRewriter &rewriter) const override {