@@ -13995,9 +13995,11 @@ bool BoUpSLP::collectValuesToDemote(
1399513995 if (MultiNodeScalars.contains(V))
1399613996 return false;
1399713997 uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
13998- APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
13999- if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
14000- return true;
13998+ if (OrigBitWidth < BitWidth) {
13999+ APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
14000+ if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
14001+ return true;
14002+ }
1400114003 auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
1400214004 unsigned BitWidth1 = OrigBitWidth - NumSignBits;
1400314005 if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
@@ -14042,6 +14044,30 @@ bool BoUpSLP::collectValuesToDemote(
1404214044 }
1404314045 return true;
1404414046 };
14047+ auto AttemptCheckBitwidth =
14048+ [&](function_ref<bool(unsigned, unsigned)> Checker, bool &NeedToExit) {
14049+ // Try all bitwidth < OrigBitWidth.
14050+ NeedToExit = false;
14051+ uint32_t OrigBitWidth = DL->getTypeSizeInBits(I->getType());
14052+ unsigned BestFailBitwidth = 0;
14053+ for (; BitWidth < OrigBitWidth; BitWidth *= 2) {
14054+ if (Checker(BitWidth, OrigBitWidth))
14055+ return true;
14056+ if (BestFailBitwidth == 0 && FinalAnalysis())
14057+ BestFailBitwidth = BitWidth;
14058+ }
14059+ if (BitWidth >= OrigBitWidth) {
14060+ if (BestFailBitwidth == 0) {
14061+ BitWidth = OrigBitWidth;
14062+ return false;
14063+ }
14064+ MaxDepthLevel = 1;
14065+ BitWidth = BestFailBitwidth;
14066+ NeedToExit = true;
14067+ return true;
14068+ }
14069+ return false;
14070+ };
1404514071 bool NeedToExit = false;
1404614072 switch (I->getOpcode()) {
1404714073
@@ -14074,6 +14100,71 @@ bool BoUpSLP::collectValuesToDemote(
1407414100 return false;
1407514101 break;
1407614102 }
14103+ case Instruction::Shl: {
14104+ // Several vectorized uses? Check if we can truncate it, otherwise - exit.
14105+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14106+ return false;
14107+ // If we are truncating the result of this SHL, and if it's a shift of an
14108+ // inrange amount, we can always perform a SHL in a smaller type.
14109+ if (!AttemptCheckBitwidth(
14110+ [&](unsigned BitWidth, unsigned) {
14111+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
14112+ return AmtKnownBits.getMaxValue().ult(BitWidth);
14113+ },
14114+ NeedToExit))
14115+ return false;
14116+ if (NeedToExit)
14117+ return true;
14118+ if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
14119+ return false;
14120+ break;
14121+ }
14122+ case Instruction::LShr: {
14123+ // Several vectorized uses? Check if we can truncate it, otherwise - exit.
14124+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14125+ return false;
14126+ // If this is a truncate of a logical shr, we can truncate it to a smaller
14127+ // lshr iff we know that the bits we would otherwise be shifting in are
14128+ // already zeros.
14129+ if (!AttemptCheckBitwidth(
14130+ [&](unsigned BitWidth, unsigned OrigBitWidth) {
14131+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
14132+ APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
14133+ return AmtKnownBits.getMaxValue().ult(BitWidth) &&
14134+ MaskedValueIsZero(I->getOperand(0), ShiftedBits,
14135+ SimplifyQuery(*DL));
14136+ },
14137+ NeedToExit))
14138+ return false;
14139+ if (NeedToExit)
14140+ return true;
14141+ if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
14142+ return false;
14143+ break;
14144+ }
14145+ case Instruction::AShr: {
14146+ // Several vectorized uses? Check if we can truncate it, otherwise - exit.
14147+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14148+ return false;
14149+ // If this is a truncate of an arithmetic shr, we can truncate it to a
14150+ // smaller ashr iff we know that all the bits from the sign bit of the
14151+ // original type and the sign bit of the truncate type are similar.
14152+ if (!AttemptCheckBitwidth(
14153+ [&](unsigned BitWidth, unsigned OrigBitWidth) {
14154+ KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
14155+ unsigned ShiftedBits = OrigBitWidth - BitWidth;
14156+ return AmtKnownBits.getMaxValue().ult(BitWidth) &&
14157+ ShiftedBits < ComputeNumSignBits(I->getOperand(0), *DL, 0,
14158+ AC, nullptr, DT);
14159+ },
14160+ NeedToExit))
14161+ return false;
14162+ if (NeedToExit)
14163+ return true;
14164+ if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
14165+ return false;
14166+ break;
14167+ }
1407714168
1407814169 // We can demote selects if we can demote their true and false values.
1407914170 case Instruction::Select: {
0 commit comments