Skip to content

Commit 1fee833

Browse files
committed
Change the way how linearized sizes are computed.
1 parent 189cddf commit 1fee833

File tree

1 file changed

+69
-3
lines changed

1 file changed

+69
-3
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,70 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
5151
return cast<Value>(in);
5252
}
5353

54+
static bool hasDynamicDim(ArrayRef<OpFoldResult> dims) {
55+
for (auto &&dim : dims) {
56+
auto constant = getConstantIntValue(dim);
57+
if (!constant || *constant < 0) {
58+
return true;
59+
}
60+
}
61+
return false;
62+
}
63+
64+
static OpFoldResult computeStaticShape(OpBuilder &builder, Location loc,
65+
ArrayRef<OpFoldResult> dims,
66+
ArrayRef<OpFoldResult> strides) {
67+
// max(dims[i] * strides[i]) for i = 0, 1, ..., n-1
68+
int64_t maxSize = 1;
69+
for (auto &&[dim, stride] : llvm::zip(dims, strides)) {
70+
AffineExpr s0, s1;
71+
bindSymbols(builder.getContext(), s0, s1);
72+
OpFoldResult size = affine::makeComposedFoldedAffineApply(
73+
builder, loc, s0 * s1, ArrayRef<OpFoldResult>{dim, stride});
74+
auto constant = getConstantIntValue(size);
75+
assert(constant && "expected constant value");
76+
maxSize = *constant;
77+
}
78+
return builder.getIndexAttr(maxSize);
79+
}
80+
81+
static OpFoldResult computeDynamicShape(OpBuilder &builder, Location loc,
82+
ArrayRef<OpFoldResult> dims,
83+
ArrayRef<OpFoldResult> strides) {
84+
85+
SmallVector<AffineExpr> symbols(2 * dims.size());
86+
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
87+
SmallVector<AffineExpr> productExpressions;
88+
SmallVector<Value> values;
89+
size_t symbolIndex = 0;
90+
for (auto &&[dim, stride] : llvm::zip(dims, strides)) {
91+
AffineExpr dimExpr = symbols[symbolIndex++];
92+
AffineExpr strideExpr = symbols[symbolIndex++];
93+
productExpressions.push_back(dimExpr * strideExpr);
94+
values.push_back(getValueFromOpFoldResult(builder, loc, dim));
95+
values.push_back(getValueFromOpFoldResult(builder, loc, stride));
96+
}
97+
98+
AffineMap maxMap = AffineMap::get(0, symbols.size(), productExpressions,
99+
builder.getContext());
100+
Value maxValue =
101+
builder.create<affine::AffineMaxOp>(loc, maxMap, values).getResult();
102+
return maxValue;
103+
}
104+
105+
/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
106+
/// span of the memref.
107+
static OpFoldResult computeSpan(OpBuilder &builder, Location loc,
108+
ArrayRef<OpFoldResult> dims,
109+
ArrayRef<OpFoldResult> strides) {
110+
assert(dims.size() == strides.size() &&
111+
"number of dimensions and strides should be equal");
112+
if (hasDynamicDim(dims) || hasDynamicDim(strides)) {
113+
return computeDynamicShape(builder, loc, dims, strides);
114+
}
115+
return computeStaticShape(builder, loc, dims, strides);
116+
}
117+
54118
/// Returns a collapsed memref and the linearized index to access the element
55119
/// at the specified indices.
56120
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -82,10 +146,12 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
82146
rewriter.create<memref::ReinterpretCastOp>(
83147
loc, source,
84148
/* offset = */ linearizedInfo.linearizedOffset,
85-
/* shapes = */ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
149+
/* shapes = */
150+
ArrayRef<OpFoldResult>{computeSpan(
151+
rewriter, loc, stridedMetadata.getConstifiedMixedSizes(),
152+
stridedMetadata.getConstifiedMixedStrides())},
86153
/* strides = */
87-
ArrayRef<OpFoldResult>{
88-
stridedMetadata.getConstifiedMixedStrides().back()}),
154+
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
89155
getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
90156
}
91157

0 commit comments

Comments
 (0)