Skip to content

Commit 2ae5b13

Browse files
committed
Fix formatting and removed unnecessary include file
1 parent 2af05fa commit 2ae5b13

File tree

1 file changed

+52
-51
lines changed

1 file changed

+52
-51
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1313
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1414
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15-
#include "mlir/IR/ValueRange.h"
1615
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
1716
#include "llvm/ADT/TypeSwitch.h"
1817

@@ -29,49 +28,50 @@ struct AmdgpuFoldMemRefOpsPass final
2928
}
3029
};
3130

32-
33-
static LogicalResult foldMemrefViewOp(
34-
PatternRewriter &rewriter, Location loc,
35-
Value view, mlir::OperandRange indices,
36-
SmallVectorImpl<Value> &resolvedIndices,
37-
Value &memrefBase, StringRef role)
38-
{
39-
Operation *defOp = view.getDefiningOp();
40-
if (!defOp) {
41-
return failure();
42-
}
43-
return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
44-
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
45-
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
46-
rewriter, loc, subviewOp.getMixedOffsets(),
47-
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
48-
indices, resolvedIndices);
49-
memrefBase = subviewOp.getSource();
50-
return success();
51-
})
52-
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
53-
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
54-
loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
55-
return failure();
56-
}
57-
memrefBase = expandShapeOp.getViewSource();
58-
return success();
59-
})
60-
.Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
31+
static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
32+
Value view, mlir::OperandRange indices,
33+
SmallVectorImpl<Value> &resolvedIndices,
34+
Value &memrefBase, StringRef role) {
35+
Operation *defOp = view.getDefiningOp();
36+
if (!defOp) {
37+
return failure();
38+
}
39+
return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
40+
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
41+
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
42+
rewriter, loc, subviewOp.getMixedOffsets(),
43+
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
44+
resolvedIndices);
45+
memrefBase = subviewOp.getSource();
46+
return success();
47+
})
48+
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
49+
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
50+
loc, rewriter, expandShapeOp, indices, resolvedIndices,
51+
false))) {
52+
return failure();
53+
}
54+
memrefBase = expandShapeOp.getViewSource();
55+
return success();
56+
})
57+
.Case<memref::CollapseShapeOp>(
58+
[&](memref::CollapseShapeOp collapseShapeOp) {
6159
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
62-
loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
63-
return failure();
60+
loc, rewriter, collapseShapeOp, indices,
61+
resolvedIndices))) {
62+
return failure();
6463
}
6564
memrefBase = collapseShapeOp.getViewSource();
6665
return success();
67-
})
68-
.Default([&](Operation *op) {
69-
return rewriter.notifyMatchFailure(
70-
op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
71-
});
66+
})
67+
.Default([&](Operation *op) {
68+
return rewriter.notifyMatchFailure(
69+
op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
70+
"CollapseShapeOp")
71+
.str());
72+
});
7273
}
7374

74-
7575
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
7676
using OpRewritePattern::OpRewritePattern;
7777
LogicalResult matchAndRewrite(GatherToLDSOp op,
@@ -81,26 +81,27 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
8181
SmallVector<Value> sourceIndices, destIndices;
8282
Value memrefSource, memrefDest;
8383

84-
auto foldSrcResult = foldMemrefViewOp(
85-
rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
86-
84+
auto foldSrcResult =
85+
foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
86+
sourceIndices, memrefSource, "source");
87+
8788
if (failed(foldSrcResult)) {
88-
memrefSource = op.getSrc();
89-
sourceIndices = op.getSrcIndices();
89+
memrefSource = op.getSrc();
90+
sourceIndices = op.getSrcIndices();
9091
}
9192

92-
auto foldDstResult = foldMemrefViewOp(
93-
rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
93+
auto foldDstResult =
94+
foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
95+
destIndices, memrefDest, "destination");
9496

9597
if (failed(foldDstResult)) {
96-
memrefDest = op.getDst();
97-
destIndices = op.getDstIndices();
98-
}
99-
98+
memrefDest = op.getDst();
99+
destIndices = op.getDstIndices();
100+
}
100101

101102
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
102-
memrefDest, destIndices,
103-
op.getTransferType());
103+
memrefDest, destIndices,
104+
op.getTransferType());
104105

105106
return success();
106107
}

0 commit comments

Comments
 (0)