Skip to content

Commit 8e62bf5

Browse files
committed
simplify folds
1 parent e5846ab commit 8e62bf5

File tree

1 file changed

+11
-50
lines changed

1 file changed

+11
-50
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 11 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -51,68 +51,29 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
5151
return cast<Value>(in);
5252
}
5353

54-
static bool hasDynamicDim(ArrayRef<OpFoldResult> dims) {
55-
for (auto &&dim : dims) {
56-
auto constant = getConstantIntValue(dim);
57-
if (!constant || *constant < 0) {
58-
return true;
59-
}
60-
}
61-
return false;
62-
}
63-
64-
static OpFoldResult computeStaticShape(OpBuilder &builder, Location loc,
65-
ArrayRef<OpFoldResult> dims,
66-
ArrayRef<OpFoldResult> strides) {
67-
// max(dims[i] * strides[i]) for i = 0, 1, ..., n-1
68-
int64_t maxSize = 1;
69-
for (auto &&[dim, stride] : llvm::zip(dims, strides)) {
70-
AffineExpr s0, s1;
71-
bindSymbols(builder.getContext(), s0, s1);
72-
OpFoldResult size = affine::makeComposedFoldedAffineApply(
73-
builder, loc, s0 * s1, ArrayRef<OpFoldResult>{dim, stride});
74-
auto constant = getConstantIntValue(size);
75-
assert(constant && "expected constant value");
76-
maxSize = std::max(maxSize, *constant);
77-
}
78-
return builder.getIndexAttr(maxSize);
79-
}
80-
81-
static OpFoldResult computeDynamicShape(OpBuilder &builder, Location loc,
82-
ArrayRef<OpFoldResult> dims,
83-
ArrayRef<OpFoldResult> strides) {
84-
54+
/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
55+
/// span of the memref.
56+
static OpFoldResult computeSize(OpBuilder &builder, Location loc,
57+
ArrayRef<OpFoldResult> dims,
58+
ArrayRef<OpFoldResult> strides) {
59+
assert(dims.size() == strides.size() &&
60+
"number of dimensions and strides should be equal");
8561
SmallVector<AffineExpr> symbols(2 * dims.size());
8662
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
8763
SmallVector<AffineExpr> productExpressions;
88-
SmallVector<Value> values;
64+
SmallVector<OpFoldResult> values;
8965
size_t symbolIndex = 0;
9066
for (auto &&[dim, stride] : llvm::zip(dims, strides)) {
9167
AffineExpr dimExpr = symbols[symbolIndex++];
9268
AffineExpr strideExpr = symbols[symbolIndex++];
9369
productExpressions.push_back(dimExpr * strideExpr);
94-
values.push_back(getValueFromOpFoldResult(builder, loc, dim));
95-
values.push_back(getValueFromOpFoldResult(builder, loc, stride));
70+
values.push_back(dim);
71+
values.push_back(stride);
9672
}
9773

9874
AffineMap maxMap = AffineMap::get(0, symbols.size(), productExpressions,
9975
builder.getContext());
100-
Value maxValue =
101-
builder.create<affine::AffineMaxOp>(loc, maxMap, values).getResult();
102-
return maxValue;
103-
}
104-
105-
/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
106-
/// span of the memref.
107-
static OpFoldResult computeSize(OpBuilder &builder, Location loc,
108-
ArrayRef<OpFoldResult> dims,
109-
ArrayRef<OpFoldResult> strides) {
110-
assert(dims.size() == strides.size() &&
111-
"number of dimensions and strides should be equal");
112-
if (hasDynamicDim(dims) || hasDynamicDim(strides)) {
113-
return computeDynamicShape(builder, loc, dims, strides);
114-
}
115-
return computeStaticShape(builder, loc, dims, strides);
76+
return affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
11677
}
11778

11879
/// Returns a collapsed memref and the linearized index to access the element

0 commit comments

Comments
 (0)