diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index c7370b83fdb6c..15ea30ceca96d 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> { // Explicitly depend on "arith" because this pass could create operations in // `arith` out of thin air in some cases. let dependentDialects = [ - "::mlir::arith::ArithDialect" + "::mlir::arith::ArithDialect", + "::mlir::ub::UBDialect" ]; } diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h index 21de5cb0c182a..fc2dbad7a8aa7 100644 --- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h +++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h @@ -12,6 +12,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td index f3d5a26ef6f9b..db88838d15dfd 100644 --- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td +++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td @@ -9,8 +9,9 @@ #ifndef MLIR_DIALECT_UB_IR_UBOPS_TD #define MLIR_DIALECT_UB_IR_UBOPS_TD -include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/Interfaces/InferIntRangeInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "UBOpsInterfaces.td" @@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> { // PoisonOp //===----------------------------------------------------------------------===// -def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> { +def PoisonOp : UB_Op<"poison", [ConstantLike, Pure, + DeclareOpInterfaceMethods]> { let summary = "Poisoned constant operation."; let description = [{ The `poison` operation materializes a compile-time poisoned constant value diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h index 0e107e88f5232..4e4f3725a69fd 100644 --- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -51,6 +51,9 @@ class ConstantIntRanges { /// The maximum value of an integer when it is interpreted as signed. const APInt &smax() const; + /// Get the bitwidth of the ranges. + unsigned getBitWidth() const; + /// Return the bitwidth that should be used for integer ranges describing /// `type`. For concrete integer types, this is their bitwidth, for `index`, /// this is the internal storage bitwidth of `index` attributes, and for @@ -62,6 +65,10 @@ class ConstantIntRanges { /// sint_max(width)]. static ConstantIntRanges maxRange(unsigned bitwidth); + /// Create a poisoned range, i.e. a range that represents no valid integer + /// values. + static ConstantIntRanges poison(unsigned bitwidth); + /// Create a `ConstantIntRanges` with a constant value - that is, with the /// bounds [value, value] for both its signed interpretations. static ConstantIntRanges constant(const APInt &value); @@ -96,6 +103,14 @@ class ConstantIntRanges { /// value. std::optional getConstantValue() const; + /// Returns true if signed range is poisoned, i.e. no valid signed value + /// can be represented. + bool isSignedPoison() const; + + /// Returns true if unsigned range is poisoned, i.e. no valid unsigned value + /// can be represented. + bool isUnsignedPoison() const; + friend raw_ostream &operator<<(raw_ostream &os, const ConstantIntRanges &range); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 777ff0ecaa314..03da1e5327e39 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -14,6 +14,7 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" @@ -46,6 +47,16 @@ static std::optional getMaybeConstantValue(DataFlowSolver &solver, return inferredRange.getConstantValue(); } +static bool isPoison(DataFlowSolver &solver, Value value) { + auto *maybeInferredRange = + solver.lookupState(value); + if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) + return false; + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); + return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison(); +} + static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal) { assert(oldVal.getType() == newVal.getType() && @@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value) { if (value.use_empty()) return failure(); + + if (isPoison(solver, value)) { + Value poison = + ub::PoisonOp::create(rewriter, value.getLoc(), value.getType()); + if (solver.lookupState(poison)) + solver.eraseState(poison); + copyIntegerRange(solver, value, poison); + rewriter.replaceAllUsesWith(value, poison); + return success(); + } + std::optional maybeConstValue = getMaybeConstantValue(solver, value); if (!maybeConstValue.has_value()) return failure(); @@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern { return failure(); auto needsReplacing = [&](Value v) { - return getMaybeConstantValue(solver, v).has_value() && !v.use_empty(); + return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) && + !v.use_empty(); }; bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); if (op->getNumRegions() == 0) diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp index ee523f9522953..4bb6f0979cfaa 100644 --- a/mlir/lib/Dialect/UB/IR/UBOps.cpp +++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp @@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value, OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); } +void PoisonOp::inferResultRanges(ArrayRef /*argRanges*/, + SetIntRangeFn setResultRange) { + unsigned width = ConstantIntRanges::getStorageBitwidth(getType()); + setResultRange(getResult(), ConstantIntRanges::poison(width)); +} + #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc" #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp index 9f3e97d051c85..4f6ab0306229f 100644 --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; } const APInt &ConstantIntRanges::smax() const { return smaxVal; } +unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); } + unsigned ConstantIntRanges::getStorageBitwidth(Type type) { type = getElementTypeOrSelf(type); if (type.isIndex()) @@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) { return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth)); } +ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) { + if (bitwidth == 0) { + auto zero = APInt::getZero(0); + return {zero, zero, zero, zero}; + } + + // Poison is represented by an empty range. + auto zero = APInt::getZero(bitwidth); + auto one = zero + 1; + auto onem = zero - 1; + // For i1 the valid unsigned range is [0, 1] and the valid signed range + // is [-1, 0]. + return {one, zero, zero, onem}; +} + ConstantIntRanges ConstantIntRanges::constant(const APInt &value) { return {value, value, value, value}; } @@ -85,15 +102,37 @@ ConstantIntRanges ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const { // "Not an integer" poisons everything and also cannot be fed to comparison // operators. - if (umin().getBitWidth() == 0) + if (getBitWidth() == 0) return *this; - if (other.umin().getBitWidth() == 0) + if (other.getBitWidth() == 0) return other; - const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin(); - const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax(); - const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin(); - const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax(); + APInt uminUnion; + APInt umaxUnion; + APInt sminUnion; + APInt smaxUnion; + + if (isUnsignedPoison()) { + uminUnion = other.umin(); + umaxUnion = other.umax(); + } else if (other.isUnsignedPoison()) { + uminUnion = umin(); + umaxUnion = umax(); + } else { + uminUnion = umin().ult(other.umin()) ? umin() : other.umin(); + umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax(); + } + + if (isSignedPoison()) { + sminUnion = other.smin(); + smaxUnion = other.smax(); + } else if (other.isSignedPoison()) { + sminUnion = smin(); + smaxUnion = smax(); + } else { + sminUnion = smin().slt(other.smin()) ? smin() : other.smin(); + smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax(); + } return {uminUnion, umaxUnion, sminUnion, smaxUnion}; } @@ -102,15 +141,37 @@ ConstantIntRanges ConstantIntRanges::intersection(const ConstantIntRanges &other) const { // "Not an integer" poisons everything and also cannot be fed to comparison // operators. - if (umin().getBitWidth() == 0) + if (getBitWidth() == 0) return *this; - if (other.umin().getBitWidth() == 0) + if (other.getBitWidth() == 0) return other; - const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin(); - const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax(); - const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin(); - const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax(); + APInt uminIntersect; + APInt umaxIntersect; + APInt sminIntersect; + APInt smaxIntersect; + + if (isUnsignedPoison()) { + uminIntersect = umin(); + umaxIntersect = umax(); + } else if (other.isUnsignedPoison()) { + uminIntersect = other.umin(); + umaxIntersect = other.umax(); + } else { + uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin(); + umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax(); + } + + if (isSignedPoison()) { + sminIntersect = smin(); + smaxIntersect = smax(); + } else if (other.isSignedPoison()) { + sminIntersect = other.smin(); + smaxIntersect = other.smax(); + } else { + sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin(); + smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax(); + } return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect}; } @@ -124,6 +185,14 @@ std::optional ConstantIntRanges::getConstantValue() const { return std::nullopt; } +bool ConstantIntRanges::isSignedPoison() const { + return getBitWidth() > 0 && smin().sgt(smax()); +} + +bool ConstantIntRanges::isUnsignedPoison() const { + return getBitWidth() > 0 && umin().ugt(umax()); +} + raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) { os << "unsigned : ["; range.umin().print(os, /*isSigned*/ false); diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index 2f47939df5a02..36841e2f2cc9a 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -32,6 +32,29 @@ using namespace mlir; // General utilities //===----------------------------------------------------------------------===// +/// If any of the arguments are poison, return poison. +static ConstantIntRanges +propagatePoison(const ConstantIntRanges &newRange, + ArrayRef argRanges) { + APInt umin = newRange.umin(); + APInt umax = newRange.umax(); + APInt smin = newRange.smin(); + APInt smax = newRange.smax(); + + unsigned width = umin.getBitWidth(); + for (const ConstantIntRanges &argRange : argRanges) { + if (argRange.isSignedPoison()) { + smin = APInt::getZero(width); + smax = smin - 1; + } + if (argRange.isUnsignedPoison()) { + umax = APInt::getZero(width); + umin = umax + 1; + } + } + return {umin, umax, smin, smax}; +} + /// Function that evaluates the result of doing something on arithmetic /// constants and returns std::nullopt on overflow. using ConstArithFn = @@ -114,7 +137,7 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn, // Returing the 64-bit result preserves more information. return sixtyFour; ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); - return merged; + return propagatePoison(merged, argRanges); } ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, @@ -123,21 +146,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, APInt umax = range.umax().zext(destWidth); APInt smin = range.smin().sext(destWidth); APInt smax = range.smax().sext(destWidth); - return {umin, umax, smin, smax}; + return propagatePoison({umin, umax, smin, smax}, range); } ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, unsigned destWidth) { APInt umin = range.umin().zext(destWidth); APInt umax = range.umax().zext(destWidth); - return ConstantIntRanges::fromUnsigned(umin, umax); + return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range); } ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, unsigned destWidth) { APInt smin = range.smin().sext(destWidth); APInt smax = range.smax().sext(destWidth); - return ConstantIntRanges::fromSigned(smin, smax); + return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range); } ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, @@ -173,7 +196,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, : range.smin().trunc(destWidth); APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) : range.smax().trunc(destWidth); - return {umin, umax, smin, smax}; + return propagatePoison({umin, umax, smin, smax}, range); } //===----------------------------------------------------------------------===// @@ -206,7 +229,7 @@ mlir::intrange::inferAdd(ArrayRef argRanges, uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); ConstantIntRanges srange = computeBoundsBy( sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); - return urange.intersection(srange); + return propagatePoison(urange.intersection(srange), argRanges); } //===----------------------------------------------------------------------===// @@ -238,7 +261,7 @@ mlir::intrange::inferSub(ArrayRef argRanges, usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); ConstantIntRanges srange = computeBoundsBy( ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); - return urange.intersection(srange); + return propagatePoison(urange.intersection(srange), argRanges); } //===----------------------------------------------------------------------===// @@ -273,7 +296,7 @@ mlir::intrange::inferMul(ArrayRef argRanges, ConstantIntRanges srange = minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, /*isSigned=*/true); - return urange.intersection(srange); + return propagatePoison(urange.intersection(srange), argRanges); } //===----------------------------------------------------------------------===// @@ -306,7 +329,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, // X u/ Y u<= X. APInt umax = lhsMax; - return ConstantIntRanges::fromUnsigned(umin, umax); + return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), + {lhs, rhs}); } ConstantIntRanges @@ -351,10 +375,12 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, APInt result = a.sdiv_ov(b, overflowed); return overflowed ? std::optional() : fixup(a, b, result); }; - return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/true); + return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/true), + {lhs, rhs}); } - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); + return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()), + {lhs, rhs}); } ConstantIntRanges @@ -395,7 +421,7 @@ mlir::intrange::inferCeilDivS(ArrayRef argRanges) { auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax()); result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix)); } - return result; + return propagatePoison(result, argRanges); } ConstantIntRanges @@ -425,6 +451,9 @@ mlir::intrange::inferRemS(ArrayRef argRanges) { const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), &rhsMax = rhs.smax(); + if (lhs.isSignedPoison() || rhs.isSignedPoison()) + return ConstantIntRanges::poison(rhsMin.getBitWidth()); + unsigned width = rhsMax.getBitWidth(); APInt smin = APInt::getSignedMinValue(width); APInt smax = APInt::getSignedMaxValue(width); @@ -463,6 +492,9 @@ mlir::intrange::inferRemU(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); + if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison()) + return ConstantIntRanges::poison(rhsMin.getBitWidth()); + unsigned width = rhsMin.getBitWidth(); APInt umin = APInt::getZero(width); // Remainder can't be larger than either of its arguments. @@ -492,6 +524,8 @@ mlir::intrange::inferRemU(ArrayRef argRanges) { ConstantIntRanges mlir::intrange::inferMaxS(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + if (lhs.isSignedPoison() || rhs.isSignedPoison()) + return ConstantIntRanges::poison(lhs.smin().getBitWidth()); const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); @@ -501,6 +535,8 @@ mlir::intrange::inferMaxS(ArrayRef argRanges) { ConstantIntRanges mlir::intrange::inferMaxU(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison()) + return ConstantIntRanges::poison(lhs.umin().getBitWidth()); const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); @@ -510,6 +546,8 @@ mlir::intrange::inferMaxU(ArrayRef argRanges) { ConstantIntRanges mlir::intrange::inferMinS(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + if (lhs.isSignedPoison() || rhs.isSignedPoison()) + return ConstantIntRanges::poison(lhs.smin().getBitWidth()); const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); @@ -519,6 +557,8 @@ mlir::intrange::inferMinS(ArrayRef argRanges) { ConstantIntRanges mlir::intrange::inferMinU(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison()) + return ConstantIntRanges::poison(lhs.umin().getBitWidth()); const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); @@ -550,8 +590,10 @@ mlir::intrange::inferAnd(ArrayRef argRanges) { auto andi = [](const APInt &a, const APInt &b) -> std::optional { return a & b; }; - return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); + return propagatePoison(minMaxBy(andi, {lhsZeros, lhsOnes}, + {rhsZeros, rhsOnes}, + /*isSigned=*/false), + argRanges); } ConstantIntRanges @@ -561,8 +603,9 @@ mlir::intrange::inferOr(ArrayRef argRanges) { auto ori = [](const APInt &a, const APInt &b) -> std::optional { return a | b; }; - return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); + return propagatePoison(minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false), + argRanges); } /// Get bitmask of all bits which can change while iterating in @@ -579,6 +622,9 @@ mlir::intrange::inferXor(ArrayRef argRanges) { // Construct mask of varying bits for both ranges, xor values and then replace // masked bits with 0s and 1s to get min and max values respectively. ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1]; + if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison()) + return ConstantIntRanges::poison(lhs.umin().getBitWidth()); + APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs); APInt res = lhs.umin() ^ rhs.umin(); APInt min = res & ~mask; @@ -621,7 +667,7 @@ mlir::intrange::inferShl(ArrayRef argRanges, ConstantIntRanges srange = minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax}, /*isSigned=*/true); - return urange.intersection(srange); + return propagatePoison(urange.intersection(srange), argRanges); } ConstantIntRanges @@ -632,8 +678,10 @@ mlir::intrange::inferShrS(ArrayRef argRanges) { return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); }; - return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/true); + return propagatePoison(minMaxBy(ashr, {lhs.smin(), lhs.smax()}, + {rhs.umin(), rhs.umax()}, + /*isSigned=*/true), + argRanges); } ConstantIntRanges @@ -643,8 +691,10 @@ mlir::intrange::inferShrU(ArrayRef argRanges) { auto lshr = [](const APInt &l, const APInt &r) -> std::optional { return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); }; - return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); + return propagatePoison(minMaxBy(lshr, {lhs.umin(), lhs.umax()}, + {rhs.umin(), rhs.umax()}, + /*isSigned=*/false), + argRanges); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir index 2128d36f1a28e..35e755e98b1f9 100644 --- a/mlir/test/Dialect/Arith/int-range-interface.mlir +++ b/mlir/test/Dialect/Arith/int-range-interface.mlir @@ -654,6 +654,26 @@ func.func @select_union(%arg0 : index, %arg1 : i1) -> i1 { func.return %5 : i1 } +// CHECK-LABEL: func @select_poison +// CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} +func.func @select_poison(%arg0: i1) -> index { + %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index + %1 = test.with_bounds { umin = 1 : index, umax = 0 : index, smin = 1 : index, smax = 0 : index } : index + %2 = arith.select %arg0, %0, %1 : index + %3 = test.reflect_bounds %2 : index + func.return %3 : index +} + +// CHECK-LABEL: func @add_posion +// CHECK: test.reflect_bounds {smax = -1 : index, smin = 0 : index, umax = 0 : index, umin = 1 : index} +func.func @add_posion() -> index { + %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index + %1 = test.with_bounds { umin = 1 : index, umax = 0 : index, smin = 1 : index, smax = 0 : index } : index + %2 = arith.addi %0, %1 : index + %3 = test.reflect_bounds %2 : index + func.return %3 : index +} + // CHECK-LABEL: func @if_union // CHECK: %[[true:.*]] = arith.constant true // CHECK: return %[[true]] diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir index ea5969a100258..d613f38a55f01 100644 --- a/mlir/test/Dialect/Arith/int-range-opts.mlir +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s +// RUN: mlir-opt --int-range-optimizations --split-input-file %s | FileCheck %s // CHECK-LABEL: func @test // CHECK: %[[C:.*]] = arith.constant false @@ -132,3 +132,13 @@ func.func @wraps() -> i8 { %mod = arith.remsi %val, %c64 : i8 return %mod : i8 } + +// ----- + +// CHECK-LABEL: func @create_poison_op +// CHECK: %[[RES:.*]] = ub.poison : i32 +// CHECK: return %[[RES]] +func.func @create_poison_op() -> i32 { + %val = test.with_bounds { umin = 1 : i32, umax = 0 : i32, smin = 1 : i32, smax = 0 : i32 } : i32 + return %val : i32 +} diff --git a/mlir/test/Dialect/UB/int-range-interface.mlir b/mlir/test/Dialect/UB/int-range-interface.mlir new file mode 100644 index 0000000000000..69f4923ffe6c7 --- /dev/null +++ b/mlir/test/Dialect/UB/int-range-interface.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s + +// CHECK-LABEL: func @poison +// CHECK: test.reflect_bounds {smax = -1 : si32, smin = 0 : si32, umax = 0 : ui32, umin = 1 : ui32} +func.func @poison() -> i32 { + %0 = ub.poison : i32 + %1 = test.reflect_bounds %0 : i32 + func.return %1 : i32 +} + +// CHECK-LABEL: func @poison_i1 +// CHECK: test.reflect_bounds {smax = -1 : si1, smin = 0 : si1, umax = 0 : ui1, umin = 1 : ui1} +func.func @poison_i1() -> i1 { + %0 = ub.poison : i1 + %1 = test.reflect_bounds %0 : i1 + func.return %1 : i1 +} + +// CHECK-LABEL: func @poison_non_int +// Check it doesn't crash. +func.func @poison_non_int() -> f32 { + %0 = ub.poison : f32 + func.return %0 : f32 +} diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index b2f16bb3dac9c..3182ac6bf8b4b 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -116,3 +116,20 @@ func.func @vector_step() -> vector<8xindex> { %1 = test.reflect_bounds %0 : vector<8xindex> func.return %1 : vector<8xindex> } + +// CHECK-LABEL: func @poison_vector_insert +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 1 : index, umax = 4 : index, umin = 1 : index} +func.func @poison_vector_insert() -> vector<4xindex> { + %0 = ub.poison : vector<4xindex> + %1 = test.with_bounds { umin = 1 : index, umax = 1 : index, smin = 1 : index, smax = 1 : index } : index + %2 = test.with_bounds { umin = 2 : index, umax = 2 : index, smin = 2 : index, smax = 2 : index } : index + %3 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 3 : index, smax = 3 : index } : index + %4 = test.with_bounds { umin = 4 : index, umax = 4 : index, smin = 4 : index, smax = 4 : index } : index + %5 = vector.insert %1, %0[0] : index into vector<4xindex> + %6 = vector.insert %2, %5[1] : index into vector<4xindex> + %7 = vector.insert %3, %6[2] : index into vector<4xindex> + %8 = vector.insert %4, %7[3] : index into vector<4xindex> + + %9 = test.reflect_bounds %8 : vector<4xindex> + func.return %9 : vector<4xindex> +}