Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,7 @@ def AMDGPU_GatherToLDSOp :
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
}];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def AMDGPU_TransposeLoadOp :
Expand Down
56 changes: 56 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//

LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
Expand Down Expand Up @@ -546,6 +550,58 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}

namespace {
/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
/// information or changes layout, the cast can be skipped.
struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;

LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
PatternRewriter &rewriter) const override {
bool modified = false;

// Check source.
if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());

if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
fromType.getElementType() == toType.getElementType()) {
rewriter.modifyOpInPlace(gatherOp, [&] {
gatherOp.getSrcMutable().assign(castOp.getSource());
});
modified = true;
}
}

// Check target.
if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());

if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
fromType.getElementType() == toType.getElementType()) {
rewriter.modifyOpInPlace(gatherOp, [&] {
gatherOp.getDstMutable().assign(castOp.getSource());
});
modified = true;
}
}

return success(modified);
}
};
} // namespace

void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldGatherToLDSOfCast>(context);
}

//===----------------------------------------------------------------------===//
// TransposeLoadOp
//===----------------------------------------------------------------------===//

LogicalResult TransposeLoadOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());

Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/AMDGPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,32 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) {
amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32
func.return
}

// -----

// CHECK-LABEL: func @fold_gather_to_lds_of_cast
func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
%c0 = arith.constant 0 : index
%0 = memref.cast %global : memref<128x72xf32, 1> to memref<?x?xf32, 1>
// CHECK: amdgpu.gather_to_lds %[[GLOBAL]]
// CHECK-SAME: : f32, memref<128x72xf32, 1>
amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0]
: f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
func.return
}

// -----

// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest
func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3>
%c0 = arith.constant 0 : index
%0 = memref.cast %lds : memref<64x64xf32, 3> to memref<?x?xf32, 3>
// CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]]
// CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3>
amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0]
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
func.return
}
Loading