Skip to content

Commit 9cb5c00

Browse files
authored
Revert "[AMDGPU] fold memref.subview/expand_shape/collapse_shape in… (#150256)
…to `amdgpu.gather_to_lds` (#149851)" This reverts commit dbc63f1. Having build deps issue.
1 parent ce9d515 commit 9cb5c00

File tree

8 files changed

+93
-313
lines changed

8 files changed

+93
-313
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ class ConversionTarget;
2222
namespace amdgpu {
2323

2424
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
25-
#define GEN_PASS_DECL_AMDGPUFOLDMEMREFOPSPASS
26-
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
2725
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
26+
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
2827
#define GEN_PASS_REGISTRATION
2928
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
3029

@@ -39,9 +38,6 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
3938
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
4039
PatternBenefit benefit = 1);
4140

42-
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
43-
PatternBenefit benefit = 1);
44-
4541
} // namespace amdgpu
4642
} // namespace mlir
4743

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,4 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
7070
"memref::MemRefDialect"
7171
];
7272
}
73-
74-
def AmdgpuFoldMemRefOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
75-
let summary = "Fold memref operations into their parent operations";
76-
let description = [{
77-
This pass identifies memref operations (subview, expand_shape, collapse_shape)
78-
that are sources of `GatherToLDSOp` and attempts to fold the source ops,
79-
potentially simplifying the overall operation and improving performance.
80-
}];
81-
let dependentDialects = [
82-
"memref::MemRefDialect"
83-
];
84-
}
8573
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -116,43 +116,6 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
116116
/// the source memref (i.e. implements ViewLikeOpInterface).
117117
MemrefValue skipViewLikeOps(MemrefValue source);
118118

119-
/// Given the 'indices' of a load/store operation where the memref is a result
120-
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
121-
/// expand_shape op. For example
122-
///
123-
/// %0 = ... : memref<12x42xf32>
124-
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
125-
/// : memref<12x42xf32> into memref<2x6x42xf32>
126-
/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
127-
///
128-
/// could be folded into
129-
///
130-
/// %2 = load %0[6 * i1 + i2, %i3] :
131-
/// memref<12x42xf32>
132-
LogicalResult resolveSourceIndicesExpandShape(
133-
Location loc, PatternRewriter &rewriter,
134-
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
135-
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
136-
137-
/// Given the 'indices' of a load/store operation where the memref is a result
138-
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
139-
/// the collapse_shape op. For example
140-
///
141-
/// %0 = ... : memref<2x6x42xf32>
142-
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
143-
/// : memref<2x6x42xf32> into memref<12x42xf32>
144-
/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
145-
///
146-
/// could be folded into
147-
///
148-
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
149-
/// memref<2x6x42xf32>
150-
LogicalResult
151-
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
152-
memref::CollapseShapeOp collapseShapeOp,
153-
ValueRange indices,
154-
SmallVectorImpl<Value> &sourceIndices);
155-
156119
} // namespace memref
157120
} // namespace mlir
158121

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
add_mlir_dialect_library(MLIRAMDGPUTransforms
22
EmulateAtomics.cpp
3-
FoldMemRefsOps.cpp
4-
MaskedloadToLoad.cpp
53
ResolveStridedMetadata.cpp
4+
MaskedloadToLoad.cpp
65

76
ADDITIONAL_HEADER_DIRS
87
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms

mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp

Lines changed: 0 additions & 97 deletions
This file was deleted.

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,97 @@ using namespace mlir;
4444
// Utility functions
4545
//===----------------------------------------------------------------------===//
4646

47+
/// Given the 'indices' of a load/store operation where the memref is a result
48+
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
49+
/// expand_shape op. For example
50+
///
51+
/// %0 = ... : memref<12x42xf32>
52+
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
53+
/// : memref<12x42xf32> into memref<2x6x42xf32>
54+
/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
55+
///
56+
/// could be folded into
57+
///
58+
/// %2 = load %0[6 * i1 + i2, %i3] :
59+
/// memref<12x42xf32>
60+
static LogicalResult resolveSourceIndicesExpandShape(
61+
Location loc, PatternRewriter &rewriter,
62+
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
63+
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
64+
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
65+
66+
// Traverse all reassociation groups to determine the appropriate indices
67+
// corresponding to each one of them post op folding.
68+
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
69+
assert(!group.empty() && "association indices groups cannot be empty");
70+
int64_t groupSize = group.size();
71+
if (groupSize == 1) {
72+
sourceIndices.push_back(indices[group[0]]);
73+
continue;
74+
}
75+
SmallVector<OpFoldResult> groupBasis =
76+
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
77+
SmallVector<Value> groupIndices =
78+
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
79+
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
80+
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
81+
sourceIndices.push_back(collapsedIndex);
82+
}
83+
return success();
84+
}
85+
86+
/// Given the 'indices' of a load/store operation where the memref is a result
87+
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
88+
/// the collapse_shape op. For example
89+
///
90+
/// %0 = ... : memref<2x6x42xf32>
91+
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
92+
/// : memref<2x6x42xf32> into memref<12x42xf32>
93+
/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
94+
///
95+
/// could be folded into
96+
///
97+
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
98+
/// memref<2x6x42xf32>
99+
static LogicalResult
100+
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
101+
memref::CollapseShapeOp collapseShapeOp,
102+
ValueRange indices,
103+
SmallVectorImpl<Value> &sourceIndices) {
104+
// Note: collapse_shape requires a strided memref, we can do this.
105+
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
106+
loc, collapseShapeOp.getSrc());
107+
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
108+
for (auto [index, group] :
109+
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
110+
assert(!group.empty() && "association indices groups cannot be empty");
111+
int64_t groupSize = group.size();
112+
113+
if (groupSize == 1) {
114+
sourceIndices.push_back(index);
115+
continue;
116+
}
117+
118+
SmallVector<OpFoldResult> basis =
119+
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
120+
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
121+
loc, index, basis, /*hasOuterBound=*/true);
122+
llvm::append_range(sourceIndices, delinearize.getResults());
123+
}
124+
if (collapseShapeOp.getReassociationIndices().empty()) {
125+
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
126+
int64_t srcRank =
127+
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
128+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
129+
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
130+
for (int64_t i = 0; i < srcRank; i++) {
131+
sourceIndices.push_back(
132+
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
133+
}
134+
}
135+
return success();
136+
}
137+
47138
/// Helpers to access the memref operand for each op.
48139
template <typename LoadOrStoreOpTy>
49140
static Value getMemRefOperand(LoadOrStoreOpTy op) {

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1414
#include "mlir/Dialect/Affine/IR/AffineOps.h"
15-
#include "mlir/Dialect/Arith/Utils/Utils.h"
1615
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1716
#include "mlir/Interfaces/ViewLikeInterface.h"
1817
#include "llvm/ADT/STLExtras.h"
@@ -218,70 +217,5 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
218217
return source;
219218
}
220219

221-
LogicalResult resolveSourceIndicesExpandShape(
222-
Location loc, PatternRewriter &rewriter,
223-
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
224-
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
225-
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
226-
227-
// Traverse all reassociation groups to determine the appropriate indices
228-
// corresponding to each one of them post op folding.
229-
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
230-
assert(!group.empty() && "association indices groups cannot be empty");
231-
int64_t groupSize = group.size();
232-
if (groupSize == 1) {
233-
sourceIndices.push_back(indices[group[0]]);
234-
continue;
235-
}
236-
SmallVector<OpFoldResult> groupBasis =
237-
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
238-
SmallVector<Value> groupIndices =
239-
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
240-
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
241-
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
242-
sourceIndices.push_back(collapsedIndex);
243-
}
244-
return success();
245-
}
246-
247-
LogicalResult
248-
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
249-
memref::CollapseShapeOp collapseShapeOp,
250-
ValueRange indices,
251-
SmallVectorImpl<Value> &sourceIndices) {
252-
// Note: collapse_shape requires a strided memref, we can do this.
253-
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
254-
loc, collapseShapeOp.getSrc());
255-
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
256-
for (auto [index, group] :
257-
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
258-
assert(!group.empty() && "association indices groups cannot be empty");
259-
int64_t groupSize = group.size();
260-
261-
if (groupSize == 1) {
262-
sourceIndices.push_back(index);
263-
continue;
264-
}
265-
266-
SmallVector<OpFoldResult> basis =
267-
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
268-
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
269-
loc, index, basis, /*hasOuterBound=*/true);
270-
llvm::append_range(sourceIndices, delinearize.getResults());
271-
}
272-
if (collapseShapeOp.getReassociationIndices().empty()) {
273-
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
274-
int64_t srcRank =
275-
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
276-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
277-
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
278-
for (int64_t i = 0; i < srcRank; i++) {
279-
sourceIndices.push_back(
280-
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
281-
}
282-
}
283-
return success();
284-
}
285-
286220
} // namespace memref
287221
} // namespace mlir

0 commit comments

Comments
 (0)