1212#include " mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1313#include " mlir/Dialect/MemRef/IR/MemRef.h"
1414#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15- #include " mlir/IR/ValueRange.h"
1615#include " mlir/Transforms/WalkPatternRewriteDriver.h"
1716#include " llvm/ADT/TypeSwitch.h"
1817
@@ -29,49 +28,50 @@ struct AmdgpuFoldMemRefOpsPass final
2928 }
3029};
3130
32-
33- static LogicalResult foldMemrefViewOp (
34- PatternRewriter &rewriter, Location loc,
35- Value view, mlir::OperandRange indices,
36- SmallVectorImpl<Value> &resolvedIndices,
37- Value &memrefBase, StringRef role)
38- {
39- Operation *defOp = view.getDefiningOp ();
40- if (!defOp) {
41- return failure ();
42- }
43- return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
44- .Case <memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
45- mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides (
46- rewriter, loc, subviewOp.getMixedOffsets (),
47- subviewOp.getMixedStrides (), subviewOp.getDroppedDims (),
48- indices, resolvedIndices);
49- memrefBase = subviewOp.getSource ();
50- return success ();
51- })
52- .Case <memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
53- if (failed (mlir::memref::resolveSourceIndicesExpandShape (
54- loc, rewriter, expandShapeOp, indices, resolvedIndices, false ))) {
55- return failure ();
56- }
57- memrefBase = expandShapeOp.getViewSource ();
58- return success ();
59- })
60- .Case <memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
31+ static LogicalResult foldMemrefViewOp (PatternRewriter &rewriter, Location loc,
32+ Value view, mlir::OperandRange indices,
33+ SmallVectorImpl<Value> &resolvedIndices,
34+ Value &memrefBase, StringRef role) {
35+ Operation *defOp = view.getDefiningOp ();
36+ if (!defOp) {
37+ return failure ();
38+ }
39+ return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
40+ .Case <memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
41+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides (
42+ rewriter, loc, subviewOp.getMixedOffsets (),
43+ subviewOp.getMixedStrides (), subviewOp.getDroppedDims (), indices,
44+ resolvedIndices);
45+ memrefBase = subviewOp.getSource ();
46+ return success ();
47+ })
48+ .Case <memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
49+ if (failed (mlir::memref::resolveSourceIndicesExpandShape (
50+ loc, rewriter, expandShapeOp, indices, resolvedIndices,
51+ false ))) {
52+ return failure ();
53+ }
54+ memrefBase = expandShapeOp.getViewSource ();
55+ return success ();
56+ })
57+ .Case <memref::CollapseShapeOp>(
58+ [&](memref::CollapseShapeOp collapseShapeOp) {
6159 if (failed (mlir::memref::resolveSourceIndicesCollapseShape (
62- loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
63- return failure ();
60+ loc, rewriter, collapseShapeOp, indices,
61+ resolvedIndices))) {
62+ return failure ();
6463 }
6564 memrefBase = collapseShapeOp.getViewSource ();
6665 return success ();
67- })
68- .Default ([&](Operation *op) {
69- return rewriter.notifyMatchFailure (
70- op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp" ).str ());
71- });
66+ })
67+ .Default ([&](Operation *op) {
68+ return rewriter.notifyMatchFailure (
69+ op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
70+ " CollapseShapeOp" )
71+ .str ());
72+ });
7273}
7374
74-
7575struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
7676 using OpRewritePattern::OpRewritePattern;
7777 LogicalResult matchAndRewrite (GatherToLDSOp op,
@@ -81,26 +81,27 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
8181 SmallVector<Value> sourceIndices, destIndices;
8282 Value memrefSource, memrefDest;
8383
84- auto foldSrcResult = foldMemrefViewOp (
85- rewriter, loc, op.getSrc (), op.getSrcIndices (), sourceIndices, memrefSource, " source" );
86-
84+ auto foldSrcResult =
85+ foldMemrefViewOp (rewriter, loc, op.getSrc (), op.getSrcIndices (),
86+ sourceIndices, memrefSource, " source" );
87+
8788 if (failed (foldSrcResult)) {
88- memrefSource = op.getSrc ();
89- sourceIndices = op.getSrcIndices ();
89+ memrefSource = op.getSrc ();
90+ sourceIndices = op.getSrcIndices ();
9091 }
9192
92- auto foldDstResult = foldMemrefViewOp (
93- rewriter, loc, op.getDst (), op.getDstIndices (), destIndices, memrefDest, " destination" );
93+ auto foldDstResult =
94+ foldMemrefViewOp (rewriter, loc, op.getDst (), op.getDstIndices (),
95+ destIndices, memrefDest, " destination" );
9496
9597 if (failed (foldDstResult)) {
96- memrefDest = op.getDst ();
97- destIndices = op.getDstIndices ();
98- }
99-
98+ memrefDest = op.getDst ();
99+ destIndices = op.getDstIndices ();
100+ }
100101
101102 rewriter.replaceOpWithNewOp <GatherToLDSOp>(op, memrefSource, sourceIndices,
102- memrefDest, destIndices,
103- op.getTransferType ());
103+ memrefDest, destIndices,
104+ op.getTransferType ());
104105
105106 return success ();
106107 }
0 commit comments