1010// fundamental operations.
1111// ===----------------------------------------------------------------------===//
1212
13+ #include " mlir/Dialect/Affine/LoopUtils.h"
1314#include " mlir/Dialect/Affine/Passes.h"
1415
1516#include " mlir/Dialect/Affine/IR/AffineOps.h"
@@ -28,6 +29,50 @@ namespace affine {
2829using namespace mlir ;
2930using namespace mlir ::affine;
3031
32+ // / Given a basis (in static and dynamic components), return the sequence of
33+ // / suffix products of the basis, including the product of the entire basis,
34+ // / which must **not** contain an outer bound.
35+ // /
36+ // / If excess dynamic values are provided, the values at the beginning
37+ // / will be ignored. This allows for dropping the outer bound without
38+ // / needing to manipulate the dynamic value array.
39+ static SmallVector<Value> computeStrides (Location loc, RewriterBase &rewriter,
40+ ValueRange dynamicBasis,
41+ ArrayRef<int64_t > staticBasis) {
42+ if (staticBasis.empty ())
43+ return {};
44+
45+ SmallVector<Value> result;
46+ result.reserve (staticBasis.size ());
47+ size_t dynamicIndex = dynamicBasis.size ();
48+ Value dynamicPart = nullptr ;
49+ int64_t staticPart = 1 ;
50+ for (int64_t elem : llvm::reverse (staticBasis)) {
51+ if (ShapedType::isDynamic (elem)) {
52+ if (dynamicPart)
53+ dynamicPart = rewriter.create <arith::MulIOp>(
54+ loc, dynamicPart, dynamicBasis[dynamicIndex - 1 ]);
55+ else
56+ dynamicPart = dynamicBasis[dynamicIndex - 1 ];
57+ --dynamicIndex;
58+ } else {
59+ staticPart *= elem;
60+ }
61+
62+ if (dynamicPart && staticPart == 1 ) {
63+ result.push_back (dynamicPart);
64+ } else {
65+ Value stride =
66+ rewriter.createOrFold <arith::ConstantIndexOp>(loc, staticPart);
67+ if (dynamicPart)
68+ stride = rewriter.create <arith::MulIOp>(loc, dynamicPart, stride);
69+ result.push_back (stride);
70+ }
71+ }
72+ std::reverse (result.begin (), result.end ());
73+ return result;
74+ }
75+
3176namespace {
3277// / Lowers `affine.delinearize_index` into a sequence of division and remainder
3378// / operations.
@@ -36,18 +81,62 @@ struct LowerDelinearizeIndexOps
3681 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
3782 LogicalResult matchAndRewrite (AffineDelinearizeIndexOp op,
3883 PatternRewriter &rewriter) const override {
39- FailureOr<SmallVector<Value>> multiIndex =
40- delinearizeIndex (rewriter, op->getLoc (), op.getLinearIndex (),
41- op.getEffectiveBasis (), /* hasOuterBound=*/ false );
42- if (failed (multiIndex))
43- return failure ();
44- rewriter.replaceOp (op, *multiIndex);
84+ Location loc = op.getLoc ();
85+ Value linearIdx = op.getLinearIndex ();
86+ unsigned numResults = op.getNumResults ();
87+ ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
88+ if (numResults == staticBasis.size ())
89+ staticBasis = staticBasis.drop_front ();
90+
91+ if (numResults == 1 ) {
92+ rewriter.replaceOp (op, linearIdx);
93+ return success ();
94+ }
95+
96+ SmallVector<Value> results;
97+ results.reserve (numResults);
98+ SmallVector<Value> strides =
99+ computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis);
100+
101+ Value zero = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
102+
103+ Value initialPart =
104+ rewriter.create <arith::FloorDivSIOp>(loc, linearIdx, strides.front ());
105+ results.push_back (initialPart);
106+
107+ auto emitModTerm = [&](Value stride) -> Value {
108+ Value remainder = rewriter.create <arith::RemSIOp>(loc, linearIdx, stride);
109+ Value remainderNegative = rewriter.create <arith::CmpIOp>(
110+ loc, arith::CmpIPredicate::slt, remainder, zero);
111+ Value corrected = rewriter.create <arith::AddIOp>(loc, remainder, stride);
112+ Value mod = rewriter.create <arith::SelectOp>(loc, remainderNegative,
113+ corrected, remainder);
114+ return mod;
115+ };
116+
117+ // Generate all the intermediate parts
118+ for (size_t i = 0 , e = strides.size () - 1 ; i < e; ++i) {
119+ Value thisStride = strides[i];
120+ Value nextStride = strides[i + 1 ];
121+ Value modulus = emitModTerm (thisStride);
122+ // We know both inputs are positive, so floorDiv == div.
123+ // This could potentially be a divui, but it's not clear if that would
124+ // cause issues.
125+ Value divided = rewriter.create <arith::DivSIOp>(loc, modulus, nextStride);
126+ results.push_back (divided);
127+ }
128+
129+ results.push_back (emitModTerm (strides.back ()));
130+
131+ rewriter.replaceOp (op, results);
45132 return success ();
46133 }
47134};
48135
49136// / Lowers `affine.linearize_index` into a sequence of multiplications and
50- // / additions.
137+ // / additions. Make a best effort to sort the input indices so that
138+ // / the most loop-invariant terms are at the left of the additions
139+ // / to enable loop-invariant code motion.
51140struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
52141 using OpRewritePattern::OpRewritePattern;
53142 LogicalResult matchAndRewrite (AffineLinearizeIndexOp op,
@@ -58,13 +147,44 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
58147 return success ();
59148 }
60149
61- SmallVector<OpFoldResult> multiIndex =
62- getAsOpFoldResult (op.getMultiIndex ());
63- OpFoldResult linearIndex =
64- linearizeIndex (rewriter, op.getLoc (), multiIndex, op.getMixedBasis ());
65- Value linearIndexValue =
66- getValueOrCreateConstantIntOp (rewriter, op.getLoc (), linearIndex);
67- rewriter.replaceOp (op, linearIndexValue);
150+ Location loc = op.getLoc ();
151+ ValueRange multiIndex = op.getMultiIndex ();
152+ size_t numIndexes = multiIndex.size ();
153+ ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
154+ if (numIndexes == staticBasis.size ())
155+ staticBasis = staticBasis.drop_front ();
156+
157+ SmallVector<Value> strides =
158+ computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis);
159+ SmallVector<std::pair<Value, int64_t >> scaledValues;
160+ scaledValues.reserve (numIndexes);
161+
162+ // Note: strides doesn't contain a value for the final element (stride 1)
163+ // and everything else lines up. We use the "mutable" accessor so we can get
164+ // our hands on an `OpOperand&` for the loop invariant counting function.
165+ for (auto [stride, idxOp] :
166+ llvm::zip_equal (strides, llvm::drop_end (op.getMultiIndexMutable ()))) {
167+ Value scaledIdx =
168+ rewriter.create <arith::MulIOp>(loc, idxOp.get (), stride);
169+ int64_t numHoistableLoops = numEnclosingInvariantLoops (idxOp);
170+ scaledValues.emplace_back (scaledIdx, numHoistableLoops);
171+ }
172+ scaledValues.emplace_back (
173+ multiIndex.back (),
174+ numEnclosingInvariantLoops (op.getMultiIndexMutable ()[numIndexes - 1 ]));
175+
176+ // Sort by how many enclosing loops there are, ties implicitly broken by
177+ // size of the stride.
178+ llvm::stable_sort (scaledValues,
179+ [&](auto l, auto r) { return l.second > r.second ; });
180+
181+ Value result = scaledValues.front ().first ;
182+ for (auto [scaledValue, numHoistableLoops] :
183+ llvm::drop_begin (scaledValues)) {
184+ std::ignore = numHoistableLoops;
185+ result = rewriter.create <arith::AddIOp>(loc, result, scaledValue);
186+ }
187+ rewriter.replaceOp (op, result);
68188 return success ();
69189 }
70190};
0 commit comments