Skip to content

Commit 2b28223

Browse files
committed
Fix affine fold memref alias pattern
in the case where the subview op uses an index of an affine loop as eg, the offset. The previous implementation always generated symbols, and the verifier failed, although the transformation is valid if you just generate dims for the variables that are not symbols, but are valid dims.
1 parent 1f483c9 commit 2b28223

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ LogicalResult mlir::affine::mergeOffsetsSizesAndStrides(
7878
combinedOffsets, combinedSizes, combinedStrides);
7979
}
8080

81+
static AffineMap bindSymbolsOrDims(
82+
MLIRContext *ctx, llvm::ArrayRef<OpFoldResult> operands,
83+
function_ref<AffineExpr(llvm::SmallVectorImpl<AffineExpr> &)> makeExpr) {
84+
SmallVector<AffineExpr, 4> affineExprs(operands.size());
85+
unsigned symbolCount = 0;
86+
unsigned dimCount = 0;
87+
for (auto [e, value] : llvm::zip_equal(affineExprs, operands)) {
88+
auto asValue = llvm::dyn_cast_or_null<Value>(value);
89+
if (asValue && !affine::isValidSymbol(asValue) &&
90+
affine::isValidDim(asValue)) {
91+
e = getAffineDimExpr(dimCount++, ctx);
92+
} else {
93+
// This is also done if it is not a valid symbol but
94+
// we don't care, we need a fallback.
95+
e = getAffineSymbolExpr(symbolCount++, ctx);
96+
}
97+
}
98+
return AffineMap::get(dimCount, symbolCount, makeExpr(affineExprs));
99+
}
100+
81101
void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
82102
RewriterBase &rewriter, Location loc,
83103
ArrayRef<OpFoldResult> mixedSourceOffsets,
@@ -100,11 +120,12 @@ void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
100120
resolvedIndices.clear();
101121
for (auto [offset, index, stride] :
102122
llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) {
103-
AffineExpr off, idx, str;
104-
bindSymbols(rewriter.getContext(), off, idx, str);
105-
OpFoldResult ofr = makeComposedFoldedAffineApply(
106-
rewriter, loc, AffineMap::get(0, 3, off + idx * str),
107-
{offset, index, stride});
123+
auto affineMap =
124+
bindSymbolsOrDims(rewriter.getContext(), {offset, index, stride},
125+
[](auto &e) { return e[0] + e[1] * e[2]; });
126+
127+
OpFoldResult ofr = makeComposedFoldedAffineApply(rewriter, loc, affineMap,
128+
{offset, index, stride});
108129
resolvedIndices.push_back(
109130
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
110131
}

0 commit comments

Comments
 (0)