Skip to content

Commit d63ecf5

Browse files
committed
[AMDGPU] fold dst operand of gather to lds
1 parent 4077e66 commit d63ecf5

File tree

1 file changed

+52
-46
lines changed

1 file changed

+52
-46
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1313
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1414
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15+
#include "mlir/IR/ValueRange.h"
1516
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
1617
#include "llvm/ADT/TypeSwitch.h"
1718

@@ -28,63 +29,68 @@ struct AmdgpuFoldMemRefOpsPass final
2829
}
2930
};
3031

32+
33+
static LogicalResult foldMemrefViewOp(
34+
PatternRewriter &rewriter, Location loc,
35+
Value view, mlir::OperandRange indices,
36+
SmallVectorImpl<Value> &resolvedIndices,
37+
Value &memrefBase, StringRef role)
38+
{
39+
Operation *defOp = view.getDefiningOp();
40+
return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
41+
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
42+
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
43+
rewriter, loc, subviewOp.getMixedOffsets(),
44+
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
45+
indices, resolvedIndices);
46+
memrefBase = subviewOp.getSource();
47+
return success();
48+
})
49+
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
50+
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
51+
loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
52+
return failure();
53+
}
54+
memrefBase = expandShapeOp.getViewSource();
55+
return success();
56+
})
57+
.Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
58+
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
59+
loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
60+
return failure();
61+
}
62+
memrefBase = collapseShapeOp.getViewSource();
63+
return success();
64+
})
65+
.Default([&](Operation *op) {
66+
return rewriter.notifyMatchFailure(
67+
op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
68+
});
69+
}
70+
71+
3172
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
3273
using OpRewritePattern::OpRewritePattern;
3374
LogicalResult matchAndRewrite(GatherToLDSOp op,
3475
PatternRewriter &rewriter) const override {
3576
Location loc = op.getLoc();
3677

37-
Value memrefSource;
38-
SmallVector<Value> sourceIndices;
39-
auto foldResult =
40-
llvm::TypeSwitch<Operation *, LogicalResult>(
41-
op.getSrc().getDefiningOp())
42-
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
43-
// If the source is a SubViewOp, we can directly rewrite the
44-
// GatherToLDSOp.
45-
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
46-
rewriter, loc, subviewOp.getMixedOffsets(),
47-
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
48-
op.getSrcIndices(), sourceIndices);
49-
memrefSource = subviewOp.getSource();
50-
return success();
51-
})
52-
.Case<memref::ExpandShapeOp>(
53-
[&](memref::ExpandShapeOp expandShapeOp) {
54-
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
55-
loc, rewriter, expandShapeOp, op.getSrcIndices(),
56-
sourceIndices, false))) {
57-
return failure();
58-
}
59-
memrefSource = expandShapeOp.getViewSource();
60-
return success();
61-
})
62-
.Case<memref::CollapseShapeOp>(
63-
[&](memref::CollapseShapeOp collapseShapeOp) {
64-
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
65-
loc, rewriter, collapseShapeOp, op.getSrcIndices(),
66-
sourceIndices))) {
67-
return failure();
68-
}
69-
memrefSource = collapseShapeOp.getViewSource();
70-
return success();
71-
})
72-
.Default([&](Operation *op) {
73-
// If the source is not a SubViewOp, ExpandShapeOp, or
74-
// CollapseShapeOp, we cannot fold the GatherToLDSOp.
75-
return rewriter.notifyMatchFailure(
76-
op,
77-
"source producer is not one of SubViewOp, ExpandShapeOp, or "
78-
"CollapseShapeOp");
79-
});
78+
SmallVector<Value> sourceIndices, destIndices;
79+
Value memrefSource, memrefDest;
80+
81+
auto foldSrcResult = foldMemrefViewOp(
82+
rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
83+
84+
auto foldDstResult = foldMemrefViewOp(
85+
rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
8086

81-
if (failed(foldResult)) {
87+
if (failed(foldSrcResult) || failed(foldDstResult)) {
8288
return failure();
8389
}
8490

8591
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
86-
op.getDst(), op.getDstIndices(),
87-
op.getTransferType());
92+
memrefDest, destIndices,
93+
op.getTransferType());
8894

8995
return success();
9096
}

0 commit comments

Comments
 (0)