Skip to content

Commit 4d8ae5f

Browse files
committed
Allow the subview folding to consider expressions of dims and symbols
and not just plain dims and symbols. This makes it possible to fold memref.subview ops that use an affine expression of valid symbol and dims as an offset, even if that expression is computed by arith ops like muli and addi.
1 parent 2b28223 commit 4d8ae5f

File tree

4 files changed

+177
-124
lines changed

4 files changed

+177
-124
lines changed

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1818
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1919
#include "mlir/IR/OpDefinition.h"
20+
#include "llvm/ADT/ArrayRef.h"
21+
#include "llvm/Support/LogicalResult.h"
2022
#include <optional>
2123

2224
namespace mlir {
@@ -105,7 +107,7 @@ struct VectorizationStrategy {
105107
/// Replace affine store and load accesses by scalars by forwarding stores to
106108
/// loads and eliminate invariant affine loads; consequently, eliminate dead
107109
/// allocs.
108-
void affineScalarReplace(Operation* parentOp, DominanceInfo &domInfo,
110+
void affineScalarReplace(Operation *parentOp, DominanceInfo &domInfo,
109111
PostDominanceInfo &postDomInfo,
110112
AliasAnalysis &analysis);
111113

@@ -338,6 +340,20 @@ OpFoldResult linearizeIndex(OpBuilder &builder, Location loc,
338340
ArrayRef<OpFoldResult> multiIndex,
339341
ArrayRef<OpFoldResult> basis);
340342

343+
/// Given a set of indices into a memref which may be computed using
344+
/// arith ops, try to compute each value to an affine expr. This is
345+
/// only possible if the indices are an expression of valid dims and
346+
/// args. If this succeeds, the affine map is populated, along with
347+
/// the map arguments (concrete bindings for dims and symbols).
348+
LogicalResult
349+
convertValuesToAffineMapAndArgs(MLIRContext *ctx, ValueRange indices,
350+
AffineMap &map,
351+
llvm::SmallVectorImpl<Value> &mapArgs);
352+
LogicalResult
353+
convertValuesToAffineMapAndArgs(MLIRContext *ctx,
354+
ArrayRef<OpFoldResult> indices, AffineMap &map,
355+
llvm::SmallVectorImpl<OpFoldResult> &mapArgs);
356+
341357
/// Ensure that all operations that could be executed after `start`
342358
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
343359
/// between the operations) do not have the potential memory effect

mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp

Lines changed: 4 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,12 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/Affine/Analysis/Utils.h"
1514
#include "mlir/Dialect/Affine/Passes.h"
1615
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
1716
#include "mlir/Dialect/Affine/Utils.h"
18-
#include "mlir/Dialect/Func/IR/FuncOps.h"
1917
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2018
#include "mlir/IR/AffineExpr.h"
21-
#include "mlir/IR/Matchers.h"
2219
#include "mlir/IR/Operation.h"
23-
#include "mlir/Pass/Pass.h"
2420
#include "llvm/Support/Debug.h"
2521

2622
namespace mlir {
@@ -37,96 +33,6 @@ using namespace mlir::affine;
3733

3834
namespace {
3935

40-
/// Find the index of the given value in the `dims` list,
41-
/// and append it if it was not already in the list. The
42-
/// dims list is a list of symbols or dimensions of the
43-
/// affine map. Within the results of an affine map, they
44-
/// are identified by their index, which is why we need
45-
/// this function.
46-
static std::optional<size_t>
47-
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
48-
function_ref<bool(Value)> isValidElement) {
49-
50-
Value *loopIV = std::find(dims.begin(), dims.end(), value);
51-
if (loopIV != dims.end()) {
52-
// We found an IV that already has an index, return that index.
53-
return {std::distance(dims.begin(), loopIV)};
54-
}
55-
if (isValidElement(value)) {
56-
// This is a valid element for the dim/symbol list, push this as a
57-
// parameter.
58-
size_t idx = dims.size();
59-
dims.push_back(value);
60-
return idx;
61-
}
62-
return std::nullopt;
63-
}
64-
65-
/// Convert a value to an affine expr if possible. Adds dims and symbols
66-
/// if needed.
67-
static AffineExpr toAffineExpr(Value value,
68-
llvm::SmallVectorImpl<Value> &affineDims,
69-
llvm::SmallVectorImpl<Value> &affineSymbols) {
70-
using namespace matchers;
71-
IntegerAttr::ValueType cst;
72-
if (matchPattern(value, m_ConstantInt(&cst))) {
73-
return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
74-
}
75-
Value lhs;
76-
Value rhs;
77-
if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
78-
matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
79-
AffineExpr lhsE;
80-
AffineExpr rhsE;
81-
if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) &&
82-
(rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) {
83-
AffineExprKind kind;
84-
if (isa<arith::AddIOp>(value.getDefiningOp())) {
85-
kind = mlir::AffineExprKind::Add;
86-
} else {
87-
kind = mlir::AffineExprKind::Mul;
88-
}
89-
return getAffineBinaryOpExpr(kind, lhsE, rhsE);
90-
}
91-
}
92-
93-
if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
94-
return affine::isValidSymbol(v);
95-
})) {
96-
return getAffineSymbolExpr(*dimIx, value.getContext());
97-
}
98-
99-
if (auto dimIx = findInListOrAdd(
100-
value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
101-
102-
return getAffineDimExpr(*dimIx, value.getContext());
103-
}
104-
105-
return {};
106-
}
107-
108-
static LogicalResult
109-
computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
110-
llvm::SmallVectorImpl<Value> &mapArgs) {
111-
SmallVector<AffineExpr> results;
112-
SmallVector<Value> symbols;
113-
SmallVector<Value> dims;
114-
115-
for (Value indexExpr : indices) {
116-
AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
117-
if (!res) {
118-
return failure();
119-
}
120-
results.push_back(res);
121-
}
122-
123-
map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
124-
125-
dims.append(symbols);
126-
mapArgs.swap(dims);
127-
return success();
128-
}
129-
13036
struct RaiseMemrefDialect
13137
: public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
13238

@@ -140,8 +46,8 @@ struct RaiseMemrefDialect
14046
rewriter.setInsertionPoint(op);
14147
if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
14248

143-
if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
144-
mapArgs))) {
49+
if (succeeded(affine::convertValuesToAffineMapAndArgs(
50+
ctx, store.getIndices(), map, mapArgs))) {
14551
rewriter.replaceOpWithNewOp<AffineStoreOp>(
14652
op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
14753
return;
@@ -151,8 +57,8 @@ struct RaiseMemrefDialect
15157
<< "[affine] Cannot raise memref op: " << op << "\n");
15258

15359
} else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
154-
if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
155-
mapArgs))) {
60+
if (succeeded(affine::convertValuesToAffineMapAndArgs(
61+
ctx, load.getIndices(), map, mapArgs))) {
15662
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
15763
mapArgs);
15864
return;

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,19 @@
2121
#include "mlir/Dialect/Func/IR/FuncOps.h"
2222
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2323
#include "mlir/Dialect/Utils/IndexingUtils.h"
24+
#include "mlir/IR/AffineExpr.h"
2425
#include "mlir/IR/AffineExprVisitor.h"
2526
#include "mlir/IR/Dominance.h"
2627
#include "mlir/IR/IRMapping.h"
2728
#include "mlir/IR/ImplicitLocOpBuilder.h"
2829
#include "mlir/IR/IntegerSet.h"
30+
#include "mlir/IR/OpDefinition.h"
2931
#include "mlir/IR/Operation.h"
3032
#include "mlir/IR/PatternMatch.h"
3133
#include "mlir/Interfaces/SideEffectInterfaces.h"
3234
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35+
#include "llvm/ADT/ArrayRef.h"
36+
#include "llvm/Support/Casting.h"
3337
#include "llvm/Support/LogicalResult.h"
3438
#include <optional>
3539
#include <tuple>
@@ -2203,3 +2207,136 @@ OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
22032207
return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
22042208
multiIndexAndStrides);
22052209
}
2210+
2211+
namespace {
2212+
2213+
/// Find the index of the given value in the `dims` list,
2214+
/// and append it if it was not already in the list. The
2215+
/// dims list is a list of symbols or dimensions of the
2216+
/// affine map. Within the results of an affine map, they
2217+
/// are identified by their index, which is why we need
2218+
/// this function.
2219+
static std::optional<size_t>
2220+
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
2221+
function_ref<bool(Value)> isValidElement) {
2222+
2223+
Value *loopIV = std::find(dims.begin(), dims.end(), value);
2224+
if (loopIV != dims.end()) {
2225+
// We found an IV that already has an index, return that index.
2226+
return {std::distance(dims.begin(), loopIV)};
2227+
}
2228+
if (isValidElement(value)) {
2229+
// This is a valid element for the dim/symbol list, push this as a
2230+
// parameter.
2231+
size_t idx = dims.size();
2232+
dims.push_back(value);
2233+
return idx;
2234+
}
2235+
return std::nullopt;
2236+
}
2237+
2238+
/// Convert a value to an affine expr if possible. Adds dims and symbols
2239+
/// if needed.
2240+
static AffineExpr toAffineExpr(Value value,
2241+
llvm::SmallVectorImpl<Value> &affineDims,
2242+
llvm::SmallVectorImpl<Value> &affineSymbols) {
2243+
using namespace matchers;
2244+
IntegerAttr::ValueType cst;
2245+
if (matchPattern(value, m_ConstantInt(&cst))) {
2246+
return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
2247+
}
2248+
Value lhs;
2249+
Value rhs;
2250+
if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
2251+
matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
2252+
AffineExpr lhsE;
2253+
AffineExpr rhsE;
2254+
if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) &&
2255+
(rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) {
2256+
AffineExprKind kind;
2257+
if (isa<arith::AddIOp>(value.getDefiningOp())) {
2258+
kind = mlir::AffineExprKind::Add;
2259+
} else {
2260+
kind = mlir::AffineExprKind::Mul;
2261+
}
2262+
return getAffineBinaryOpExpr(kind, lhsE, rhsE);
2263+
}
2264+
}
2265+
2266+
if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
2267+
return affine::isValidSymbol(v);
2268+
})) {
2269+
return getAffineSymbolExpr(*dimIx, value.getContext());
2270+
}
2271+
2272+
if (auto dimIx = findInListOrAdd(
2273+
value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
2274+
2275+
return getAffineDimExpr(*dimIx, value.getContext());
2276+
}
2277+
2278+
return {};
2279+
}
2280+
2281+
} // namespace
2282+
2283+
LogicalResult mlir::affine::convertValuesToAffineMapAndArgs(
2284+
MLIRContext *ctx, ValueRange indices, AffineMap &map,
2285+
llvm::SmallVectorImpl<Value> &mapArgs) {
2286+
SmallVector<AffineExpr> results;
2287+
SmallVector<Value> symbols;
2288+
SmallVector<Value> dims;
2289+
2290+
for (Value indexExpr : indices) {
2291+
AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
2292+
if (!res) {
2293+
return failure();
2294+
}
2295+
results.push_back(res);
2296+
}
2297+
2298+
map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
2299+
2300+
dims.append(symbols);
2301+
mapArgs.swap(dims);
2302+
return success();
2303+
}
2304+
2305+
LogicalResult mlir::affine::convertValuesToAffineMapAndArgs(
2306+
MLIRContext *ctx, ArrayRef<OpFoldResult> indices, AffineMap &map,
2307+
llvm::SmallVectorImpl<OpFoldResult> &mapArgs) {
2308+
SmallVector<AffineExpr> results;
2309+
SmallVector<Value> symbols;
2310+
SmallVector<Value> dims;
2311+
SmallVector<OpFoldResult> constantSymbols;
2312+
2313+
for (OpFoldResult indexExpr : indices) {
2314+
if (auto asValue = llvm::dyn_cast_or_null<Value>(indexExpr)) {
2315+
AffineExpr res = toAffineExpr(asValue, dims, symbols);
2316+
if (!res) {
2317+
return failure();
2318+
}
2319+
results.push_back(res);
2320+
} else {
2321+
constantSymbols.push_back(indexExpr);
2322+
results.push_back(getAffineSymbolExpr(symbols.size(), ctx));
2323+
// add a null symbol here to increment the next symbol id.
2324+
symbols.emplace_back();
2325+
}
2326+
}
2327+
2328+
map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
2329+
2330+
for (auto dim : dims) {
2331+
mapArgs.push_back(dim);
2332+
}
2333+
unsigned nextConstSymbol = 0;
2334+
for (auto symbol : symbols) {
2335+
if (symbol) {
2336+
mapArgs.push_back(symbol);
2337+
} else {
2338+
mapArgs.push_back(constantSymbols[nextConstSymbol++]);
2339+
}
2340+
}
2341+
return success();
2342+
}

mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1010
#include "mlir/Dialect/Affine/IR/AffineOps.h"
11+
#include "mlir/Dialect/Affine/Utils.h"
1112
#include "mlir/Dialect/Arith/Utils/Utils.h"
13+
#include "mlir/IR/AffineExpr.h"
14+
#include "mlir/IR/AffineMap.h"
1215
#include "mlir/IR/PatternMatch.h"
1316

1417
using namespace mlir;
@@ -78,26 +81,6 @@ LogicalResult mlir::affine::mergeOffsetsSizesAndStrides(
7881
combinedOffsets, combinedSizes, combinedStrides);
7982
}
8083

81-
static AffineMap bindSymbolsOrDims(
82-
MLIRContext *ctx, llvm::ArrayRef<OpFoldResult> operands,
83-
function_ref<AffineExpr(llvm::SmallVectorImpl<AffineExpr> &)> makeExpr) {
84-
SmallVector<AffineExpr, 4> affineExprs(operands.size());
85-
unsigned symbolCount = 0;
86-
unsigned dimCount = 0;
87-
for (auto [e, value] : llvm::zip_equal(affineExprs, operands)) {
88-
auto asValue = llvm::dyn_cast_or_null<Value>(value);
89-
if (asValue && !affine::isValidSymbol(asValue) &&
90-
affine::isValidDim(asValue)) {
91-
e = getAffineDimExpr(dimCount++, ctx);
92-
} else {
93-
// This is also done if it is not a valid symbol but
94-
// we don't care, we need a fallback.
95-
e = getAffineSymbolExpr(symbolCount++, ctx);
96-
}
97-
}
98-
return AffineMap::get(dimCount, symbolCount, makeExpr(affineExprs));
99-
}
100-
10184
void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
10285
RewriterBase &rewriter, Location loc,
10386
ArrayRef<OpFoldResult> mixedSourceOffsets,
@@ -120,12 +103,23 @@ void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
120103
resolvedIndices.clear();
121104
for (auto [offset, index, stride] :
122105
llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) {
123-
auto affineMap =
124-
bindSymbolsOrDims(rewriter.getContext(), {offset, index, stride},
125-
[](auto &e) { return e[0] + e[1] * e[2]; });
106+
AffineMap map;
107+
SmallVector<OpFoldResult> mapArgs;
108+
auto *ctx = rewriter.getContext();
109+
if (failed(affine::convertValuesToAffineMapAndArgs(
110+
ctx, {offset, index, stride}, map, mapArgs))) {
111+
// todo
112+
resolvedIndices.push_back(Value{});
113+
continue;
114+
}
115+
AffineExpr off, ix, str;
116+
bindDims(ctx, off, ix, str);
117+
auto nextMap = AffineMap::get(3, 0, off + ix * str);
118+
auto composedMap = nextMap.compose(map);
119+
120+
OpFoldResult ofr =
121+
makeComposedFoldedAffineApply(rewriter, loc, composedMap, mapArgs);
126122

127-
OpFoldResult ofr = makeComposedFoldedAffineApply(rewriter, loc, affineMap,
128-
{offset, index, stride});
129123
resolvedIndices.push_back(
130124
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
131125
}

0 commit comments

Comments
 (0)