Skip to content

Conversation

qedawkins
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-amdgpu

@llvm/pr-subscribers-backend-amdgpu

Author: Quinn Dawkins (qedawkins)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/150503.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+57)
  • (modified) mlir/test/Dialect/AMDGPU/canonicalize.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b237f7b5749e7..92aacdaef4136 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -921,6 +921,7 @@ def AMDGPU_GatherToLDSOp :
     $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def AMDGPU_TransposeLoadOp :
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 18e8270f5aa99..28eb99600f48b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GatherToLDSOp
+//===----------------------------------------------------------------------===//
+
 LogicalResult GatherToLDSOp::verify() {
   MemRefType srcType = cast<MemRefType>(getSrc().getType());
   MemRefType dstType = cast<MemRefType>(getDst().getType());
@@ -546,6 +550,59 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+namespace {
+/// If the source/target of a CopyOp is a CastOp that does not modify the shape
+/// and element type, the cast can be skipped. Such CastOps only cast the layout
+/// of the type.
+struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
+                                PatternRewriter &rewriter) const override {
+    bool modified = false;
+
+    // Check source.
+    if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
+      auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+      auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+
+      if (fromType && toType &&
+          fromType.getElementType() == toType.getElementType()) {
+        rewriter.modifyOpInPlace(gatherOp, [&] {
+          gatherOp.getSrcMutable().assign(castOp.getSource());
+        });
+        modified = true;
+      }
+    }
+
+    // Check target.
+    if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
+      auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+      auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+
+      if (fromType && toType &&
+          fromType.getElementType() == toType.getElementType()) {
+        rewriter.modifyOpInPlace(gatherOp, [&] {
+          gatherOp.getDstMutable().assign(castOp.getSource());
+        });
+        modified = true;
+      }
+    }
+
+    return success(modified);
+  }
+};
+} // namespace
+
+void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<FoldGatherToLDSOfCast>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// TransposeLoadOp
+//===----------------------------------------------------------------------===//
+
 LogicalResult TransposeLoadOp::verify() {
   MemRefType srcType = cast<MemRefType>(getSrc().getType());
 
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 4559e39cf0569..19f258f439bbf 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -130,3 +130,17 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) {
   amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_gather_to_lds_of_cast
+func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
+// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
+  %c0 = arith.constant 0 : index
+  %0 = memref.cast %global : memref<128x72xf32, 1> to memref<?x?xf32, 1>
+  // CHECK: amdgpu.gather_to_lds %[[GLOBAL]]
+  // CHECK-SAME: : f32, memref<128x72xf32, 1>
+  amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0]
+    : f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
+  func.return
+}

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

@qedawkins qedawkins requested review from kuhar and lialan July 24, 2025 20:23
Copy link
Member

@lialan lialan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@qedawkins qedawkins merged commit b7f889a into llvm:main Jul 24, 2025
65 of 66 checks passed
@qedawkins qedawkins deleted the gather_to_lds_fold_cast branch July 24, 2025 23:58
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants