1616#include " mlir/Dialect/MemRef/IR/MemRef.h"
1717#include " mlir/Dialect/MemRef/Transforms/Passes.h"
1818#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
19+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1920#include " mlir/Dialect/Utils/IndexingUtils.h"
2021#include " mlir/Dialect/Utils/StaticValueUtils.h"
2122#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -38,141 +39,6 @@ namespace memref {
3839
3940using namespace mlir ;
4041
41- static void setInsertionPointToStart (OpBuilder &builder, Value val) {
42- if (auto *parentOp = val.getDefiningOp ()) {
43- builder.setInsertionPointAfter (parentOp);
44- } else {
45- builder.setInsertionPointToStart (val.getParentBlock ());
46- }
47- }
48-
49- OpFoldResult computeMemRefSpan (Value memref, OpBuilder &builder) {
50- Location loc = memref.getLoc ();
51- MemRefType type = cast<MemRefType>(memref.getType ());
52- ArrayRef<int64_t > shape = type.getShape ();
53-
54- // Check for empty memref
55- if (type.hasStaticShape () &&
56- llvm::any_of (shape, [](int64_t dim) { return dim == 0 ; })) {
57- return builder.getIndexAttr (0 );
58- }
59-
60- // Get strides of the memref
61- SmallVector<int64_t , 4 > strides;
62- int64_t offset;
63- if (failed (type.getStridesAndOffset (strides, offset))) {
64- // Cannot extract strides, return a dynamic value
65- return Value ();
66- }
67-
68- // Static case: compute at compile time if possible
69- if (type.hasStaticShape ()) {
70- int64_t span = 0 ;
71- for (unsigned i = 0 ; i < type.getRank (); ++i) {
72- span += (shape[i] - 1 ) * strides[i];
73- }
74- return builder.getIndexAttr (span);
75- }
76-
77- // Dynamic case: emit IR to compute at runtime
78- Value result = builder.create <arith::ConstantIndexOp>(loc, 0 );
79-
80- for (unsigned i = 0 ; i < type.getRank (); ++i) {
81- // Get dimension size
82- Value dimSize;
83- if (shape[i] == ShapedType::kDynamic ) {
84- dimSize = builder.create <memref::DimOp>(loc, memref, i);
85- } else {
86- dimSize = builder.create <arith::ConstantIndexOp>(loc, shape[i]);
87- }
88-
89- // Compute (dim - 1)
90- Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
91- Value dimMinusOne = builder.create <arith::SubIOp>(loc, dimSize, one);
92-
93- // Get stride
94- Value stride;
95- if (strides[i] == ShapedType::kDynamicStrideOrOffset ) {
96- // For dynamic strides, need to extract from memref descriptor
97- // This would require runtime support, possibly using extractStride
98- // As a placeholder, return a dynamic value
99- return Value ();
100- } else {
101- stride = builder.create <arith::ConstantIndexOp>(loc, strides[i]);
102- }
103-
104- // Add (dim - 1) * stride to result
105- Value term = builder.create <arith::MulIOp>(loc, dimMinusOne, stride);
106- result = builder.create <arith::AddIOp>(loc, result, term);
107- }
108-
109- return result;
110- }
111-
112- static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
113- OpFoldResult>
114- getFlatOffsetAndStrides (OpBuilder &rewriter, Location loc, Value source,
115- ArrayRef<OpFoldResult> subOffsets,
116- ArrayRef<OpFoldResult> subStrides = std::nullopt ) {
117- auto sourceType = cast<MemRefType>(source.getType ());
118- auto sourceRank = static_cast <unsigned >(sourceType.getRank ());
119-
120- memref::ExtractStridedMetadataOp newExtractStridedMetadata;
121- {
122- OpBuilder::InsertionGuard g (rewriter);
123- setInsertionPointToStart (rewriter, source);
124- newExtractStridedMetadata =
125- rewriter.create <memref::ExtractStridedMetadataOp>(loc, source);
126- }
127-
128- auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset ();
129-
130- auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
131- return ShapedType::isDynamic (dim) ? getAsOpFoldResult (dimVal)
132- : rewriter.getIndexAttr (dim);
133- };
134-
135- OpFoldResult origOffset =
136- getDim (sourceOffset, newExtractStridedMetadata.getOffset ());
137- ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides ();
138- OpFoldResult outmostDim =
139- getDim (sourceType.getShape ().front (),
140- newExtractStridedMetadata.getSizes ().front ());
141-
142- SmallVector<OpFoldResult> origStrides;
143- origStrides.reserve (sourceRank);
144-
145- SmallVector<OpFoldResult> strides;
146- strides.reserve (sourceRank);
147-
148- AffineExpr s0 = rewriter.getAffineSymbolExpr (0 );
149- AffineExpr s1 = rewriter.getAffineSymbolExpr (1 );
150- for (auto i : llvm::seq (0u , sourceRank)) {
151- OpFoldResult origStride = getDim (sourceStrides[i], sourceStridesVals[i]);
152-
153- if (!subStrides.empty ()) {
154- strides.push_back (affine::makeComposedFoldedAffineApply (
155- rewriter, loc, s0 * s1, {subStrides[i], origStride}));
156- }
157-
158- origStrides.emplace_back (origStride);
159- }
160-
161- // Compute linearized index:
162- auto &&[expr, values] =
163- computeLinearIndex (rewriter.getIndexAttr (0 ), origStrides, subOffsets);
164- OpFoldResult linearizedIndex =
165- affine::makeComposedFoldedAffineApply (rewriter, loc, expr, values);
166-
167- // Compute collapsed size: (the outmost stride * outmost dimension).
168- // SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
169- // OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
170- OpFoldResult collapsedSize = computeMemRefSpan (source, rewriter);
171-
172- return {newExtractStridedMetadata.getBaseBuffer (), linearizedIndex,
173- origStrides, origOffset, collapsedSize};
174- }
175-
17642static Value getValueFromOpFoldResult (OpBuilder &rewriter, Location loc,
17743 OpFoldResult in) {
17844 if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
@@ -188,17 +54,36 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
18854 Location loc,
18955 Value source,
19056 ValueRange indices) {
191- auto &&[base, index, strides, offset, collapsedShape] =
192- getFlatOffsetAndStrides (rewriter, loc, source,
193- getAsOpFoldResult (indices));
57+ int64_t sourceOffset;
58+ SmallVector<int64_t , 4 > sourceStrides;
59+ auto sourceType = cast<MemRefType>(source.getType ());
60+ if (failed (sourceType.getStridesAndOffset (sourceStrides, sourceOffset))) {
61+ assert (false );
62+ }
63+
64+ memref::ExtractStridedMetadataOp stridedMetadata =
65+ rewriter.create <memref::ExtractStridedMetadataOp>(loc, source);
66+
67+ auto typeBit = sourceType.getElementType ().getIntOrFloatBitWidth ();
68+ OpFoldResult linearizedIndices;
69+ memref::LinearizedMemRefInfo linearizedInfo;
70+ std::tie (linearizedInfo, linearizedIndices) =
71+ memref::getLinearizedMemRefOffsetAndSize (
72+ rewriter, loc, typeBit, typeBit,
73+ stridedMetadata.getConstifiedMixedOffset (),
74+ stridedMetadata.getConstifiedMixedSizes (),
75+ stridedMetadata.getConstifiedMixedStrides (),
76+ getAsOpFoldResult (indices));
19477
19578 return std::make_pair (
19679 rewriter.create <memref::ReinterpretCastOp>(
19780 loc, source,
198- /* offset = */ offset,
199- /* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
200- /* strides = */ ArrayRef<OpFoldResult>{strides.back ()}),
201- getValueFromOpFoldResult (rewriter, loc, index));
81+ /* offset = */ linearizedInfo.linearizedOffset ,
82+ /* shapes = */ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize },
83+ /* strides = */
84+ ArrayRef<OpFoldResult>{
85+ stridedMetadata.getConstifiedMixedStrides ().back ()}),
86+ getValueFromOpFoldResult (rewriter, loc, linearizedIndices));
20287}
20388
20489static bool needFlattening (Value val) {
@@ -313,8 +198,23 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
313198 SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets ();
314199 SmallVector<OpFoldResult> subSizes = op.getMixedSizes ();
315200 SmallVector<OpFoldResult> subStrides = op.getMixedStrides ();
316- auto &&[base, finalOffset, strides, _, __] =
317- getFlatOffsetAndStrides (rewriter, loc, memref, subOffsets, subStrides);
201+
202+ // base, finalOffset, strides
203+ memref::ExtractStridedMetadataOp stridedMetadata =
204+ rewriter.create <memref::ExtractStridedMetadataOp>(loc, memref);
205+
206+ auto sourceType = cast<MemRefType>(memref.getType ());
207+ auto typeBit = sourceType.getElementType ().getIntOrFloatBitWidth ();
208+ OpFoldResult linearizedIndices;
209+ memref::LinearizedMemRefInfo linearizedInfo;
210+ std::tie (linearizedInfo, linearizedIndices) =
211+ memref::getLinearizedMemRefOffsetAndSize (
212+ rewriter, loc, typeBit, typeBit,
213+ stridedMetadata.getConstifiedMixedOffset (),
214+ stridedMetadata.getConstifiedMixedSizes (),
215+ stridedMetadata.getConstifiedMixedStrides (), op.getMixedOffsets ());
216+ auto finalOffset = linearizedInfo.linearizedOffset ;
217+ auto strides = stridedMetadata.getConstifiedMixedStrides ();
318218
319219 auto srcType = cast<MemRefType>(memref.getType ());
320220 auto resultType = cast<MemRefType>(op.getType ());
@@ -337,7 +237,7 @@ struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
337237 }
338238
339239 rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
340- op, resultType, base , finalOffset, finalSizes, finalStrides);
240+ op, resultType, memref , finalOffset, finalSizes, finalStrides);
341241 return success ();
342242 }
343243};
@@ -364,12 +264,13 @@ struct FlattenMemrefsPass
364264} // namespace
365265
366266void memref::populateFlattenMemrefsPatterns (RewritePatternSet &patterns) {
367- patterns
368- .insert <MemRefRewritePattern<memref::LoadOp>,
369- MemRefRewritePattern<memref::StoreOp>,
370- MemRefRewritePattern<vector::LoadOp>,
371- MemRefRewritePattern<vector::StoreOp>,
372- MemRefRewritePattern<vector::TransferReadOp>,
373- MemRefRewritePattern<vector::TransferWriteOp>, FlattenSubview>(
374- patterns.getContext ());
267+ patterns.insert <MemRefRewritePattern<memref::LoadOp>,
268+ MemRefRewritePattern<memref::StoreOp>,
269+ MemRefRewritePattern<vector::LoadOp>,
270+ MemRefRewritePattern<vector::StoreOp>,
271+ MemRefRewritePattern<vector::TransferReadOp>,
272+ MemRefRewritePattern<vector::TransferWriteOp>,
273+ MemRefRewritePattern<vector::MaskedLoadOp>,
274+ MemRefRewritePattern<vector::MaskedStoreOp>, FlattenSubview>(
275+ patterns.getContext ());
375276}
0 commit comments