Skip to content

Commit a042cd0

Browse files
authored
[ConstantFPRange] Add support for mul/div (llvm#163063)
This patch adds support for fmul/fdiv operations.
1 parent 09d9f50 commit a042cd0

File tree

3 files changed

+262
-0
lines changed

3 files changed

+262
-0
lines changed

llvm/include/llvm/IR/ConstantFPRange.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,15 @@ class [[nodiscard]] ConstantFPRange {
231231
/// from a subtraction of a value in this range and a value in \p Other.
232232
LLVM_ABI ConstantFPRange sub(const ConstantFPRange &Other) const;
233233

234+
/// Return a new range representing the possible values resulting
235+
/// from a multiplication of a value in this range and a value in \p Other.
236+
LLVM_ABI ConstantFPRange mul(const ConstantFPRange &Other) const;
237+
238+
/// Return a new range representing the possible values resulting
239+
/// from a division of a value in this range and a value in
240+
/// \p Other.
241+
LLVM_ABI ConstantFPRange div(const ConstantFPRange &Other) const;
242+
234243
/// Flush denormal values to zero according to the specified mode.
235244
/// For dynamic mode, we return the union of all possible results.
236245
LLVM_ABI void flushDenormals(DenormalMode::DenormalModeKind Mode);

llvm/lib/IR/ConstantFPRange.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,3 +528,147 @@ void ConstantFPRange::flushDenormals(DenormalMode::DenormalModeKind Mode) {
528528
Lower = minnum(Lower, APFloat::getZero(Sem, ZeroLowerNegative));
529529
Upper = maxnum(Upper, APFloat::getZero(Sem, ZeroUpperNegative));
530530
}
531+
532+
/// Represent a contiguous range of values sharing the same sign.
533+
struct SameSignRange {
534+
bool HasZero;
535+
bool HasNonZero;
536+
bool HasInf;
537+
// The lower and upper bounds of the range (inclusive).
538+
// The sign is dropped and infinities are excluded.
539+
std::optional<std::pair<APFloat, APFloat>> FinitePart;
540+
541+
explicit SameSignRange(const APFloat &Lower, const APFloat &Upper)
542+
: HasZero(Lower.isZero()), HasNonZero(!Upper.isZero()),
543+
HasInf(Upper.isInfinity()) {
544+
assert(!Lower.isNegative() && !Upper.isNegative() &&
545+
"The sign should be dropped.");
546+
assert(strictCompare(Lower, Upper) != APFloat::cmpGreaterThan &&
547+
"Empty set.");
548+
if (!Lower.isInfinity())
549+
FinitePart = {Lower,
550+
HasInf ? APFloat::getLargest(Lower.getSemantics()) : Upper};
551+
}
552+
};
553+
554+
/// Split the range into positive and negative components.
555+
static void splitPosNeg(const APFloat &Lower, const APFloat &Upper,
556+
std::optional<SameSignRange> &NegPart,
557+
std::optional<SameSignRange> &PosPart) {
558+
assert(strictCompare(Lower, Upper) != APFloat::cmpGreaterThan &&
559+
"Non-NaN part is empty.");
560+
if (Lower.isNegative() == Upper.isNegative()) {
561+
if (Lower.isNegative())
562+
NegPart = SameSignRange{abs(Upper), abs(Lower)};
563+
else
564+
PosPart = SameSignRange{Lower, Upper};
565+
return;
566+
}
567+
auto &Sem = Lower.getSemantics();
568+
NegPart = SameSignRange{APFloat::getZero(Sem), abs(Lower)};
569+
PosPart = SameSignRange{APFloat::getZero(Sem), Upper};
570+
}
571+
572+
ConstantFPRange ConstantFPRange::mul(const ConstantFPRange &Other) const {
573+
auto &Sem = getSemantics();
574+
bool ResMayBeQNaN = ((MayBeQNaN || MayBeSNaN) && !Other.isEmptySet()) ||
575+
((Other.MayBeQNaN || Other.MayBeSNaN) && !isEmptySet());
576+
if (isNaNOnly() || Other.isNaNOnly())
577+
return getNaNOnly(Sem, /*MayBeQNaN=*/ResMayBeQNaN,
578+
/*MayBeSNaN=*/false);
579+
std::optional<SameSignRange> LHSNeg, LHSPos, RHSNeg, RHSPos;
580+
splitPosNeg(Lower, Upper, LHSNeg, LHSPos);
581+
splitPosNeg(Other.Lower, Other.Upper, RHSNeg, RHSPos);
582+
APFloat ResLower = APFloat::getInf(Sem, /*Negative=*/false);
583+
APFloat ResUpper = APFloat::getInf(Sem, /*Negative=*/true);
584+
auto Update = [&](std::optional<SameSignRange> &LHS,
585+
std::optional<SameSignRange> &RHS, bool Negative) {
586+
if (!LHS || !RHS)
587+
return;
588+
// 0 * inf = QNaN
589+
ResMayBeQNaN |= LHS->HasZero && RHS->HasInf;
590+
ResMayBeQNaN |= RHS->HasZero && LHS->HasInf;
591+
// NonZero * inf = inf
592+
if ((LHS->HasInf && RHS->HasNonZero) || (RHS->HasInf && LHS->HasNonZero))
593+
(Negative ? ResLower : ResUpper) = APFloat::getInf(Sem, Negative);
594+
// Finite * Finite
595+
if (LHS->FinitePart && RHS->FinitePart) {
596+
APFloat NewLower = LHS->FinitePart->first * RHS->FinitePart->first;
597+
APFloat NewUpper = LHS->FinitePart->second * RHS->FinitePart->second;
598+
if (Negative) {
599+
ResLower = minnum(ResLower, -NewUpper);
600+
ResUpper = maxnum(ResUpper, -NewLower);
601+
} else {
602+
ResLower = minnum(ResLower, NewLower);
603+
ResUpper = maxnum(ResUpper, NewUpper);
604+
}
605+
}
606+
};
607+
Update(LHSNeg, RHSNeg, /*Negative=*/false);
608+
Update(LHSNeg, RHSPos, /*Negative=*/true);
609+
Update(LHSPos, RHSNeg, /*Negative=*/true);
610+
Update(LHSPos, RHSPos, /*Negative=*/false);
611+
return ConstantFPRange(ResLower, ResUpper, ResMayBeQNaN, /*MayBeSNaN=*/false);
612+
}
613+
614+
ConstantFPRange ConstantFPRange::div(const ConstantFPRange &Other) const {
615+
auto &Sem = getSemantics();
616+
bool ResMayBeQNaN = ((MayBeQNaN || MayBeSNaN) && !Other.isEmptySet()) ||
617+
((Other.MayBeQNaN || Other.MayBeSNaN) && !isEmptySet());
618+
if (isNaNOnly() || Other.isNaNOnly())
619+
return getNaNOnly(Sem, /*MayBeQNaN=*/ResMayBeQNaN,
620+
/*MayBeSNaN=*/false);
621+
std::optional<SameSignRange> LHSNeg, LHSPos, RHSNeg, RHSPos;
622+
splitPosNeg(Lower, Upper, LHSNeg, LHSPos);
623+
splitPosNeg(Other.Lower, Other.Upper, RHSNeg, RHSPos);
624+
APFloat ResLower = APFloat::getInf(Sem, /*Negative=*/false);
625+
APFloat ResUpper = APFloat::getInf(Sem, /*Negative=*/true);
626+
auto Update = [&](std::optional<SameSignRange> &LHS,
627+
std::optional<SameSignRange> &RHS, bool Negative) {
628+
if (!LHS || !RHS)
629+
return;
630+
// inf / inf = QNaN 0 / 0 = QNaN
631+
ResMayBeQNaN |= LHS->HasInf && RHS->HasInf;
632+
ResMayBeQNaN |= LHS->HasZero && RHS->HasZero;
633+
// It is not straightforward to infer HasNonZeroFinite = HasFinite &&
634+
// HasNonZero. By definitions we have:
635+
// HasFinite = HasNonZeroFinite || HasZero
636+
// HasNonZero = HasNonZeroFinite || HasInf
637+
// Since the range is contiguous, if both HasFinite and HasNonZero are true,
638+
// HasNonZeroFinite must be true.
639+
bool LHSHasNonZeroFinite = LHS->FinitePart && LHS->HasNonZero;
640+
bool RHSHasNonZeroFinite = RHS->FinitePart && RHS->HasNonZero;
641+
// inf / Finite = inf FiniteNonZero / 0 = inf
642+
if ((LHS->HasInf && RHS->FinitePart) ||
643+
(LHSHasNonZeroFinite && RHS->HasZero))
644+
(Negative ? ResLower : ResUpper) = APFloat::getInf(Sem, Negative);
645+
// Finite / inf = 0
646+
if (LHS->FinitePart && RHS->HasInf) {
647+
APFloat Zero = APFloat::getZero(Sem, /*Negative=*/Negative);
648+
ResLower = minnum(ResLower, Zero);
649+
ResUpper = maxnum(ResUpper, Zero);
650+
}
651+
// Finite / FiniteNonZero
652+
if (LHS->FinitePart && RHSHasNonZeroFinite) {
653+
assert(!RHS->FinitePart->second.isZero() &&
654+
"Divisor should be non-zero.");
655+
APFloat NewLower = LHS->FinitePart->first / RHS->FinitePart->second;
656+
APFloat NewUpper = LHS->FinitePart->second /
657+
(RHS->FinitePart->first.isZero()
658+
? APFloat::getSmallest(Sem, /*Negative=*/false)
659+
: RHS->FinitePart->first);
660+
if (Negative) {
661+
ResLower = minnum(ResLower, -NewUpper);
662+
ResUpper = maxnum(ResUpper, -NewLower);
663+
} else {
664+
ResLower = minnum(ResLower, NewLower);
665+
ResUpper = maxnum(ResUpper, NewUpper);
666+
}
667+
}
668+
};
669+
Update(LHSNeg, RHSNeg, /*Negative=*/false);
670+
Update(LHSNeg, RHSPos, /*Negative=*/true);
671+
Update(LHSPos, RHSNeg, /*Negative=*/true);
672+
Update(LHSPos, RHSPos, /*Negative=*/false);
673+
return ConstantFPRange(ResLower, ResUpper, ResMayBeQNaN, /*MayBeSNaN=*/false);
674+
}

llvm/unittests/IR/ConstantFPRangeTest.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,115 @@ TEST_F(ConstantFPRangeTest, sub) {
10661066
#endif
10671067
}
10681068

1069+
TEST_F(ConstantFPRangeTest, mul) {
1070+
EXPECT_EQ(Full.mul(Full), NonNaN.unionWith(QNaN));
1071+
EXPECT_EQ(Full.mul(Empty), Empty);
1072+
EXPECT_EQ(Empty.mul(Full), Empty);
1073+
EXPECT_EQ(Empty.mul(Empty), Empty);
1074+
EXPECT_EQ(One.mul(One), ConstantFPRange(APFloat(1.0)));
1075+
EXPECT_EQ(Some.mul(Some),
1076+
ConstantFPRange::getNonNaN(APFloat(-9.0), APFloat(9.0)));
1077+
EXPECT_EQ(SomePos.mul(SomeNeg),
1078+
ConstantFPRange::getNonNaN(APFloat(-9.0), APFloat(-0.0)));
1079+
EXPECT_EQ(PosInf.mul(PosInf), PosInf);
1080+
EXPECT_EQ(NegInf.mul(NegInf), PosInf);
1081+
EXPECT_EQ(PosInf.mul(Finite), NonNaN.unionWith(QNaN));
1082+
EXPECT_EQ(NegInf.mul(Finite), NonNaN.unionWith(QNaN));
1083+
EXPECT_EQ(PosInf.mul(NegInf), NegInf);
1084+
EXPECT_EQ(NegInf.mul(PosInf), NegInf);
1085+
EXPECT_EQ(PosZero.mul(NegZero), NegZero);
1086+
EXPECT_EQ(PosZero.mul(Zero), Zero);
1087+
EXPECT_EQ(NegZero.mul(NegZero), PosZero);
1088+
EXPECT_EQ(NegZero.mul(Zero), Zero);
1089+
EXPECT_EQ(NaN.mul(NaN), QNaN);
1090+
EXPECT_EQ(NaN.mul(Finite), QNaN);
1091+
1092+
#if defined(EXPENSIVE_CHECKS)
1093+
EnumerateTwoInterestingConstantFPRanges(
1094+
[](const ConstantFPRange &LHS, const ConstantFPRange &RHS) {
1095+
ConstantFPRange Res = LHS.mul(RHS);
1096+
ConstantFPRange Expected =
1097+
ConstantFPRange::getEmpty(LHS.getSemantics());
1098+
EnumerateValuesInConstantFPRange(
1099+
LHS,
1100+
[&](const APFloat &LHSC) {
1101+
EnumerateValuesInConstantFPRange(
1102+
RHS,
1103+
[&](const APFloat &RHSC) {
1104+
APFloat Prod = LHSC * RHSC;
1105+
EXPECT_TRUE(Res.contains(Prod))
1106+
<< "Wrong result for " << LHS << " * " << RHS
1107+
<< ". The result " << Res << " should contain " << Prod;
1108+
if (!Expected.contains(Prod))
1109+
Expected = Expected.unionWith(ConstantFPRange(Prod));
1110+
},
1111+
/*IgnoreNaNPayload=*/true);
1112+
},
1113+
/*IgnoreNaNPayload=*/true);
1114+
EXPECT_EQ(Res, Expected)
1115+
<< "Suboptimal result for " << LHS << " * " << RHS << ". Expected "
1116+
<< Expected << ", but got " << Res;
1117+
},
1118+
SparseLevel::SpecialValuesOnly);
1119+
#endif
1120+
}
1121+
1122+
TEST_F(ConstantFPRangeTest, div) {
1123+
EXPECT_EQ(Full.div(Full), NonNaN.unionWith(QNaN));
1124+
EXPECT_EQ(Full.div(Empty), Empty);
1125+
EXPECT_EQ(Empty.div(Full), Empty);
1126+
EXPECT_EQ(Empty.div(Empty), Empty);
1127+
EXPECT_EQ(One.div(One), ConstantFPRange(APFloat(1.0)));
1128+
EXPECT_EQ(Some.div(Some), NonNaN.unionWith(QNaN));
1129+
EXPECT_EQ(SomePos.div(SomeNeg),
1130+
ConstantFPRange(APFloat::getInf(Sem, /*Negative=*/true),
1131+
APFloat::getZero(Sem, /*Negative=*/true),
1132+
/*MayBeQNaN=*/true, /*MayBeSNaN=*/false));
1133+
EXPECT_EQ(PosInf.div(PosInf), QNaN);
1134+
EXPECT_EQ(NegInf.div(NegInf), QNaN);
1135+
EXPECT_EQ(PosInf.div(Finite), NonNaN);
1136+
EXPECT_EQ(NegInf.div(Finite), NonNaN);
1137+
EXPECT_EQ(PosInf.div(NegInf), QNaN);
1138+
EXPECT_EQ(NegInf.div(PosInf), QNaN);
1139+
EXPECT_EQ(Zero.div(Zero), QNaN);
1140+
EXPECT_EQ(SomePos.div(PosInf), PosZero);
1141+
EXPECT_EQ(SomeNeg.div(PosInf), NegZero);
1142+
EXPECT_EQ(PosInf.div(SomePos), PosInf);
1143+
EXPECT_EQ(NegInf.div(SomeNeg), PosInf);
1144+
EXPECT_EQ(NegInf.div(Some), NonNaN);
1145+
EXPECT_EQ(NaN.div(NaN), QNaN);
1146+
EXPECT_EQ(NaN.div(Finite), QNaN);
1147+
1148+
#if defined(EXPENSIVE_CHECKS)
1149+
EnumerateTwoInterestingConstantFPRanges(
1150+
[](const ConstantFPRange &LHS, const ConstantFPRange &RHS) {
1151+
ConstantFPRange Res = LHS.div(RHS);
1152+
ConstantFPRange Expected =
1153+
ConstantFPRange::getEmpty(LHS.getSemantics());
1154+
EnumerateValuesInConstantFPRange(
1155+
LHS,
1156+
[&](const APFloat &LHSC) {
1157+
EnumerateValuesInConstantFPRange(
1158+
RHS,
1159+
[&](const APFloat &RHSC) {
1160+
APFloat Val = LHSC / RHSC;
1161+
EXPECT_TRUE(Res.contains(Val))
1162+
<< "Wrong result for " << LHS << " / " << RHS
1163+
<< ". The result " << Res << " should contain " << Val;
1164+
if (!Expected.contains(Val))
1165+
Expected = Expected.unionWith(ConstantFPRange(Val));
1166+
},
1167+
/*IgnoreNaNPayload=*/true);
1168+
},
1169+
/*IgnoreNaNPayload=*/true);
1170+
EXPECT_EQ(Res, Expected)
1171+
<< "Suboptimal result for " << LHS << " / " << RHS << ". Expected "
1172+
<< Expected << ", but got " << Res;
1173+
},
1174+
SparseLevel::SpecialValuesOnly);
1175+
#endif
1176+
}
1177+
10691178
TEST_F(ConstantFPRangeTest, flushDenormals) {
10701179
const fltSemantics &FP8Sem = APFloat::Float8E4M3();
10711180
APFloat NormalVal = APFloat::getSmallestNormalized(FP8Sem);

0 commit comments

Comments
 (0)