@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
8484 return result;
8585}
8686
87+ LogicalResult
88+ affine::lowerAffineDelinearizeIndexOp (RewriterBase &rewriter,
89+ AffineDelinearizeIndexOp op) {
90+ Location loc = op.getLoc ();
91+ Value linearIdx = op.getLinearIndex ();
92+ unsigned numResults = op.getNumResults ();
93+ ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
94+ if (numResults == staticBasis.size ())
95+ staticBasis = staticBasis.drop_front ();
96+
97+ if (numResults == 1 ) {
98+ rewriter.replaceOp (op, linearIdx);
99+ return success ();
100+ }
101+
102+ SmallVector<Value> results;
103+ results.reserve (numResults);
104+ SmallVector<Value> strides =
105+ computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis,
106+ /* knownNonNegative=*/ true );
107+
108+ Value zero = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
109+
110+ Value initialPart =
111+ rewriter.create <arith::FloorDivSIOp>(loc, linearIdx, strides.front ());
112+ results.push_back (initialPart);
113+
114+ auto emitModTerm = [&](Value stride) -> Value {
115+ Value remainder = rewriter.create <arith::RemSIOp>(loc, linearIdx, stride);
116+ Value remainderNegative = rewriter.create <arith::CmpIOp>(
117+ loc, arith::CmpIPredicate::slt, remainder, zero);
118+ // If the correction is relevant, this term is <= stride, which is known
119+ // to be positive in `index`. Otherwise, while 2 * stride might overflow,
120+ // this branch won't be taken, so the risk of `poison` is fine.
121+ Value corrected = rewriter.create <arith::AddIOp>(
122+ loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
123+ Value mod = rewriter.create <arith::SelectOp>(loc, remainderNegative,
124+ corrected, remainder);
125+ return mod;
126+ };
127+
128+ // Generate all the intermediate parts
129+ for (size_t i = 0 , e = strides.size () - 1 ; i < e; ++i) {
130+ Value thisStride = strides[i];
131+ Value nextStride = strides[i + 1 ];
132+ Value modulus = emitModTerm (thisStride);
133+ // We know both inputs are positive, so floorDiv == div.
134+ // This could potentially be a divui, but it's not clear if that would
135+ // cause issues.
136+ Value divided = rewriter.create <arith::DivSIOp>(loc, modulus, nextStride);
137+ results.push_back (divided);
138+ }
139+
140+ results.push_back (emitModTerm (strides.back ()));
141+
142+ rewriter.replaceOp (op, results);
143+ return success ();
144+ }
145+
146+ LogicalResult affine::lowerAffineLinearizeIndexOp (RewriterBase &rewriter,
147+ AffineLinearizeIndexOp op) {
148+ // Should be folded away, included here for safety.
149+ if (op.getMultiIndex ().empty ()) {
150+ rewriter.replaceOpWithNewOp <arith::ConstantIndexOp>(op, 0 );
151+ return success ();
152+ }
153+
154+ Location loc = op.getLoc ();
155+ ValueRange multiIndex = op.getMultiIndex ();
156+ size_t numIndexes = multiIndex.size ();
157+ ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
158+ if (numIndexes == staticBasis.size ())
159+ staticBasis = staticBasis.drop_front ();
160+
161+ SmallVector<Value> strides =
162+ computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis,
163+ /* knownNonNegative=*/ op.getDisjoint ());
164+ SmallVector<std::pair<Value, int64_t >> scaledValues;
165+ scaledValues.reserve (numIndexes);
166+
167+ // Note: strides doesn't contain a value for the final element (stride 1)
168+ // and everything else lines up. We use the "mutable" accessor so we can get
169+ // our hands on an `OpOperand&` for the loop invariant counting function.
170+ for (auto [stride, idxOp] :
171+ llvm::zip_equal (strides, llvm::drop_end (op.getMultiIndexMutable ()))) {
172+ Value scaledIdx = rewriter.create <arith::MulIOp>(
173+ loc, idxOp.get (), stride, arith::IntegerOverflowFlags::nsw);
174+ int64_t numHoistableLoops = numEnclosingInvariantLoops (idxOp);
175+ scaledValues.emplace_back (scaledIdx, numHoistableLoops);
176+ }
177+ scaledValues.emplace_back (
178+ multiIndex.back (),
179+ numEnclosingInvariantLoops (op.getMultiIndexMutable ()[numIndexes - 1 ]));
180+
181+ // Sort by how many enclosing loops there are, ties implicitly broken by
182+ // size of the stride.
183+ llvm::stable_sort (scaledValues,
184+ [&](auto l, auto r) { return l.second > r.second ; });
185+
186+ Value result = scaledValues.front ().first ;
187+ for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin (scaledValues)) {
188+ std::ignore = numHoistableLoops;
189+ result = rewriter.create <arith::AddIOp>(loc, result, scaledValue,
190+ arith::IntegerOverflowFlags::nsw);
191+ }
192+ rewriter.replaceOp (op, result);
193+ return success ();
194+ }
195+
87196namespace {
88- // / Lowers `affine.delinearize_index` into a sequence of division and remainder
89- // / operations.
90197struct LowerDelinearizeIndexOps
91198 : public OpRewritePattern<AffineDelinearizeIndexOp> {
92199 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
93200 LogicalResult matchAndRewrite (AffineDelinearizeIndexOp op,
94201 PatternRewriter &rewriter) const override {
95- Location loc = op.getLoc ();
96- Value linearIdx = op.getLinearIndex ();
97- unsigned numResults = op.getNumResults ();
98- ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
99- if (numResults == staticBasis.size ())
100- staticBasis = staticBasis.drop_front ();
101-
102- if (numResults == 1 ) {
103- rewriter.replaceOp (op, linearIdx);
104- return success ();
105- }
106-
107- SmallVector<Value> results;
108- results.reserve (numResults);
109- SmallVector<Value> strides =
110- computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis,
111- /* knownNonNegative=*/ true );
112-
113- Value zero = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
114-
115- Value initialPart =
116- rewriter.create <arith::FloorDivSIOp>(loc, linearIdx, strides.front ());
117- results.push_back (initialPart);
118-
119- auto emitModTerm = [&](Value stride) -> Value {
120- Value remainder = rewriter.create <arith::RemSIOp>(loc, linearIdx, stride);
121- Value remainderNegative = rewriter.create <arith::CmpIOp>(
122- loc, arith::CmpIPredicate::slt, remainder, zero);
123- // If the correction is relevant, this term is <= stride, which is known
124- // to be positive in `index`. Otherwise, while 2 * stride might overflow,
125- // this branch won't be taken, so the risk of `poison` is fine.
126- Value corrected = rewriter.create <arith::AddIOp>(
127- loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
128- Value mod = rewriter.create <arith::SelectOp>(loc, remainderNegative,
129- corrected, remainder);
130- return mod;
131- };
132-
133- // Generate all the intermediate parts
134- for (size_t i = 0 , e = strides.size () - 1 ; i < e; ++i) {
135- Value thisStride = strides[i];
136- Value nextStride = strides[i + 1 ];
137- Value modulus = emitModTerm (thisStride);
138- // We know both inputs are positive, so floorDiv == div.
139- // This could potentially be a divui, but it's not clear if that would
140- // cause issues.
141- Value divided = rewriter.create <arith::DivSIOp>(loc, modulus, nextStride);
142- results.push_back (divided);
143- }
144-
145- results.push_back (emitModTerm (strides.back ()));
146-
147- rewriter.replaceOp (op, results);
148- return success ();
202+ return affine::lowerAffineDelinearizeIndexOp (rewriter, op);
149203 }
150204};
151205
152- // / Lowers `affine.linearize_index` into a sequence of multiplications and
153- // / additions. Make a best effort to sort the input indices so that
154- // / the most loop-invariant terms are at the left of the additions
155- // / to enable loop-invariant code motion.
156206struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
157207 using OpRewritePattern::OpRewritePattern;
158208 LogicalResult matchAndRewrite (AffineLinearizeIndexOp op,
159209 PatternRewriter &rewriter) const override {
160- // Should be folded away, included here for safety.
161- if (op.getMultiIndex ().empty ()) {
162- rewriter.replaceOpWithNewOp <arith::ConstantIndexOp>(op, 0 );
163- return success ();
164- }
165-
166- Location loc = op.getLoc ();
167- ValueRange multiIndex = op.getMultiIndex ();
168- size_t numIndexes = multiIndex.size ();
169- ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
170- if (numIndexes == staticBasis.size ())
171- staticBasis = staticBasis.drop_front ();
172-
173- SmallVector<Value> strides =
174- computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis,
175- /* knownNonNegative=*/ op.getDisjoint ());
176- SmallVector<std::pair<Value, int64_t >> scaledValues;
177- scaledValues.reserve (numIndexes);
178-
179- // Note: strides doesn't contain a value for the final element (stride 1)
180- // and everything else lines up. We use the "mutable" accessor so we can get
181- // our hands on an `OpOperand&` for the loop invariant counting function.
182- for (auto [stride, idxOp] :
183- llvm::zip_equal (strides, llvm::drop_end (op.getMultiIndexMutable ()))) {
184- Value scaledIdx = rewriter.create <arith::MulIOp>(
185- loc, idxOp.get (), stride, arith::IntegerOverflowFlags::nsw);
186- int64_t numHoistableLoops = numEnclosingInvariantLoops (idxOp);
187- scaledValues.emplace_back (scaledIdx, numHoistableLoops);
188- }
189- scaledValues.emplace_back (
190- multiIndex.back (),
191- numEnclosingInvariantLoops (op.getMultiIndexMutable ()[numIndexes - 1 ]));
192-
193- // Sort by how many enclosing loops there are, ties implicitly broken by
194- // size of the stride.
195- llvm::stable_sort (scaledValues,
196- [&](auto l, auto r) { return l.second > r.second ; });
197-
198- Value result = scaledValues.front ().first ;
199- for (auto [scaledValue, numHoistableLoops] :
200- llvm::drop_begin (scaledValues)) {
201- std::ignore = numHoistableLoops;
202- result = rewriter.create <arith::AddIOp>(loc, result, scaledValue,
203- arith::IntegerOverflowFlags::nsw);
204- }
205- rewriter.replaceOp (op, result);
206- return success ();
210+ return affine::lowerAffineLinearizeIndexOp (rewriter, op);
207211 }
208212};
209213
0 commit comments