-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][AMDGPU] Add canonicalizer for folding casts into gather_to_lds #150503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-backend-amdgpu Author: Quinn Dawkins (qedawkins) ChangesFull diff: https://github.com/llvm/llvm-project/pull/150503.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b237f7b5749e7..92aacdaef4136 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -921,6 +921,7 @@ def AMDGPU_GatherToLDSOp :
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
}];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
def AMDGPU_TransposeLoadOp :
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 18e8270f5aa99..28eb99600f48b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// GatherToLDSOp
+//===----------------------------------------------------------------------===//
+
LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
@@ -546,6 +550,59 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+namespace {
+/// If the source/target of a CopyOp is a CastOp that does not modify the shape
+/// and element type, the cast can be skipped. Such CastOps only cast the layout
+/// of the type.
+struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
+ using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
+ PatternRewriter &rewriter) const override {
+ bool modified = false;
+
+ // Check source.
+ if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
+ auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+
+ if (fromType && toType &&
+ fromType.getElementType() == toType.getElementType()) {
+ rewriter.modifyOpInPlace(gatherOp, [&] {
+ gatherOp.getSrcMutable().assign(castOp.getSource());
+ });
+ modified = true;
+ }
+ }
+
+ // Check target.
+ if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
+ auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+
+ if (fromType && toType &&
+ fromType.getElementType() == toType.getElementType()) {
+ rewriter.modifyOpInPlace(gatherOp, [&] {
+ gatherOp.getDstMutable().assign(castOp.getSource());
+ });
+ modified = true;
+ }
+ }
+
+ return success(modified);
+ }
+};
+} // namespace
+
+void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldGatherToLDSOfCast>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// TransposeLoadOp
+//===----------------------------------------------------------------------===//
+
LogicalResult TransposeLoadOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 4559e39cf0569..19f258f439bbf 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -130,3 +130,17 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) {
amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @fold_gather_to_lds_of_cast
+func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
+// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
+ %c0 = arith.constant 0 : index
+ %0 = memref.cast %global : memref<128x72xf32, 1> to memref<?x?xf32, 1>
+ // CHECK: amdgpu.gather_to_lds %[[GLOBAL]]
+ // CHECK-SAME: : f32, memref<128x72xf32, 1>
+ amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0]
+ : f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
+ func.return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
No description provided.