Skip to content

Commit 9170789

Browse files
committed
missing the file
1 parent c003e70 commit 9170789

File tree

1 file changed

+357
-0
lines changed

1 file changed

+357
-0
lines changed
Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains patterns for flattening an multi-rank memref-related
10+
// ops into 1-d memref ops.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18+
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
19+
#include "mlir/Dialect/Utils/IndexingUtils.h"
20+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
21+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
22+
#include "mlir/IR/AffineExpr.h"
23+
#include "mlir/IR/Attributes.h"
24+
#include "mlir/IR/Builders.h"
25+
#include "mlir/IR/BuiltinTypes.h"
26+
#include "mlir/IR/OpDefinition.h"
27+
#include "mlir/IR/PatternMatch.h"
28+
#include "mlir/Pass/Pass.h"
29+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30+
31+
32+
namespace mlir {
33+
namespace memref {
34+
#define GEN_PASS_DEF_FLATTENMEMREFSPASS
35+
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
36+
} // namespace memref
37+
} // namespace mlir
38+
39+
using namespace mlir;
40+
41+
static void setInsertionPointToStart(OpBuilder &builder, Value val) {
42+
if (auto *parentOp = val.getDefiningOp()) {
43+
builder.setInsertionPointAfter(parentOp);
44+
} else {
45+
builder.setInsertionPointToStart(val.getParentBlock());
46+
}
47+
}
48+
49+
static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
50+
OpFoldResult>
51+
getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
52+
ArrayRef<OpFoldResult> subOffsets,
53+
ArrayRef<OpFoldResult> subStrides = std::nullopt) {
54+
auto sourceType = cast<MemRefType>(source.getType());
55+
auto sourceRank = static_cast<unsigned>(sourceType.getRank());
56+
57+
memref::ExtractStridedMetadataOp newExtractStridedMetadata;
58+
{
59+
OpBuilder::InsertionGuard g(rewriter);
60+
setInsertionPointToStart(rewriter, source);
61+
newExtractStridedMetadata =
62+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
63+
}
64+
65+
auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
66+
67+
auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
68+
return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
69+
: rewriter.getIndexAttr(dim);
70+
};
71+
72+
OpFoldResult origOffset =
73+
getDim(sourceOffset, newExtractStridedMetadata.getOffset());
74+
ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
75+
OpFoldResult outmostDim =
76+
getDim(sourceType.getShape().front(),
77+
newExtractStridedMetadata.getSizes().front());
78+
79+
SmallVector<OpFoldResult> origStrides;
80+
origStrides.reserve(sourceRank);
81+
82+
SmallVector<OpFoldResult> strides;
83+
strides.reserve(sourceRank);
84+
85+
AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
86+
AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
87+
for (auto i : llvm::seq(0u, sourceRank)) {
88+
OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
89+
90+
if (!subStrides.empty()) {
91+
strides.push_back(affine::makeComposedFoldedAffineApply(
92+
rewriter, loc, s0 * s1, {subStrides[i], origStride}));
93+
}
94+
95+
origStrides.emplace_back(origStride);
96+
}
97+
98+
// Compute linearized index:
99+
auto &&[expr, values] =
100+
computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
101+
OpFoldResult linearizedIndex =
102+
affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
103+
104+
// Compute collapsed size: (the outmost stride * outmost dimension).
105+
SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
106+
OpFoldResult collapsedSize = computeProduct(loc, rewriter, ops);
107+
108+
return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
109+
origStrides, origOffset, collapsedSize};
110+
}
111+
112+
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
113+
OpFoldResult in) {
114+
if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
115+
return rewriter.create<arith::ConstantIndexOp>(
116+
loc, cast<IntegerAttr>(offsetAttr).getInt());
117+
}
118+
return cast<Value>(in);
119+
}
120+
121+
/// Returns a collapsed memref and the linearized index to access the element
122+
/// at the specified indices.
123+
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
124+
Location loc,
125+
Value source,
126+
ValueRange indices) {
127+
auto &&[base, index, strides, offset, collapsedShape] =
128+
getFlatOffsetAndStrides(rewriter, loc, source,
129+
getAsOpFoldResult(indices));
130+
131+
return std::make_pair(
132+
rewriter.create<memref::ReinterpretCastOp>(
133+
loc, source,
134+
/* offset = */ offset,
135+
/* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
136+
/* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
137+
getValueFromOpFoldResult(rewriter, loc, index));
138+
}
139+
140+
static bool needFlattenning(Value val) {
141+
auto type = cast<MemRefType>(val.getType());
142+
return type.getRank() > 1;
143+
}
144+
145+
static bool checkLayout(Value val) {
146+
auto type = cast<MemRefType>(val.getType());
147+
return type.getLayout().isIdentity() ||
148+
isa<StridedLayoutAttr>(type.getLayout());
149+
}
150+
151+
namespace {
152+
template <typename T>
153+
static Value getTargetMemref(T op) {
154+
if constexpr (std::is_same_v<T, memref::LoadOp>) {
155+
return op.getMemref();
156+
} else if constexpr (std::is_same_v<T, vector::LoadOp>) {
157+
return op.getBase();
158+
} else if constexpr (std::is_same_v<T, memref::StoreOp>) {
159+
return op.getMemref();
160+
} else if constexpr (std::is_same_v<T, vector::StoreOp>) {
161+
return op.getBase();
162+
} else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
163+
return op.getBase();
164+
} else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
165+
return op.getBase();
166+
} else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
167+
return op.getSource();
168+
} else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
169+
return op.getSource();
170+
}
171+
return {};
172+
}
173+
174+
template <typename T>
175+
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
176+
Value offset) {
177+
if constexpr (std::is_same_v<T, memref::LoadOp>) {
178+
auto newLoad = rewriter.create<memref::LoadOp>(
179+
op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
180+
newLoad->setAttrs(op->getAttrs());
181+
rewriter.replaceOp(op, newLoad.getResult());
182+
} else if constexpr (std::is_same_v<T, vector::LoadOp>) {
183+
auto newLoad = rewriter.create<vector::LoadOp>(
184+
op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
185+
newLoad->setAttrs(op->getAttrs());
186+
rewriter.replaceOp(op, newLoad.getResult());
187+
} else if constexpr (std::is_same_v<T, memref::StoreOp>) {
188+
auto newStore = rewriter.create<memref::StoreOp>(
189+
op->getLoc(), op->getOperands().front(), flatMemref,
190+
ValueRange{offset});
191+
newStore->setAttrs(op->getAttrs());
192+
rewriter.replaceOp(op, newStore);
193+
} else if constexpr (std::is_same_v<T, vector::StoreOp>) {
194+
auto newStore = rewriter.create<vector::StoreOp>(
195+
op->getLoc(), op->getOperands().front(), flatMemref,
196+
ValueRange{offset});
197+
newStore->setAttrs(op->getAttrs());
198+
rewriter.replaceOp(op, newStore);
199+
} else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
200+
auto newTransferRead = rewriter.create<vector::TransferReadOp>(
201+
op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
202+
op.getPadding());
203+
rewriter.replaceOp(op, newTransferRead.getResult());
204+
} else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
205+
auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
206+
op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
207+
rewriter.replaceOp(op, newTransferWrite);
208+
} else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
209+
auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
210+
op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
211+
op.getMask(), op.getPassThru());
212+
newMaskedLoad->setAttrs(op->getAttrs());
213+
rewriter.replaceOp(op, newMaskedLoad.getResult());
214+
} else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
215+
auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
216+
op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
217+
op.getValueToStore());
218+
newMaskedStore->setAttrs(op->getAttrs());
219+
rewriter.replaceOp(op, newMaskedStore);
220+
} else {
221+
op.emitOpError("unimplemented: do not know how to replace op.");
222+
}
223+
}
224+
225+
template <typename T>
226+
struct MemRefRewritePatternBase : public OpRewritePattern<T> {
227+
using OpRewritePattern<T>::OpRewritePattern;
228+
LogicalResult matchAndRewrite(T op,
229+
PatternRewriter &rewriter) const override {
230+
Value memref = getTargetMemref<T>(op);
231+
if (!needFlattenning(memref) || !checkLayout(memref))
232+
return rewriter.notifyMatchFailure(op,
233+
"nothing to do or unsupported layout");
234+
auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
235+
rewriter, op->getLoc(), memref, op.getIndices());
236+
replaceOp<T>(op, rewriter, flatMemref, offset);
237+
return success();
238+
}
239+
};
240+
241+
struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> {
242+
using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase;
243+
};
244+
245+
struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> {
246+
using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase;
247+
};
248+
249+
struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> {
250+
using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase;
251+
};
252+
253+
struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> {
254+
using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase;
255+
};
256+
257+
struct FlattenVectorMaskedLoad
258+
: public MemRefRewritePatternBase<vector::MaskedLoadOp> {
259+
using MemRefRewritePatternBase<
260+
vector::MaskedLoadOp>::MemRefRewritePatternBase;
261+
};
262+
263+
struct FlattenVectorMaskedStore
264+
: public MemRefRewritePatternBase<vector::MaskedStoreOp> {
265+
using MemRefRewritePatternBase<
266+
vector::MaskedStoreOp>::MemRefRewritePatternBase;
267+
};
268+
269+
struct FlattenVectorTransferRead
270+
: public MemRefRewritePatternBase<vector::TransferReadOp> {
271+
using MemRefRewritePatternBase<
272+
vector::TransferReadOp>::MemRefRewritePatternBase;
273+
};
274+
275+
struct FlattenVectorTransferWrite
276+
: public MemRefRewritePatternBase<vector::TransferWriteOp> {
277+
using MemRefRewritePatternBase<
278+
vector::TransferWriteOp>::MemRefRewritePatternBase;
279+
};
280+
281+
struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
282+
using OpRewritePattern::OpRewritePattern;
283+
284+
LogicalResult matchAndRewrite(memref::SubViewOp op,
285+
PatternRewriter &rewriter) const override {
286+
Value memref = op.getSource();
287+
if (!needFlattenning(memref))
288+
return rewriter.notifyMatchFailure(op, "nothing to do");
289+
290+
if (!checkLayout(memref))
291+
return rewriter.notifyMatchFailure(op, "unsupported layout");
292+
293+
Location loc = op.getLoc();
294+
SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
295+
SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
296+
SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
297+
auto &&[base, finalOffset, strides, _, __] =
298+
getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
299+
300+
auto srcType = cast<MemRefType>(memref.getType());
301+
auto resultType = cast<MemRefType>(op.getType());
302+
unsigned subRank = static_cast<unsigned>(resultType.getRank());
303+
304+
llvm::SmallBitVector droppedDims = op.getDroppedDims();
305+
306+
SmallVector<OpFoldResult> finalSizes;
307+
finalSizes.reserve(subRank);
308+
309+
SmallVector<OpFoldResult> finalStrides;
310+
finalStrides.reserve(subRank);
311+
312+
for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
313+
if (droppedDims.test(i))
314+
continue;
315+
316+
finalSizes.push_back(subSizes[i]);
317+
finalStrides.push_back(strides[i]);
318+
}
319+
320+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
321+
op, resultType, base, finalOffset, finalSizes, finalStrides);
322+
return success();
323+
}
324+
};
325+
326+
struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
327+
using Base::Base;
328+
329+
void getDependentDialects(DialectRegistry &registry) const override {
330+
registry.insert<affine::AffineDialect, arith::ArithDialect,
331+
memref::MemRefDialect, vector::VectorDialect>();
332+
}
333+
334+
void runOnOperation() override {
335+
RewritePatternSet patterns(&getContext());
336+
337+
memref::populateFlattenMemrefsPatterns(patterns);
338+
339+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
340+
return signalPassFailure();
341+
}
342+
};
343+
344+
} // namespace
345+
346+
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
347+
patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
348+
FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
349+
FlattenVectorLoad, FlattenVectorStore,
350+
FlattenVectorTransferRead, FlattenVectorTransferWrite>(
351+
patterns.getContext());
352+
}
353+
354+
std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() {
355+
return std::make_unique<FlattenMemrefsPass>();
356+
}
357+

0 commit comments

Comments
 (0)