1212#include " mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1313#include " mlir/Dialect/MemRef/IR/MemRef.h"
1414#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15+ #include " mlir/IR/ValueRange.h"
1516#include " mlir/Transforms/WalkPatternRewriteDriver.h"
1617#include " llvm/ADT/TypeSwitch.h"
1718
@@ -28,63 +29,68 @@ struct AmdgpuFoldMemRefOpsPass final
2829 }
2930};
3031
32+
33+ static LogicalResult foldMemrefViewOp (
34+ PatternRewriter &rewriter, Location loc,
35+ Value view, mlir::OperandRange indices,
36+ SmallVectorImpl<Value> &resolvedIndices,
37+ Value &memrefBase, StringRef role)
38+ {
39+ Operation *defOp = view.getDefiningOp ();
40+ return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
41+ .Case <memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
42+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides (
43+ rewriter, loc, subviewOp.getMixedOffsets (),
44+ subviewOp.getMixedStrides (), subviewOp.getDroppedDims (),
45+ indices, resolvedIndices);
46+ memrefBase = subviewOp.getSource ();
47+ return success ();
48+ })
49+ .Case <memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
50+ if (failed (mlir::memref::resolveSourceIndicesExpandShape (
51+ loc, rewriter, expandShapeOp, indices, resolvedIndices, false ))) {
52+ return failure ();
53+ }
54+ memrefBase = expandShapeOp.getViewSource ();
55+ return success ();
56+ })
57+ .Case <memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
58+ if (failed (mlir::memref::resolveSourceIndicesCollapseShape (
59+ loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
60+ return failure ();
61+ }
62+ memrefBase = collapseShapeOp.getViewSource ();
63+ return success ();
64+ })
65+ .Default ([&](Operation *op) {
66+ return rewriter.notifyMatchFailure (
67+ op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp" ).str ());
68+ });
69+ }
70+
71+
3172struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
3273 using OpRewritePattern::OpRewritePattern;
3374 LogicalResult matchAndRewrite (GatherToLDSOp op,
3475 PatternRewriter &rewriter) const override {
3576 Location loc = op.getLoc ();
3677
37- Value memrefSource;
38- SmallVector<Value> sourceIndices;
39- auto foldResult =
40- llvm::TypeSwitch<Operation *, LogicalResult>(
41- op.getSrc ().getDefiningOp ())
42- .Case <memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
43- // If the source is a SubViewOp, we can directly rewrite the
44- // GatherToLDSOp.
45- mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides (
46- rewriter, loc, subviewOp.getMixedOffsets (),
47- subviewOp.getMixedStrides (), subviewOp.getDroppedDims (),
48- op.getSrcIndices (), sourceIndices);
49- memrefSource = subviewOp.getSource ();
50- return success ();
51- })
52- .Case <memref::ExpandShapeOp>(
53- [&](memref::ExpandShapeOp expandShapeOp) {
54- if (failed (mlir::memref::resolveSourceIndicesExpandShape (
55- loc, rewriter, expandShapeOp, op.getSrcIndices (),
56- sourceIndices, false ))) {
57- return failure ();
58- }
59- memrefSource = expandShapeOp.getViewSource ();
60- return success ();
61- })
62- .Case <memref::CollapseShapeOp>(
63- [&](memref::CollapseShapeOp collapseShapeOp) {
64- if (failed (mlir::memref::resolveSourceIndicesCollapseShape (
65- loc, rewriter, collapseShapeOp, op.getSrcIndices (),
66- sourceIndices))) {
67- return failure ();
68- }
69- memrefSource = collapseShapeOp.getViewSource ();
70- return success ();
71- })
72- .Default ([&](Operation *op) {
73- // If the source is not a SubViewOp, ExpandShapeOp, or
74- // CollapseShapeOp, we cannot fold the GatherToLDSOp.
75- return rewriter.notifyMatchFailure (
76- op,
77- " source producer is not one of SubViewOp, ExpandShapeOp, or "
78- " CollapseShapeOp" );
79- });
78+ SmallVector<Value> sourceIndices, destIndices;
79+ Value memrefSource, memrefDest;
80+
81+ auto foldSrcResult = foldMemrefViewOp (
82+ rewriter, loc, op.getSrc (), op.getSrcIndices (), sourceIndices, memrefSource, " source" );
83+
84+ auto foldDstResult = foldMemrefViewOp (
85+ rewriter, loc, op.getDst (), op.getDstIndices (), destIndices, memrefDest, " destination" );
8086
81- if (failed (foldResult )) {
87+ if (failed (foldSrcResult) || failed (foldDstResult )) {
8288 return failure ();
8389 }
8490
8591 rewriter.replaceOpWithNewOp <GatherToLDSOp>(op, memrefSource, sourceIndices,
86- op. getDst (), op. getDstIndices () ,
87- op.getTransferType ());
92+ memrefDest, destIndices ,
93+ op.getTransferType ());
8894
8995 return success ();
9096 }
0 commit comments