Skip to content

Commit ca7ffaa

Browse files
committed
[ConstantRange] add nuw support to truncate (NFC) (llvm#152990)
1 parent 6a81dac commit ca7ffaa

File tree

3 files changed

+92
-13
lines changed

3 files changed

+92
-13
lines changed

llvm/include/llvm/IR/ConstantRange.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,9 @@ class [[nodiscard]] ConstantRange {
380380
/// Return a new range in the specified integer type, which must be
381381
/// strictly smaller than the current type. The returned range will
382382
/// correspond to the possible range of values if the source range had been
383-
/// truncated to the specified type.
384-
LLVM_ABI ConstantRange truncate(uint32_t BitWidth) const;
383+
/// truncated to the specified type with wrap type \p NoWrapKind.
384+
LLVM_ABI ConstantRange truncate(uint32_t BitWidth,
385+
unsigned NoWrapKind = 0) const;
385386

386387
/// Make this range have the bit width given by \p BitWidth. The
387388
/// value is zero extended, truncated, or left alone to make it that width.

llvm/lib/IR/ConstantRange.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,8 @@ ConstantRange ConstantRange::signExtend(uint32_t DstTySize) const {
872872
return ConstantRange(Lower.sext(DstTySize), Upper.sext(DstTySize));
873873
}
874874

875-
ConstantRange ConstantRange::truncate(uint32_t DstTySize) const {
875+
ConstantRange ConstantRange::truncate(uint32_t DstTySize,
876+
unsigned NoWrapKind) const {
876877
assert(getBitWidth() > DstTySize && "Not a value truncation");
877878
if (isEmptySet())
878879
return getEmpty(DstTySize);
@@ -886,22 +887,36 @@ ConstantRange ConstantRange::truncate(uint32_t DstTySize) const {
886887
// We use the non-wrapped set code to analyze the [Lower, MaxValue) part, and
887888
// then we do the union with [MaxValue, Upper)
888889
if (isUpperWrapped()) {
889-
// If Upper is greater than or equal to MaxValue(DstTy), it covers the whole
890-
// truncated range.
891-
if (Upper.getActiveBits() > DstTySize || Upper.countr_one() == DstTySize)
890+
// If Upper is greater than MaxValue(DstTy), it covers the whole truncated
891+
// range.
892+
if (Upper.getActiveBits() > DstTySize)
892893
return getFull(DstTySize);
893894

894-
Union = ConstantRange(APInt::getMaxValue(DstTySize),Upper.trunc(DstTySize));
895-
UpperDiv.setAllBits();
896-
897-
// Union covers the MaxValue case, so return if the remaining range is just
898-
// MaxValue(DstTy).
899-
if (LowerDiv == UpperDiv)
900-
return Union;
895+
// For nuw the two parts are: [0, Upper) \/ [Lower, MaxValue(DstTy)]
896+
if (NoWrapKind & TruncInst::NoUnsignedWrap) {
897+
Union = ConstantRange(APInt::getZero(DstTySize), Upper.trunc(DstTySize));
898+
UpperDiv = APInt::getOneBitSet(getBitWidth(), DstTySize);
899+
} else {
900+
// If Upper is equal to MaxValue(DstTy), it covers the whole truncated
901+
// range.
902+
if (Upper.countr_one() == DstTySize)
903+
return getFull(DstTySize);
904+
Union =
905+
ConstantRange(APInt::getMaxValue(DstTySize), Upper.trunc(DstTySize));
906+
UpperDiv.setAllBits();
907+
// Union covers the MaxValue case, so return if the remaining range is
908+
// just MaxValue(DstTy).
909+
if (LowerDiv == UpperDiv)
910+
return Union;
911+
}
901912
}
902913

903914
// Chop off the most significant bits that are past the destination bitwidth.
904915
if (LowerDiv.getActiveBits() > DstTySize) {
916+
// For trunc nuw if LowerDiv is greater than MaxValue(DstTy), the range is
917+
// outside the whole truncated range.
918+
if (NoWrapKind & TruncInst::NoUnsignedWrap)
919+
return Union;
905920
// Mask to just the signficant bits and subtract from LowerDiv/UpperDiv.
906921
APInt Adjust = LowerDiv & APInt::getBitsSetFrom(getBitWidth(), DstTySize);
907922
LowerDiv -= Adjust;
@@ -913,6 +928,10 @@ ConstantRange ConstantRange::truncate(uint32_t DstTySize) const {
913928
return ConstantRange(LowerDiv.trunc(DstTySize),
914929
UpperDiv.trunc(DstTySize)).unionWith(Union);
915930

931+
if (!LowerDiv.isZero() && NoWrapKind & TruncInst::NoUnsignedWrap)
932+
return ConstantRange(LowerDiv.trunc(DstTySize), APInt::getZero(DstTySize))
933+
.unionWith(Union);
934+
916935
// The truncated value wraps around. Check if we can do better than fullset.
917936
if (UpperDivWidth == DstTySize + 1) {
918937
// Clear the MSB so that UpperDiv wraps around.

llvm/unittests/IR/ConstantRangeTest.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,65 @@ TEST_F(ConstantRangeTest, Trunc) {
451451
EXPECT_EQ(SevenOne.truncate(2), ConstantRange(APInt(2, 3), APInt(2, 1)));
452452
}
453453

454+
TEST_F(ConstantRangeTest, TruncNuw) {
455+
auto Range = [](unsigned NumBits, unsigned Lower, unsigned Upper) {
456+
return ConstantRange(APInt(NumBits, Lower), APInt(NumBits, Upper));
457+
};
458+
// trunc([0, 4), 3->2) = full
459+
EXPECT_TRUE(
460+
Range(3, 0, 4).truncate(2, TruncInst::NoUnsignedWrap).isFullSet());
461+
// trunc([0, 3), 3->2) = [0, 3)
462+
EXPECT_EQ(Range(3, 0, 3).truncate(2, TruncInst::NoUnsignedWrap),
463+
Range(2, 0, 3));
464+
// trunc([1, 3), 3->2) = [1, 3)
465+
EXPECT_EQ(Range(3, 1, 3).truncate(2, TruncInst::NoUnsignedWrap),
466+
Range(2, 1, 3));
467+
// trunc([1, 5), 3->2) = [1, 0)
468+
EXPECT_EQ(Range(3, 1, 5).truncate(2, TruncInst::NoUnsignedWrap),
469+
Range(2, 1, 0));
470+
// trunc([4, 7), 3->2) = empty
471+
EXPECT_TRUE(
472+
Range(3, 4, 7).truncate(2, TruncInst::NoUnsignedWrap).isEmptySet());
473+
// trunc([4, 0), 3->2) = empty
474+
EXPECT_TRUE(
475+
Range(3, 4, 0).truncate(2, TruncInst::NoUnsignedWrap).isEmptySet());
476+
// trunc([4, 1), 3->2) = [0, 1)
477+
EXPECT_EQ(Range(3, 4, 1).truncate(2, TruncInst::NoUnsignedWrap),
478+
Range(2, 0, 1));
479+
// trunc([3, 1), 3->2) = [3, 1)
480+
EXPECT_EQ(Range(3, 3, 1).truncate(2, TruncInst::NoUnsignedWrap),
481+
Range(2, 3, 1));
482+
// trunc([3, 0), 3->2) = [3, 0)
483+
EXPECT_EQ(Range(3, 3, 0).truncate(2, TruncInst::NoUnsignedWrap),
484+
Range(2, 3, 0));
485+
// trunc([1, 0), 2->1) = [1, 0)
486+
EXPECT_EQ(Range(2, 1, 0).truncate(1, TruncInst::NoUnsignedWrap),
487+
Range(1, 1, 0));
488+
// trunc([2, 1), 2->1) = [0, 1)
489+
EXPECT_EQ(Range(2, 2, 1).truncate(1, TruncInst::NoUnsignedWrap),
490+
Range(1, 0, 1));
491+
}
492+
493+
TEST_F(ConstantRangeTest, TruncNuwExhaustive) {
494+
EnumerateConstantRanges(4, [&](const ConstantRange &CR) {
495+
unsigned NumBits = 3;
496+
ConstantRange Trunc = CR.truncate(NumBits, TruncInst::NoUnsignedWrap);
497+
SmallBitVector Elems(1 << NumBits);
498+
ForeachNumInConstantRange(CR, [&](const APInt &N) {
499+
if (N.isIntN(NumBits))
500+
Elems.set(N.getZExtValue());
501+
});
502+
TestRange(Trunc, Elems, PreferSmallest, {CR});
503+
});
504+
EnumerateConstantRanges(3, [&](const ConstantRange &CR) {
505+
ConstantRange Trunc = CR.truncate(1, TruncInst::NoUnsignedWrap);
506+
EXPECT_EQ(CR.contains(APInt::getZero(3)),
507+
Trunc.contains(APInt::getZero(1)));
508+
EXPECT_EQ(CR.contains(APInt::getOneBitSet(3, 0)),
509+
Trunc.contains(APInt::getAllOnes(1)));
510+
});
511+
}
512+
454513
TEST_F(ConstantRangeTest, ZExt) {
455514
ConstantRange ZFull = Full.zeroExtend(20);
456515
ConstantRange ZEmpty = Empty.zeroExtend(20);

0 commit comments

Comments
 (0)