diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index b237f7b5749e7..92aacdaef4136 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -921,6 +921,7 @@ def AMDGPU_GatherToLDSOp : $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst) }]; let hasVerifier = 1; + let hasCanonicalizer = 1; } def AMDGPU_TransposeLoadOp : diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 18e8270f5aa99..9a0a230e8abca 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// GatherToLDSOp +//===----------------------------------------------------------------------===// + LogicalResult GatherToLDSOp::verify() { MemRefType srcType = cast(getSrc().getType()); MemRefType dstType = cast(getDst().getType()); @@ -546,6 +550,42 @@ LogicalResult GatherToLDSOp::verify() { return success(); } +namespace { +/// If the source/target of a GatherToLDSOp is a CastOp that only removes static +/// information or changes layout, the cast can be skipped. +struct FoldGatherToLDSOfCast final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherToLDSOp gatherOp, + PatternRewriter &rewriter) const override { + bool modified = false; + auto foldCast = [&](OpOperand &operand) { + if (auto castOp = operand.get().getDefiningOp()) { + if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { + rewriter.modifyOpInPlace(gatherOp, + [&] { operand.assign(castOp.getSource()); }); + modified = true; + } + } + }; + + foldCast(gatherOp.getSrcMutable()); + foldCast(gatherOp.getDstMutable()); + + return success(modified); + } +}; +} // namespace + +void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//===----------------------------------------------------------------------===// +// TransposeLoadOp +//===----------------------------------------------------------------------===// + LogicalResult TransposeLoadOp::verify() { MemRefType srcType = cast(getSrc().getType()); diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index 4559e39cf0569..5501ad42dbd90 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -130,3 +130,32 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) { amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32 func.return } + +// ----- + +// CHECK-LABEL: func @fold_gather_to_lds_of_cast +func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) { +// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1> + %c0 = arith.constant 0 : index + %0 = memref.cast %global : memref<128x72xf32, 1> to memref + // CHECK: amdgpu.gather_to_lds %[[GLOBAL]] + // CHECK-SAME: : f32, memref<128x72xf32, 1> + amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0] + : f32, memref, memref<64x64xf32, 3> + func.return +} + +// ----- + +// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest +func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) { +// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1> +// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3> + %c0 = arith.constant 0 : index + %0 = memref.cast %lds : memref<64x64xf32, 3> to memref + // CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]] + // CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3> + amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0] + : f32, memref<128x72xf32, 1>, memref + func.return +}