diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h index 10467cc96f9c6..e772095a266cc 100644 --- a/llvm/include/llvm/IR/ConstantFPRange.h +++ b/llvm/include/llvm/IR/ConstantFPRange.h @@ -231,6 +231,15 @@ class [[nodiscard]] ConstantFPRange { /// from a subtraction of a value in this range and a value in \p Other. LLVM_ABI ConstantFPRange sub(const ConstantFPRange &Other) const; + /// Return a new range representing the possible values resulting + /// from a multiplication of a value in this range and a value in \p Other. + LLVM_ABI ConstantFPRange mul(const ConstantFPRange &Other) const; + + /// Return a new range representing the possible values resulting + /// from a division of a value in this range and a value in + /// \p Other. + LLVM_ABI ConstantFPRange div(const ConstantFPRange &Other) const; + /// Flush denormal values to zero according to the specified mode. /// For dynamic mode, we return the union of all possible results. LLVM_ABI void flushDenormals(DenormalMode::DenormalModeKind Mode); diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp index e9c058ee8ec25..5b8768601928e 100644 --- a/llvm/lib/IR/ConstantFPRange.cpp +++ b/llvm/lib/IR/ConstantFPRange.cpp @@ -528,3 +528,147 @@ void ConstantFPRange::flushDenormals(DenormalMode::DenormalModeKind Mode) { Lower = minnum(Lower, APFloat::getZero(Sem, ZeroLowerNegative)); Upper = maxnum(Upper, APFloat::getZero(Sem, ZeroUpperNegative)); } + +/// Represent a contiguous range of values sharing the same sign. +struct SameSignRange { + bool HasZero; + bool HasNonZero; + bool HasInf; + // The lower and upper bounds of the range (inclusive). + // The sign is dropped and infinities are excluded. + std::optional> FinitePart; + + explicit SameSignRange(const APFloat &Lower, const APFloat &Upper) + : HasZero(Lower.isZero()), HasNonZero(!Upper.isZero()), + HasInf(Upper.isInfinity()) { + assert(!Lower.isNegative() && !Upper.isNegative() && + "The sign should be dropped."); + assert(strictCompare(Lower, Upper) != APFloat::cmpGreaterThan && + "Empty set."); + if (!Lower.isInfinity()) + FinitePart = {Lower, + HasInf ? APFloat::getLargest(Lower.getSemantics()) : Upper}; + } +}; + +/// Split the range into positive and negative components. +static void splitPosNeg(const APFloat &Lower, const APFloat &Upper, + std::optional &NegPart, + std::optional &PosPart) { + assert(strictCompare(Lower, Upper) != APFloat::cmpGreaterThan && + "Non-NaN part is empty."); + if (Lower.isNegative() == Upper.isNegative()) { + if (Lower.isNegative()) + NegPart = SameSignRange{abs(Upper), abs(Lower)}; + else + PosPart = SameSignRange{Lower, Upper}; + return; + } + auto &Sem = Lower.getSemantics(); + NegPart = SameSignRange{APFloat::getZero(Sem), abs(Lower)}; + PosPart = SameSignRange{APFloat::getZero(Sem), Upper}; +} + +ConstantFPRange ConstantFPRange::mul(const ConstantFPRange &Other) const { + auto &Sem = getSemantics(); + bool ResMayBeQNaN = ((MayBeQNaN || MayBeSNaN) && !Other.isEmptySet()) || + ((Other.MayBeQNaN || Other.MayBeSNaN) && !isEmptySet()); + if (isNaNOnly() || Other.isNaNOnly()) + return getNaNOnly(Sem, /*MayBeQNaN=*/ResMayBeQNaN, + /*MayBeSNaN=*/false); + std::optional LHSNeg, LHSPos, RHSNeg, RHSPos; + splitPosNeg(Lower, Upper, LHSNeg, LHSPos); + splitPosNeg(Other.Lower, Other.Upper, RHSNeg, RHSPos); + APFloat ResLower = APFloat::getInf(Sem, /*Negative=*/false); + APFloat ResUpper = APFloat::getInf(Sem, /*Negative=*/true); + auto Update = [&](std::optional &LHS, + std::optional &RHS, bool Negative) { + if (!LHS || !RHS) + return; + // 0 * inf = QNaN + ResMayBeQNaN |= LHS->HasZero && RHS->HasInf; + ResMayBeQNaN |= RHS->HasZero && LHS->HasInf; + // NonZero * inf = inf + if ((LHS->HasInf && RHS->HasNonZero) || (RHS->HasInf && LHS->HasNonZero)) + (Negative ? ResLower : ResUpper) = APFloat::getInf(Sem, Negative); + // Finite * Finite + if (LHS->FinitePart && RHS->FinitePart) { + APFloat NewLower = LHS->FinitePart->first * RHS->FinitePart->first; + APFloat NewUpper = LHS->FinitePart->second * RHS->FinitePart->second; + if (Negative) { + ResLower = minnum(ResLower, -NewUpper); + ResUpper = maxnum(ResUpper, -NewLower); + } else { + ResLower = minnum(ResLower, NewLower); + ResUpper = maxnum(ResUpper, NewUpper); + } + } + }; + Update(LHSNeg, RHSNeg, /*Negative=*/false); + Update(LHSNeg, RHSPos, /*Negative=*/true); + Update(LHSPos, RHSNeg, /*Negative=*/true); + Update(LHSPos, RHSPos, /*Negative=*/false); + return ConstantFPRange(ResLower, ResUpper, ResMayBeQNaN, /*MayBeSNaN=*/false); +} + +ConstantFPRange ConstantFPRange::div(const ConstantFPRange &Other) const { + auto &Sem = getSemantics(); + bool ResMayBeQNaN = ((MayBeQNaN || MayBeSNaN) && !Other.isEmptySet()) || + ((Other.MayBeQNaN || Other.MayBeSNaN) && !isEmptySet()); + if (isNaNOnly() || Other.isNaNOnly()) + return getNaNOnly(Sem, /*MayBeQNaN=*/ResMayBeQNaN, + /*MayBeSNaN=*/false); + std::optional LHSNeg, LHSPos, RHSNeg, RHSPos; + splitPosNeg(Lower, Upper, LHSNeg, LHSPos); + splitPosNeg(Other.Lower, Other.Upper, RHSNeg, RHSPos); + APFloat ResLower = APFloat::getInf(Sem, /*Negative=*/false); + APFloat ResUpper = APFloat::getInf(Sem, /*Negative=*/true); + auto Update = [&](std::optional &LHS, + std::optional &RHS, bool Negative) { + if (!LHS || !RHS) + return; + // inf / inf = QNaN 0 / 0 = QNaN + ResMayBeQNaN |= LHS->HasInf && RHS->HasInf; + ResMayBeQNaN |= LHS->HasZero && RHS->HasZero; + // It is not straightforward to infer HasNonZeroFinite = HasFinite && + // HasNonZero. By definitions we have: + // HasFinite = HasNonZeroFinite || HasZero + // HasNonZero = HasNonZeroFinite || HasInf + // Since the range is contiguous, if both HasFinite and HasNonZero are true, + // HasNonZeroFinite must be true. + bool LHSHasNonZeroFinite = LHS->FinitePart && LHS->HasNonZero; + bool RHSHasNonZeroFinite = RHS->FinitePart && RHS->HasNonZero; + // inf / Finite = inf FiniteNonZero / 0 = inf + if ((LHS->HasInf && RHS->FinitePart) || + (LHSHasNonZeroFinite && RHS->HasZero)) + (Negative ? ResLower : ResUpper) = APFloat::getInf(Sem, Negative); + // Finite / inf = 0 + if (LHS->FinitePart && RHS->HasInf) { + APFloat Zero = APFloat::getZero(Sem, /*Negative=*/Negative); + ResLower = minnum(ResLower, Zero); + ResUpper = maxnum(ResUpper, Zero); + } + // Finite / FiniteNonZero + if (LHS->FinitePart && RHSHasNonZeroFinite) { + assert(!RHS->FinitePart->second.isZero() && + "Divisor should be non-zero."); + APFloat NewLower = LHS->FinitePart->first / RHS->FinitePart->second; + APFloat NewUpper = LHS->FinitePart->second / + (RHS->FinitePart->first.isZero() + ? APFloat::getSmallest(Sem, /*Negative=*/false) + : RHS->FinitePart->first); + if (Negative) { + ResLower = minnum(ResLower, -NewUpper); + ResUpper = maxnum(ResUpper, -NewLower); + } else { + ResLower = minnum(ResLower, NewLower); + ResUpper = maxnum(ResUpper, NewUpper); + } + } + }; + Update(LHSNeg, RHSNeg, /*Negative=*/false); + Update(LHSNeg, RHSPos, /*Negative=*/true); + Update(LHSPos, RHSNeg, /*Negative=*/true); + Update(LHSPos, RHSPos, /*Negative=*/false); + return ConstantFPRange(ResLower, ResUpper, ResMayBeQNaN, /*MayBeSNaN=*/false); +} diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp index 2431db90a40bd..67fee962379e1 100644 --- a/llvm/unittests/IR/ConstantFPRangeTest.cpp +++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp @@ -1066,6 +1066,115 @@ TEST_F(ConstantFPRangeTest, sub) { #endif } +TEST_F(ConstantFPRangeTest, mul) { + EXPECT_EQ(Full.mul(Full), NonNaN.unionWith(QNaN)); + EXPECT_EQ(Full.mul(Empty), Empty); + EXPECT_EQ(Empty.mul(Full), Empty); + EXPECT_EQ(Empty.mul(Empty), Empty); + EXPECT_EQ(One.mul(One), ConstantFPRange(APFloat(1.0))); + EXPECT_EQ(Some.mul(Some), + ConstantFPRange::getNonNaN(APFloat(-9.0), APFloat(9.0))); + EXPECT_EQ(SomePos.mul(SomeNeg), + ConstantFPRange::getNonNaN(APFloat(-9.0), APFloat(-0.0))); + EXPECT_EQ(PosInf.mul(PosInf), PosInf); + EXPECT_EQ(NegInf.mul(NegInf), PosInf); + EXPECT_EQ(PosInf.mul(Finite), NonNaN.unionWith(QNaN)); + EXPECT_EQ(NegInf.mul(Finite), NonNaN.unionWith(QNaN)); + EXPECT_EQ(PosInf.mul(NegInf), NegInf); + EXPECT_EQ(NegInf.mul(PosInf), NegInf); + EXPECT_EQ(PosZero.mul(NegZero), NegZero); + EXPECT_EQ(PosZero.mul(Zero), Zero); + EXPECT_EQ(NegZero.mul(NegZero), PosZero); + EXPECT_EQ(NegZero.mul(Zero), Zero); + EXPECT_EQ(NaN.mul(NaN), QNaN); + EXPECT_EQ(NaN.mul(Finite), QNaN); + +#if defined(EXPENSIVE_CHECKS) + EnumerateTwoInterestingConstantFPRanges( + [](const ConstantFPRange &LHS, const ConstantFPRange &RHS) { + ConstantFPRange Res = LHS.mul(RHS); + ConstantFPRange Expected = + ConstantFPRange::getEmpty(LHS.getSemantics()); + EnumerateValuesInConstantFPRange( + LHS, + [&](const APFloat &LHSC) { + EnumerateValuesInConstantFPRange( + RHS, + [&](const APFloat &RHSC) { + APFloat Prod = LHSC * RHSC; + EXPECT_TRUE(Res.contains(Prod)) + << "Wrong result for " << LHS << " * " << RHS + << ". The result " << Res << " should contain " << Prod; + if (!Expected.contains(Prod)) + Expected = Expected.unionWith(ConstantFPRange(Prod)); + }, + /*IgnoreNaNPayload=*/true); + }, + /*IgnoreNaNPayload=*/true); + EXPECT_EQ(Res, Expected) + << "Suboptimal result for " << LHS << " * " << RHS << ". Expected " + << Expected << ", but got " << Res; + }, + SparseLevel::SpecialValuesOnly); +#endif +} + +TEST_F(ConstantFPRangeTest, div) { + EXPECT_EQ(Full.div(Full), NonNaN.unionWith(QNaN)); + EXPECT_EQ(Full.div(Empty), Empty); + EXPECT_EQ(Empty.div(Full), Empty); + EXPECT_EQ(Empty.div(Empty), Empty); + EXPECT_EQ(One.div(One), ConstantFPRange(APFloat(1.0))); + EXPECT_EQ(Some.div(Some), NonNaN.unionWith(QNaN)); + EXPECT_EQ(SomePos.div(SomeNeg), + ConstantFPRange(APFloat::getInf(Sem, /*Negative=*/true), + APFloat::getZero(Sem, /*Negative=*/true), + /*MayBeQNaN=*/true, /*MayBeSNaN=*/false)); + EXPECT_EQ(PosInf.div(PosInf), QNaN); + EXPECT_EQ(NegInf.div(NegInf), QNaN); + EXPECT_EQ(PosInf.div(Finite), NonNaN); + EXPECT_EQ(NegInf.div(Finite), NonNaN); + EXPECT_EQ(PosInf.div(NegInf), QNaN); + EXPECT_EQ(NegInf.div(PosInf), QNaN); + EXPECT_EQ(Zero.div(Zero), QNaN); + EXPECT_EQ(SomePos.div(PosInf), PosZero); + EXPECT_EQ(SomeNeg.div(PosInf), NegZero); + EXPECT_EQ(PosInf.div(SomePos), PosInf); + EXPECT_EQ(NegInf.div(SomeNeg), PosInf); + EXPECT_EQ(NegInf.div(Some), NonNaN); + EXPECT_EQ(NaN.div(NaN), QNaN); + EXPECT_EQ(NaN.div(Finite), QNaN); + +#if defined(EXPENSIVE_CHECKS) + EnumerateTwoInterestingConstantFPRanges( + [](const ConstantFPRange &LHS, const ConstantFPRange &RHS) { + ConstantFPRange Res = LHS.div(RHS); + ConstantFPRange Expected = + ConstantFPRange::getEmpty(LHS.getSemantics()); + EnumerateValuesInConstantFPRange( + LHS, + [&](const APFloat &LHSC) { + EnumerateValuesInConstantFPRange( + RHS, + [&](const APFloat &RHSC) { + APFloat Val = LHSC / RHSC; + EXPECT_TRUE(Res.contains(Val)) + << "Wrong result for " << LHS << " / " << RHS + << ". The result " << Res << " should contain " << Val; + if (!Expected.contains(Val)) + Expected = Expected.unionWith(ConstantFPRange(Val)); + }, + /*IgnoreNaNPayload=*/true); + }, + /*IgnoreNaNPayload=*/true); + EXPECT_EQ(Res, Expected) + << "Suboptimal result for " << LHS << " / " << RHS << ". Expected " + << Expected << ", but got " << Res; + }, + SparseLevel::SpecialValuesOnly); +#endif +} + TEST_F(ConstantFPRangeTest, flushDenormals) { const fltSemantics &FP8Sem = APFloat::Float8E4M3(); APFloat NormalVal = APFloat::getSmallestNormalized(FP8Sem);