Skip to content

Commit bc4c650

Browse files
committed
Update comment + address comments + add dest test
1 parent 60d5258 commit bc4c650

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,9 @@ LogicalResult GatherToLDSOp::verify() {
551551
}
552552

553553
namespace {
554-
/// If the source/target of a CopyOp is a CastOp that does not modify the shape
555-
/// and element type, the cast can be skipped. Such CastOps only cast the layout
556-
/// of the type.
557-
struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
554+
/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
555+
/// information or changes layout, the cast can be skipped.
556+
struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
558557
using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
559558

560559
LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
@@ -563,10 +562,10 @@ struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
563562

564563
// Check source.
565564
if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
566-
auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
567-
auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
565+
auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
566+
auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
568567

569-
if (fromType && toType &&
568+
if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
570569
fromType.getElementType() == toType.getElementType()) {
571570
rewriter.modifyOpInPlace(gatherOp, [&] {
572571
gatherOp.getSrcMutable().assign(castOp.getSource());
@@ -577,10 +576,10 @@ struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
577576

578577
// Check target.
579578
if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
580-
auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
581-
auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
579+
auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
580+
auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
582581

583-
if (fromType && toType &&
582+
if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
584583
fromType.getElementType() == toType.getElementType()) {
585584
rewriter.modifyOpInPlace(gatherOp, [&] {
586585
gatherOp.getDstMutable().assign(castOp.getSource());

mlir/test/Dialect/AMDGPU/canonicalize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,18 @@ func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memr
144144
: f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
145145
func.return
146146
}
147+
148+
// -----
149+
150+
// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest
151+
func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
152+
// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
153+
// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3>
154+
%c0 = arith.constant 0 : index
155+
%0 = memref.cast %lds : memref<64x64xf32, 3> to memref<?x?xf32, 3>
156+
// CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]]
157+
// CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3>
158+
amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0]
159+
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
160+
func.return
161+
}

0 commit comments

Comments
 (0)