@@ -46,6 +46,17 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
4646 return inferredRange.getConstantValue ();
4747}
4848
49+ static void copyIntegerRange (DataFlowSolver &solver, Value oldVal,
50+ Value newVal) {
51+ assert (oldVal.getType () == newVal.getType () &&
52+ " Can't copy integer ranges between different types" );
53+ auto *oldState = solver.lookupState <IntegerValueRangeLattice>(oldVal);
54+ if (!oldState)
55+ return ;
56+ (void )solver.getOrCreateState <IntegerValueRangeLattice>(newVal)->join (
57+ *oldState);
58+ }
59+
4960// / Patterned after SCCP
5061static LogicalResult maybeReplaceWithConstant (DataFlowSolver &solver,
5162 PatternRewriter &rewriter,
@@ -80,6 +91,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
8091 if (!constOp)
8192 return failure ();
8293
94+ copyIntegerRange (solver, value, constOp->getResult (0 ));
8395 rewriter.replaceAllUsesWith (value, constOp->getResult (0 ));
8496 return success ();
8597}
@@ -195,56 +207,21 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
195207 DataFlowSolver &solver;
196208};
197209
198- // / Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
199- static LogicalResult checkIntType (Type type, unsigned targetBitwidth) {
200- Type elemType = getElementTypeOrSelf (type);
201- if (isa<IndexType>(elemType))
202- return success ();
203-
204- if (auto intType = dyn_cast<IntegerType>(elemType))
205- if (intType.getWidth () > targetBitwidth)
206- return success ();
207-
208- return failure ();
209- }
210-
211- // / Check if op have same type for all operands and results and this type
212- // / is suitable for truncation.
213- static LogicalResult checkElementwiseOpType (Operation *op,
214- unsigned targetBitwidth) {
215- if (op->getNumOperands () == 0 || op->getNumResults () == 0 )
216- return failure ();
217-
218- Type type;
219- for (Value val : llvm::concat<Value>(op->getOperands (), op->getResults ())) {
220- if (!type) {
221- type = val.getType ();
222- continue ;
223- }
224-
225- if (type != val.getType ())
226- return failure ();
227- }
228-
229- return checkIntType (type, targetBitwidth);
230- }
231-
232- // / Return union of all operands values ranges.
233- static std::optional<ConstantIntRanges> getOperandsRange (DataFlowSolver &solver,
234- ValueRange operands) {
235- std::optional<ConstantIntRanges> ret;
236- for (Value value : operands) {
210+ // / Gather ranges for all the values in `values`. Appends to the existing
211+ // / vector.
212+ static LogicalResult collectRanges (DataFlowSolver &solver, ValueRange values,
213+ SmallVectorImpl<ConstantIntRanges> &ranges) {
214+ for (Value val : values) {
237215 auto *maybeInferredRange =
238- solver.lookupState <IntegerValueRangeLattice>(value );
216+ solver.lookupState <IntegerValueRangeLattice>(val );
239217 if (!maybeInferredRange || maybeInferredRange->getValue ().isUninitialized ())
240- return std:: nullopt ;
218+ return failure () ;
241219
242220 const ConstantIntRanges &inferredRange =
243221 maybeInferredRange->getValue ().getValue ();
244-
245- ret = (ret ? ret->rangeUnion (inferredRange) : inferredRange);
222+ ranges.push_back (inferredRange);
246223 }
247- return ret ;
224+ return success () ;
248225}
249226
250227// / Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
@@ -258,56 +235,79 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
258235 return dstType;
259236}
260237
261- // / Check provided `range` is inside `smin, smax, umin, umax` bounds.
262- static LogicalResult checkRange (const ConstantIntRanges &range, APInt smin,
263- APInt smax, APInt umin, APInt umax) {
264- auto sge = [](APInt val1, APInt val2) -> bool {
265- unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
266- val1 = val1.sext (width);
267- val2 = val2.sext (width);
268- return val1.sge (val2);
269- };
270- auto sle = [](APInt val1, APInt val2) -> bool {
271- unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
272- val1 = val1.sext (width);
273- val2 = val2.sext (width);
274- return val1.sle (val2);
275- };
276- auto uge = [](APInt val1, APInt val2) -> bool {
277- unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
278- val1 = val1.zext (width);
279- val2 = val2.zext (width);
280- return val1.uge (val2);
281- };
282- auto ule = [](APInt val1, APInt val2) -> bool {
283- unsigned width = std::max (val1.getBitWidth (), val2.getBitWidth ());
284- val1 = val1.zext (width);
285- val2 = val2.zext (width);
286- return val1.ule (val2);
287- };
288- return success (sge (range.smin (), smin) && sle (range.smax (), smax) &&
289- uge (range.umin (), umin) && ule (range.umax (), umax));
238+ namespace {
239+ // Enum for tracking which type of truncation should be performed
240+ // to narrow an operation, if any.
241+ enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
242+ } // namespace
243+
244+ // / If the values within `range` can be represented using only `width` bits,
245+ // / return the kind of truncation needed to preserve that property.
246+ // /
247+ // / This check relies on the fact that the signed and unsigned ranges are both
248+ // / always correct, but that one might be an approximation of the other,
249+ // / so we want to use the correct truncation operation.
250+ static CastKind checkTruncatability (const ConstantIntRanges &range,
251+ unsigned targetWidth) {
252+ unsigned srcWidth = range.smin ().getBitWidth ();
253+ if (srcWidth <= targetWidth)
254+ return CastKind::None;
255+ unsigned removedWidth = srcWidth - targetWidth;
256+ // The sign bits need to extend into the sign bit of the target width. For
257+ // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
258+ // bits.
259+ bool canTruncateSigned =
260+ range.smin ().getNumSignBits () >= (removedWidth + 1 ) &&
261+ range.smax ().getNumSignBits () >= (removedWidth + 1 );
262+ bool canTruncateUnsigned = range.umin ().countLeadingZeros () >= removedWidth &&
263+ range.umax ().countLeadingZeros () >= removedWidth;
264+ if (canTruncateSigned && canTruncateUnsigned)
265+ return CastKind::Both;
266+ if (canTruncateSigned)
267+ return CastKind::Signed;
268+ if (canTruncateUnsigned)
269+ return CastKind::Unsigned;
270+ return CastKind::None;
271+ }
272+
273+ static CastKind mergeCastKinds (CastKind lhs, CastKind rhs) {
274+ if (lhs == CastKind::None || rhs == CastKind::None)
275+ return CastKind::None;
276+ if (lhs == CastKind::Both)
277+ return rhs;
278+ if (rhs == CastKind::Both)
279+ return lhs;
280+ if (lhs == rhs)
281+ return lhs;
282+ return CastKind::None;
290283}
291284
292- static Value doCast (OpBuilder &builder, Location loc, Value src, Type dstType) {
285+ static Value doCast (OpBuilder &builder, Location loc, Value src, Type dstType,
286+ CastKind castKind) {
293287 Type srcType = src.getType ();
294288 assert (isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
295289 " Mixing vector and non-vector types" );
290+ assert (castKind != CastKind::None && " Can't cast when casting isn't allowed" );
296291 Type srcElemType = getElementTypeOrSelf (srcType);
297292 Type dstElemType = getElementTypeOrSelf (dstType);
298293 assert (srcElemType.isIntOrIndex () && " Invalid src type" );
299294 assert (dstElemType.isIntOrIndex () && " Invalid dst type" );
300295 if (srcType == dstType)
301296 return src;
302297
303- if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
298+ if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
299+ if (castKind == CastKind::Signed)
300+ return builder.create <arith::IndexCastOp>(loc, dstType, src);
304301 return builder.create <arith::IndexCastUIOp>(loc, dstType, src);
302+ }
305303
306304 auto srcInt = cast<IntegerType>(srcElemType);
307305 auto dstInt = cast<IntegerType>(dstElemType);
308306 if (dstInt.getWidth () < srcInt.getWidth ())
309307 return builder.create <arith::TruncIOp>(loc, dstType, src);
310308
309+ if (castKind == CastKind::Signed)
310+ return builder.create <arith::ExtSIOp>(loc, dstType, src);
311311 return builder.create <arith::ExtUIOp>(loc, dstType, src);
312312}
313313
@@ -319,36 +319,47 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
319319 using OpTraitRewritePattern::OpTraitRewritePattern;
320320 LogicalResult matchAndRewrite (Operation *op,
321321 PatternRewriter &rewriter) const override {
322- std::optional<ConstantIntRanges> range =
323- getOperandsRange (solver, op->getResults ());
324- if (!range)
325- return failure ();
322+ if (op->getNumResults () == 0 )
323+ return rewriter.notifyMatchFailure (op, " can't narrow resultless op" );
324+
325+ SmallVector<ConstantIntRanges> ranges;
326+ if (failed (collectRanges (solver, op->getOperands (), ranges)))
327+ return rewriter.notifyMatchFailure (op, " input without specified range" );
328+ if (failed (collectRanges (solver, op->getResults (), ranges)))
329+ return rewriter.notifyMatchFailure (op, " output without specified range" );
330+
331+ Type srcType = op->getResult (0 ).getType ();
332+ if (!llvm::all_equal (op->getResultTypes ()))
333+ return rewriter.notifyMatchFailure (op, " mismatched result types" );
334+ if (op->getNumOperands () == 0 ||
335+ !llvm::all_of (op->getOperandTypes (),
336+ [=](Type t) { return t == srcType; }))
337+ return rewriter.notifyMatchFailure (
338+ op, " no operands or operand types don't match result type" );
326339
327340 for (unsigned targetBitwidth : targetBitwidths) {
328- if (failed (checkElementwiseOpType (op, targetBitwidth)))
329- continue ;
330-
331- Type srcType = op->getResult (0 ).getType ();
332-
333- // We are truncating op args to the desired bitwidth before the op and
334- // then extending op results back to the original width after. extui and
335- // exti will produce different results for negative values, so limit
336- // signed range to non-negative values.
337- auto smin = APInt::getZero (targetBitwidth);
338- auto smax = APInt::getSignedMaxValue (targetBitwidth);
339- auto umin = APInt::getMinValue (targetBitwidth);
340- auto umax = APInt::getMaxValue (targetBitwidth);
341- if (failed (checkRange (*range, smin, smax, umin, umax)))
341+ CastKind castKind = CastKind::Both;
342+ for (const ConstantIntRanges &range : ranges) {
343+ castKind = mergeCastKinds (castKind,
344+ checkTruncatability (range, targetBitwidth));
345+ if (castKind == CastKind::None)
346+ break ;
347+ }
348+ if (castKind == CastKind::None)
342349 continue ;
343-
344350 Type targetType = getTargetType (srcType, targetBitwidth);
345351 if (targetType == srcType)
346352 continue ;
347353
348354 Location loc = op->getLoc ();
349355 IRMapping mapping;
350- for (Value arg : op->getOperands ()) {
351- Value newArg = doCast (rewriter, loc, arg, targetType);
356+ for (auto [arg, argRange] : llvm::zip_first (op->getOperands (), ranges)) {
357+ CastKind argCastKind = castKind;
358+ // When dealing with `index` values, preserve non-negativity in the
359+ // index_casts since we can't recover this in unsigned when equivalent.
360+ if (argCastKind == CastKind::Signed && argRange.smin ().isNonNegative ())
361+ argCastKind = CastKind::Both;
362+ Value newArg = doCast (rewriter, loc, arg, targetType, argCastKind);
352363 mapping.map (arg, newArg);
353364 }
354365
@@ -359,8 +370,12 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
359370 }
360371 });
361372 SmallVector<Value> newResults;
362- for (Value res : newOp->getResults ())
363- newResults.emplace_back (doCast (rewriter, loc, res, srcType));
373+ for (auto [newRes, oldRes] :
374+ llvm::zip_equal (newOp->getResults (), op->getResults ())) {
375+ Value castBack = doCast (rewriter, loc, newRes, srcType, castKind);
376+ copyIntegerRange (solver, oldRes, castBack);
377+ newResults.push_back (castBack);
378+ }
364379
365380 rewriter.replaceOp (op, newResults);
366381 return success ();
@@ -382,21 +397,19 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
382397 Value lhs = op.getLhs ();
383398 Value rhs = op.getRhs ();
384399
385- std::optional<ConstantIntRanges> range =
386- getOperandsRange (solver, {lhs, rhs});
387- if (!range)
400+ SmallVector<ConstantIntRanges> ranges;
401+ if (failed (collectRanges (solver, op.getOperands (), ranges)))
388402 return failure ();
403+ const ConstantIntRanges &lhsRange = ranges[0 ];
404+ const ConstantIntRanges &rhsRange = ranges[1 ];
389405
406+ Type srcType = lhs.getType ();
390407 for (unsigned targetBitwidth : targetBitwidths) {
391- Type srcType = lhs.getType ();
392- if (failed (checkIntType (srcType, targetBitwidth)))
393- continue ;
394-
395- auto smin = APInt::getSignedMinValue (targetBitwidth);
396- auto smax = APInt::getSignedMaxValue (targetBitwidth);
397- auto umin = APInt::getMinValue (targetBitwidth);
398- auto umax = APInt::getMaxValue (targetBitwidth);
399- if (failed (checkRange (*range, smin, smax, umin, umax)))
408+ CastKind lhsCastKind = checkTruncatability (lhsRange, targetBitwidth);
409+ CastKind rhsCastKind = checkTruncatability (rhsRange, targetBitwidth);
410+ CastKind castKind = mergeCastKinds (lhsCastKind, rhsCastKind);
411+ // Note: this includes target width > src width.
412+ if (castKind == CastKind::None)
400413 continue ;
401414
402415 Type targetType = getTargetType (srcType, targetBitwidth);
@@ -405,12 +418,13 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
405418
406419 Location loc = op->getLoc ();
407420 IRMapping mapping;
408- for ( Value arg : op-> getOperands ()) {
409- Value newArg = doCast (rewriter, loc, arg , targetType);
410- mapping.map (arg, newArg );
411- }
421+ Value lhsCast = doCast (rewriter, loc, lhs, targetType, lhsCastKind);
422+ Value rhsCast = doCast (rewriter, loc, rhs , targetType, rhsCastKind );
423+ mapping.map (lhs, lhsCast );
424+ mapping. map (rhs, rhsCast);
412425
413426 Operation *newOp = rewriter.clone (*op, mapping);
427+ copyIntegerRange (solver, op.getResult (), newOp->getResult (0 ));
414428 rewriter.replaceOp (op, newOp->getResults ());
415429 return success ();
416430 }
@@ -425,19 +439,23 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
425439// / Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
426440// / This pattern assumes all passed `targetBitwidths` are not wider than index
427441// / type.
428- struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
442+ template <typename CastOp>
443+ struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
429444 FoldIndexCastChain (MLIRContext *context, ArrayRef<unsigned > target)
430- : OpRewritePattern(context), targetBitwidths(target) {}
445+ : OpRewritePattern<CastOp> (context), targetBitwidths(target) {}
431446
432- LogicalResult matchAndRewrite (arith::IndexCastUIOp op,
447+ LogicalResult matchAndRewrite (CastOp op,
433448 PatternRewriter &rewriter) const override {
434- auto srcOp = op.getIn ().getDefiningOp <arith::IndexCastUIOp >();
449+ auto srcOp = op.getIn ().template getDefiningOp <CastOp >();
435450 if (!srcOp)
436- return failure ( );
451+ return rewriter. notifyMatchFailure (op, " doesn't come from an index cast " );
437452
438453 Value src = srcOp.getIn ();
439454 if (src.getType () != op.getType ())
440- return failure ();
455+ return rewriter.notifyMatchFailure (op, " outer types don't match" );
456+
457+ if (!srcOp.getType ().isIndex ())
458+ return rewriter.notifyMatchFailure (op, " intermediate type isn't index" );
441459
442460 auto intType = dyn_cast<IntegerType>(op.getType ());
443461 if (!intType || !llvm::is_contained (targetBitwidths, intType.getWidth ()))
@@ -517,7 +535,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
517535 ArrayRef<unsigned > bitwidthsSupported) {
518536 patterns.add <NarrowElementwise, NarrowCmpI>(patterns.getContext (), solver,
519537 bitwidthsSupported);
520- patterns.add <FoldIndexCastChain>(patterns.getContext (), bitwidthsSupported);
538+ patterns.add <FoldIndexCastChain<arith::IndexCastUIOp>,
539+ FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext (),
540+ bitwidthsSupported);
521541}
522542
523543std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass () {
0 commit comments