@@ -196,6 +196,8 @@ static cl::opt<unsigned> MaxSwitchCasesPerResult(
196196STATISTIC (NumBitMaps, " Number of switch instructions turned into bitmaps" );
197197STATISTIC (NumLinearMaps,
198198 " Number of switch instructions turned into linear mapping" );
199+ STATISTIC (NumTruncatedInts, " Number of switch instructions turned into lookup "
200+ " tables with a truncated value" );
199201STATISTIC (NumLookupTables,
200202 " Number of switch instructions turned into lookup tables" );
201203STATISTIC (
@@ -6486,6 +6488,13 @@ class SwitchLookupTable {
64866488 // shift and mask operations.
64876489 BitMapKind,
64886490
6491+ // For tables with integer elements which only differ in a small subset
6492+ // of their bits, we can truncate the values to a smaller type to discard
6493+ // the bits which are the same for all values, saving memory. Values are
6494+ // retrieved by retrieving the truncated value from a lookup table, then
6495+ // extending and OR-ing to re-construct the original values.
6496+ TruncatedIntKind,
6497+
64896498 // The table is stored as an array of values. Values are retrieved by load
64906499 // instructions from the table.
64916500 ArrayKind
@@ -6503,6 +6512,14 @@ class SwitchLookupTable {
65036512 ConstantInt *LinearMultiplier = nullptr ;
65046513 bool LinearMapValWrapped = false ;
65056514
6515+ // For TruncatedIntKind, these are the truncated lookup table and the
6516+ // constants used to reconstruct original values from their truncated forms.
6517+ std::unique_ptr<SwitchLookupTable> TruncatedLookupTable = nullptr ;
6518+ IntegerType *TruncatedOrigTy = nullptr ;
6519+ ConstantInt *TruncatedShift = nullptr ;
6520+ ConstantInt *TruncatedMask = nullptr ;
6521+ bool TruncatedValWrapped = false ;
6522+
65066523 // For ArrayKind, this is the array.
65076524 GlobalVariable *Array = nullptr ;
65086525};
@@ -6622,6 +6639,132 @@ SwitchLookupTable::SwitchLookupTable(
66226639 return ;
66236640 }
66246641
6642+ // Check if we can truncate the value to a smaller form by discarding leading
6643+ // and trailing bits which are the same for all values.
6644+ if (isa<IntegerType>(ValueType)) {
6645+ IntegerType *IT = cast<IntegerType>(ValueType);
6646+ unsigned OrigWidth = IT->getBitWidth ();
6647+
6648+ // Figure out which bits are always the same.
6649+ APInt AlwaysOneBits = APInt::getAllOnes (OrigWidth);
6650+ APInt AlwaysZeroBits = APInt::getAllOnes (OrigWidth);
6651+ for (Constant *TableEntry : TableContents) {
6652+ // Each entry in the lookup table may be a constant integer, otherwise
6653+ // it must be undefined (Undef or poison) in which case we can simply
6654+ // skip it.
6655+ if (isa<ConstantInt>(TableEntry)) {
6656+ const APInt &Val = cast<ConstantInt>(TableEntry)->getValue ();
6657+ AlwaysOneBits &= Val;
6658+ AlwaysZeroBits &= ~Val;
6659+ }
6660+ }
6661+ assert ((AlwaysOneBits & AlwaysZeroBits).isZero () &&
6662+ " A bit cannot be both zero and one in every case" );
6663+ APInt NotChangingBits = AlwaysOneBits | AlwaysZeroBits;
6664+
6665+ unsigned DiscardedHighBits = NotChangingBits.countLeadingOnes ();
6666+ unsigned DiscardedLowBits = NotChangingBits.countTrailingOnes ();
6667+ unsigned TotalDiscardedBits = DiscardedHighBits + DiscardedLowBits;
6668+ unsigned TruncatedWidth = OrigWidth - TotalDiscardedBits;
6669+
6670+ // If the original type's width is a power of two, we want to ensure
6671+ // that the truncated size is also a power of two. Otherwise, the higher
6672+ // cost of indexing into the array with non-power-of-two-sized elements is
6673+ // probably going to cancel out any benefits we might get from making the
6674+ // table smaller.
6675+ if (has_single_bit (OrigWidth)) {
6676+ unsigned OldTruncatedWidth = TruncatedWidth;
6677+ TruncatedWidth = bit_ceil (TruncatedWidth);
6678+
6679+ // If the truncated size increased, we need to decrease DiscardedHighBits
6680+ // and/or DiscardedLowBits accordingly.
6681+ unsigned TruncatedWidthIncrease = TruncatedWidth - OldTruncatedWidth;
6682+ if (TruncatedWidthIncrease) {
6683+ TotalDiscardedBits -= TruncatedWidthIncrease;
6684+
6685+ // Prioritize decreasing the number of least significant bits discarded:
6686+ // if we can get this to 0, we won't need to shift left
6687+ unsigned LowBitsDecrease =
6688+ std::min (TruncatedWidthIncrease, DiscardedLowBits);
6689+ DiscardedLowBits -= LowBitsDecrease;
6690+
6691+ unsigned HighBitsDecrease = std::min (
6692+ TruncatedWidthIncrease - LowBitsDecrease, DiscardedHighBits);
6693+ DiscardedHighBits -= HighBitsDecrease;
6694+
6695+ assert (TruncatedWidthIncrease == LowBitsDecrease + HighBitsDecrease);
6696+
6697+ assert (DiscardedHighBits + DiscardedLowBits == TotalDiscardedBits &&
6698+ TotalDiscardedBits + TruncatedWidth == OrigWidth);
6699+ }
6700+ }
6701+
6702+ // We'll only truncate the values if the truncated values would be less than
6703+ // half the size of the original values, as otherwise there's unlikely to be
6704+ // any benefit.
6705+ if (TruncatedWidth <= OrigWidth / 2 ) {
6706+ IntegerType *TruncatedTy =
6707+ IntegerType::get (M.getContext (), TruncatedWidth);
6708+
6709+ PoisonValue *TruncatedPoison = PoisonValue::get (TruncatedTy);
6710+
6711+ // Truncate the values and build a new lookup table containing them
6712+ SmallVector<std::pair<ConstantInt *, Constant *>, 64 > TruncatedValues (
6713+ TableSize);
6714+ for (uint64_t I = 0 ; I < TableSize; ++I) {
6715+ ConstantInt *CaseVal =
6716+ ConstantInt::get (M.getContext (), Offset->getValue () + I);
6717+ Constant *CaseRes = TableContents[I];
6718+
6719+ Constant *TruncatedCaseRes;
6720+ if (ConstantInt *IntCaseRes = dyn_cast<ConstantInt>(CaseRes)) {
6721+ APInt TruncatedVal = IntCaseRes->getValue ().extractBits (
6722+ TruncatedWidth, DiscardedLowBits);
6723+ TruncatedCaseRes = ConstantInt::get (M.getContext (), TruncatedVal);
6724+ } else if (isa<PoisonValue>(CaseRes)) {
6725+ TruncatedCaseRes = TruncatedPoison;
6726+ } else {
6727+ assert (isa<UndefValue>(CaseRes));
6728+ // To avoid making a call to the deprecated
6729+ // 'UndefValue ::get(TruncatedTy)', we'll simply replace undefined
6730+ // table entries with zero.
6731+ TruncatedCaseRes =
6732+ ConstantInt::get (M.getContext (), APInt::getZero (TruncatedWidth));
6733+ }
6734+
6735+ assert (TruncatedCaseRes->getType () == TruncatedTy);
6736+
6737+ TruncatedValues[I] = {CaseVal, TruncatedCaseRes};
6738+ }
6739+
6740+ // Recursively construct a new lookup table with the truncated values.
6741+ // This enables us to use more efficient table kinds which weren't
6742+ // possible originally, such as a bitmap.
6743+ TruncatedLookupTable = std::make_unique<SwitchLookupTable>(
6744+ M, TableSize, Offset, TruncatedValues, TruncatedPoison, DL,
6745+ (" switch.truncated." + FuncName).str ());
6746+ TruncatedOrigTy = IT;
6747+
6748+ if (DiscardedLowBits > 0 ) {
6749+ TruncatedShift = ConstantInt::get (IT, DiscardedLowBits);
6750+ TruncatedValWrapped = DiscardedHighBits == 0 ;
6751+ }
6752+
6753+ // The mask we OR on at the end consists of all the bits which are always
6754+ // one, excluding the bits which fit into the truncated value and didn't
6755+ // need to be changed.
6756+ APInt Mask =
6757+ AlwaysOneBits & ~APInt::getBitsSet (OrigWidth, DiscardedLowBits,
6758+ OrigWidth - DiscardedHighBits);
6759+ if (!Mask.isZero ())
6760+ TruncatedMask = ConstantInt::get (M.getContext (), Mask);
6761+
6762+ Kind = TruncatedIntKind;
6763+ ++NumTruncatedInts;
6764+ return ;
6765+ }
6766+ }
6767+
66256768 // Store the table in an array.
66266769 ArrayType *ArrayTy = ArrayType::get (ValueType, TableSize);
66276770 Constant *Initializer = ConstantArray::get (ArrayTy, TableContents);
@@ -6677,6 +6820,23 @@ Value *SwitchLookupTable::buildLookup(Value *Index, IRBuilder<> &Builder) {
66776820 // Mask off.
66786821 return Builder.CreateTrunc (DownShifted, BitMapElementTy, " switch.masked" );
66796822 }
6823+ case TruncatedIntKind: {
6824+ // Load the truncated value from the lookup table
6825+ Value *TruncatedVal = TruncatedLookupTable->buildLookup (Index, Builder);
6826+
6827+ // Derive the original value from the truncated version
6828+ Value *Result = Builder.CreateIntCast (TruncatedVal, TruncatedOrigTy, false ,
6829+ " switch.truncatedint.cast" );
6830+ if (TruncatedShift)
6831+ Result =
6832+ Builder.CreateShl (Result, TruncatedShift, " switch.truncatedint.shift" ,
6833+ /* HasNUW =*/ true ,
6834+ /* HasNSW =*/ !TruncatedValWrapped);
6835+ if (TruncatedMask)
6836+ Result =
6837+ Builder.CreateOr (Result, TruncatedMask, " switch.truncatedint.mask" );
6838+ return Result;
6839+ }
66806840 case ArrayKind: {
66816841 // Make sure the table index will not overflow when treated as signed.
66826842 IntegerType *IT = cast<IntegerType>(Index->getType ());
0 commit comments