Skip to content

Commit c206c37

Browse files
lialanmahesh-attarde
authored andcommitted
Reapply "[AMDGPU] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds" (llvm#150334)
This is a reapply of patch llvm#149851. The reapply also fixes a CMake/Bazel build issue, which was the reason of the revert. (Thanks @rupprecht ) Original patch (llvm#149851) message: ----- This PR adds a new optimization pass to fold `memref.subview/expand_shape/collapse_shape` ops into consumer `amdgpu.gather_to_lds` operations. * Implements a new pass `AmdgpuFoldMemRefOpsPass` with pattern `FoldMemRefOpsIntoGatherToLDSOp` * Adds corresponding folding tests
1 parent e6899ce commit c206c37

File tree

9 files changed

+315
-93
lines changed

9 files changed

+315
-93
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ class ConversionTarget;
2222
namespace amdgpu {
2323

2424
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
25-
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
25+
#define GEN_PASS_DECL_AMDGPUFOLDMEMREFOPSPASS
2626
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
27+
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
2728
#define GEN_PASS_REGISTRATION
2829
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
2930

@@ -38,6 +39,9 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
3839
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
3940
PatternBenefit benefit = 1);
4041

42+
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
43+
PatternBenefit benefit = 1);
44+
4145
} // namespace amdgpu
4246
} // namespace mlir
4347

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,16 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
7070
"memref::MemRefDialect"
7171
];
7272
}
73+
74+
def AmdgpuFoldMemRefOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
75+
let summary = "Fold memref operations into their parent operations";
76+
let description = [{
77+
This pass identifies memref operations (subview, expand_shape, collapse_shape)
78+
that are sources of `GatherToLDSOp` and attempts to fold the source ops,
79+
potentially simplifying the overall operation and improving performance.
80+
}];
81+
let dependentDialects = [
82+
"memref::MemRefDialect"
83+
];
84+
}
7385
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,43 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
116116
/// the source memref (i.e. implements ViewLikeOpInterface).
117117
MemrefValue skipViewLikeOps(MemrefValue source);
118118

119+
/// Given the 'indices' of a load/store operation where the memref is a result
120+
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
121+
/// expand_shape op. For example
122+
///
123+
/// %0 = ... : memref<12x42xf32>
124+
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
125+
/// : memref<12x42xf32> into memref<2x6x42xf32>
126+
/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
127+
///
128+
/// could be folded into
129+
///
130+
/// %2 = load %0[6 * i1 + i2, %i3] :
131+
/// memref<12x42xf32>
132+
LogicalResult resolveSourceIndicesExpandShape(
133+
Location loc, PatternRewriter &rewriter,
134+
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
135+
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
136+
137+
/// Given the 'indices' of a load/store operation where the memref is a result
138+
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
139+
/// the collapse_shape op. For example
140+
///
141+
/// %0 = ... : memref<2x6x42xf32>
142+
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
143+
/// : memref<2x6x42xf32> into memref<12x42xf32>
144+
/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
145+
///
146+
/// could be folded into
147+
///
148+
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
149+
/// memref<2x6x42xf32>
150+
LogicalResult
151+
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
152+
memref::CollapseShapeOp collapseShapeOp,
153+
ValueRange indices,
154+
SmallVectorImpl<Value> &sourceIndices);
155+
119156
} // namespace memref
120157
} // namespace mlir
121158

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
add_mlir_dialect_library(MLIRAMDGPUTransforms
22
EmulateAtomics.cpp
3-
ResolveStridedMetadata.cpp
3+
FoldMemRefsOps.cpp
44
MaskedloadToLoad.cpp
5+
ResolveStridedMetadata.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
@@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
1213
LINK_LIBS PUBLIC
1314
MLIRAMDGPUDialect
1415
MLIRAMDGPUUtils
16+
MLIRAffineUtils
1517
MLIRArithDialect
1618
MLIRMemRefDialect
1719
MLIRSCFDialect
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
13+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
16+
#include "llvm/ADT/TypeSwitch.h"
17+
18+
namespace mlir::amdgpu {
19+
#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
20+
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
21+
22+
struct AmdgpuFoldMemRefOpsPass final
23+
: amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
24+
void runOnOperation() override {
25+
RewritePatternSet patterns(&getContext());
26+
populateAmdgpuFoldMemRefOpsPatterns(patterns);
27+
walkAndApplyPatterns(getOperation(), std::move(patterns));
28+
}
29+
};
30+
31+
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
32+
using OpRewritePattern::OpRewritePattern;
33+
LogicalResult matchAndRewrite(GatherToLDSOp op,
34+
PatternRewriter &rewriter) const override {
35+
Location loc = op.getLoc();
36+
37+
Value memrefSource;
38+
SmallVector<Value> sourceIndices;
39+
auto foldResult =
40+
llvm::TypeSwitch<Operation *, LogicalResult>(
41+
op.getSrc().getDefiningOp())
42+
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
43+
// If the source is a SubViewOp, we can directly rewrite the
44+
// GatherToLDSOp.
45+
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
46+
rewriter, loc, subviewOp.getMixedOffsets(),
47+
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
48+
op.getSrcIndices(), sourceIndices);
49+
memrefSource = subviewOp.getSource();
50+
return success();
51+
})
52+
.Case<memref::ExpandShapeOp>(
53+
[&](memref::ExpandShapeOp expandShapeOp) {
54+
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
55+
loc, rewriter, expandShapeOp, op.getSrcIndices(),
56+
sourceIndices, false))) {
57+
return failure();
58+
}
59+
memrefSource = expandShapeOp.getViewSource();
60+
return success();
61+
})
62+
.Case<memref::CollapseShapeOp>(
63+
[&](memref::CollapseShapeOp collapseShapeOp) {
64+
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
65+
loc, rewriter, collapseShapeOp, op.getSrcIndices(),
66+
sourceIndices))) {
67+
return failure();
68+
}
69+
memrefSource = collapseShapeOp.getViewSource();
70+
return success();
71+
})
72+
.Default([&](Operation *op) {
73+
// If the source is not a SubViewOp, ExpandShapeOp, or
74+
// CollapseShapeOp, we cannot fold the GatherToLDSOp.
75+
return rewriter.notifyMatchFailure(
76+
op,
77+
"source producer is not one of SubViewOp, ExpandShapeOp, or "
78+
"CollapseShapeOp");
79+
});
80+
81+
if (failed(foldResult)) {
82+
return failure();
83+
}
84+
85+
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
86+
op.getDst(), op.getDstIndices(),
87+
op.getTransferType());
88+
89+
return success();
90+
}
91+
};
92+
93+
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
94+
PatternBenefit benefit) {
95+
patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
96+
}
97+
} // namespace mlir::amdgpu

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -44,97 +44,6 @@ using namespace mlir;
4444
// Utility functions
4545
//===----------------------------------------------------------------------===//
4646

47-
/// Given the 'indices' of a load/store operation where the memref is a result
48-
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
49-
/// expand_shape op. For example
50-
///
51-
/// %0 = ... : memref<12x42xf32>
52-
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
53-
/// : memref<12x42xf32> into memref<2x6x42xf32>
54-
/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
55-
///
56-
/// could be folded into
57-
///
58-
/// %2 = load %0[6 * i1 + i2, %i3] :
59-
/// memref<12x42xf32>
60-
static LogicalResult resolveSourceIndicesExpandShape(
61-
Location loc, PatternRewriter &rewriter,
62-
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
63-
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
64-
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
65-
66-
// Traverse all reassociation groups to determine the appropriate indices
67-
// corresponding to each one of them post op folding.
68-
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
69-
assert(!group.empty() && "association indices groups cannot be empty");
70-
int64_t groupSize = group.size();
71-
if (groupSize == 1) {
72-
sourceIndices.push_back(indices[group[0]]);
73-
continue;
74-
}
75-
SmallVector<OpFoldResult> groupBasis =
76-
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
77-
SmallVector<Value> groupIndices =
78-
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
79-
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
80-
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
81-
sourceIndices.push_back(collapsedIndex);
82-
}
83-
return success();
84-
}
85-
86-
/// Given the 'indices' of a load/store operation where the memref is a result
87-
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
88-
/// the collapse_shape op. For example
89-
///
90-
/// %0 = ... : memref<2x6x42xf32>
91-
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
92-
/// : memref<2x6x42xf32> into memref<12x42xf32>
93-
/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
94-
///
95-
/// could be folded into
96-
///
97-
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
98-
/// memref<2x6x42xf32>
99-
static LogicalResult
100-
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
101-
memref::CollapseShapeOp collapseShapeOp,
102-
ValueRange indices,
103-
SmallVectorImpl<Value> &sourceIndices) {
104-
// Note: collapse_shape requires a strided memref, we can do this.
105-
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
106-
loc, collapseShapeOp.getSrc());
107-
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
108-
for (auto [index, group] :
109-
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
110-
assert(!group.empty() && "association indices groups cannot be empty");
111-
int64_t groupSize = group.size();
112-
113-
if (groupSize == 1) {
114-
sourceIndices.push_back(index);
115-
continue;
116-
}
117-
118-
SmallVector<OpFoldResult> basis =
119-
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
120-
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
121-
loc, index, basis, /*hasOuterBound=*/true);
122-
llvm::append_range(sourceIndices, delinearize.getResults());
123-
}
124-
if (collapseShapeOp.getReassociationIndices().empty()) {
125-
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
126-
int64_t srcRank =
127-
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
128-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
129-
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
130-
for (int64_t i = 0; i < srcRank; i++) {
131-
sourceIndices.push_back(
132-
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
133-
}
134-
}
135-
return success();
136-
}
137-
13847
/// Helpers to access the memref operand for each op.
13948
template <typename LoadOrStoreOpTy>
14049
static Value getMemRefOperand(LoadOrStoreOpTy op) {

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1414
#include "mlir/Dialect/Affine/IR/AffineOps.h"
15+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1516
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1617
#include "mlir/Interfaces/ViewLikeInterface.h"
1718
#include "llvm/ADT/STLExtras.h"
@@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
217218
return source;
218219
}
219220

221+
LogicalResult resolveSourceIndicesExpandShape(
222+
Location loc, PatternRewriter &rewriter,
223+
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
224+
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
225+
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
226+
227+
// Traverse all reassociation groups to determine the appropriate indices
228+
// corresponding to each one of them post op folding.
229+
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
230+
assert(!group.empty() && "association indices groups cannot be empty");
231+
int64_t groupSize = group.size();
232+
if (groupSize == 1) {
233+
sourceIndices.push_back(indices[group[0]]);
234+
continue;
235+
}
236+
SmallVector<OpFoldResult> groupBasis =
237+
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
238+
SmallVector<Value> groupIndices =
239+
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
240+
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
241+
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
242+
sourceIndices.push_back(collapsedIndex);
243+
}
244+
return success();
245+
}
246+
247+
LogicalResult
248+
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
249+
memref::CollapseShapeOp collapseShapeOp,
250+
ValueRange indices,
251+
SmallVectorImpl<Value> &sourceIndices) {
252+
// Note: collapse_shape requires a strided memref, we can do this.
253+
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
254+
loc, collapseShapeOp.getSrc());
255+
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
256+
for (auto [index, group] :
257+
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
258+
assert(!group.empty() && "association indices groups cannot be empty");
259+
int64_t groupSize = group.size();
260+
261+
if (groupSize == 1) {
262+
sourceIndices.push_back(index);
263+
continue;
264+
}
265+
266+
SmallVector<OpFoldResult> basis =
267+
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
268+
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
269+
loc, index, basis, /*hasOuterBound=*/true);
270+
llvm::append_range(sourceIndices, delinearize.getResults());
271+
}
272+
if (collapseShapeOp.getReassociationIndices().empty()) {
273+
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
274+
int64_t srcRank =
275+
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
276+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
277+
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
278+
for (int64_t i = 0; i < srcRank; i++) {
279+
sourceIndices.push_back(
280+
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
281+
}
282+
}
283+
return success();
284+
}
285+
220286
} // namespace memref
221287
} // namespace mlir

0 commit comments

Comments
 (0)