4848#include " llvm/CodeGen/ValueTypes.h"
4949#include " llvm/CodeGenTypes/MachineValueType.h"
5050#include " llvm/IR/Constant.h"
51- #include " llvm/IR/ConstantRange.h"
5251#include " llvm/IR/Constants.h"
5352#include " llvm/IR/DataLayout.h"
5453#include " llvm/IR/DebugInfoMetadata.h"
@@ -3009,102 +3008,117 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
30093008 return SDValue ();
30103009}
30113010
3012- const APInt *
3013- SelectionDAG::getValidShiftAmountConstant (SDValue V,
3014- const APInt &DemandedElts ) const {
3011+ std::optional<ConstantRange>
3012+ SelectionDAG::getValidShiftAmountRange (SDValue V, const APInt &DemandedElts ,
3013+ unsigned Depth ) const {
30153014 assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
30163015 V.getOpcode () == ISD::SRA) &&
30173016 " Unknown shift node" );
3017+ // Shifting more than the bitwidth is not valid.
30183018 unsigned BitWidth = V.getScalarValueSizeInBits ();
3019- if (ConstantSDNode *SA = isConstOrConstSplat (V.getOperand (1 ), DemandedElts)) {
3020- // Shifting more than the bitwidth is not valid.
3021- const APInt &ShAmt = SA->getAPIntValue ();
3022- if (ShAmt.ult (BitWidth))
3023- return &ShAmt;
3019+
3020+ if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand (1 ))) {
3021+ const APInt &ShAmt = Cst->getAPIntValue ();
3022+ if (ShAmt.uge (BitWidth))
3023+ return std::nullopt ;
3024+ return ConstantRange (ShAmt);
30243025 }
3025- return nullptr ;
3026+
3027+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3028+ const APInt *MinAmt = nullptr , *MaxAmt = nullptr ;
3029+ for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3030+ if (!DemandedElts[i])
3031+ continue ;
3032+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3033+ if (!SA) {
3034+ MinAmt = MaxAmt = nullptr ;
3035+ break ;
3036+ }
3037+ const APInt &ShAmt = SA->getAPIntValue ();
3038+ if (ShAmt.uge (BitWidth))
3039+ return std::nullopt ;
3040+ if (!MinAmt || MinAmt->ugt (ShAmt))
3041+ MinAmt = &ShAmt;
3042+ if (!MaxAmt || MaxAmt->ult (ShAmt))
3043+ MaxAmt = &ShAmt;
3044+ }
3045+ assert (((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
3046+ " Failed to find matching min/max shift amounts" );
3047+ if (MinAmt && MaxAmt)
3048+ return ConstantRange (*MinAmt, *MaxAmt + 1 );
3049+ }
3050+
3051+ // Use computeKnownBits to find a hidden constant/knownbits (usually type
3052+ // legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3053+ KnownBits KnownAmt = computeKnownBits (V.getOperand (1 ), DemandedElts, Depth);
3054+ if (KnownAmt.getMaxValue ().ult (BitWidth))
3055+ return ConstantRange::fromKnownBits (KnownAmt, /* IsSigned=*/ false );
3056+
3057+ return std::nullopt ;
30263058}
30273059
3028- const APInt *SelectionDAG::getValidShiftAmountConstant (SDValue V) const {
3060+ std::optional<uint64_t >
3061+ SelectionDAG::getValidShiftAmount (SDValue V, const APInt &DemandedElts,
3062+ unsigned Depth) const {
3063+ assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3064+ V.getOpcode () == ISD::SRA) &&
3065+ " Unknown shift node" );
3066+ if (std::optional<ConstantRange> AmtRange =
3067+ getValidShiftAmountRange (V, DemandedElts, Depth))
3068+ if (const APInt *ShAmt = AmtRange->getSingleElement ())
3069+ return ShAmt->getZExtValue ();
3070+ return std::nullopt ;
3071+ }
3072+
3073+ std::optional<uint64_t >
3074+ SelectionDAG::getValidShiftAmount (SDValue V, unsigned Depth) const {
30293075 EVT VT = V.getValueType ();
30303076 APInt DemandedElts = VT.isFixedLengthVector ()
30313077 ? APInt::getAllOnes (VT.getVectorNumElements ())
30323078 : APInt (1 , 1 );
3033- return getValidShiftAmountConstant (V, DemandedElts);
3079+ return getValidShiftAmount (V, DemandedElts, Depth );
30343080}
30353081
3036- const APInt *SelectionDAG::getValidMinimumShiftAmountConstant (
3037- SDValue V, const APInt &DemandedElts) const {
3082+ std::optional<uint64_t >
3083+ SelectionDAG::getValidMinimumShiftAmount (SDValue V, const APInt &DemandedElts,
3084+ unsigned Depth) const {
30383085 assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
30393086 V.getOpcode () == ISD::SRA) &&
30403087 " Unknown shift node" );
3041- if (const APInt *ValidAmt = getValidShiftAmountConstant (V, DemandedElts))
3042- return ValidAmt;
3043- unsigned BitWidth = V.getScalarValueSizeInBits ();
3044- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
3045- if (!BV)
3046- return nullptr ;
3047- const APInt *MinShAmt = nullptr ;
3048- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3049- if (!DemandedElts[i])
3050- continue ;
3051- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3052- if (!SA)
3053- return nullptr ;
3054- // Shifting more than the bitwidth is not valid.
3055- const APInt &ShAmt = SA->getAPIntValue ();
3056- if (ShAmt.uge (BitWidth))
3057- return nullptr ;
3058- if (MinShAmt && MinShAmt->ule (ShAmt))
3059- continue ;
3060- MinShAmt = &ShAmt;
3061- }
3062- return MinShAmt;
3088+ if (std::optional<ConstantRange> AmtRange =
3089+ getValidShiftAmountRange (V, DemandedElts, Depth))
3090+ return AmtRange->getUnsignedMin ().getZExtValue ();
3091+ return std::nullopt ;
30633092}
30643093
3065- const APInt *SelectionDAG::getValidMinimumShiftAmountConstant (SDValue V) const {
3094+ std::optional<uint64_t >
3095+ SelectionDAG::getValidMinimumShiftAmount (SDValue V, unsigned Depth) const {
30663096 EVT VT = V.getValueType ();
30673097 APInt DemandedElts = VT.isFixedLengthVector ()
30683098 ? APInt::getAllOnes (VT.getVectorNumElements ())
30693099 : APInt (1 , 1 );
3070- return getValidMinimumShiftAmountConstant (V, DemandedElts);
3100+ return getValidMinimumShiftAmount (V, DemandedElts, Depth );
30713101}
30723102
3073- const APInt *SelectionDAG::getValidMaximumShiftAmountConstant (
3074- SDValue V, const APInt &DemandedElts) const {
3103+ std::optional<uint64_t >
3104+ SelectionDAG::getValidMaximumShiftAmount (SDValue V, const APInt &DemandedElts,
3105+ unsigned Depth) const {
30753106 assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
30763107 V.getOpcode () == ISD::SRA) &&
30773108 " Unknown shift node" );
3078- if (const APInt *ValidAmt = getValidShiftAmountConstant (V, DemandedElts))
3079- return ValidAmt;
3080- unsigned BitWidth = V.getScalarValueSizeInBits ();
3081- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
3082- if (!BV)
3083- return nullptr ;
3084- const APInt *MaxShAmt = nullptr ;
3085- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3086- if (!DemandedElts[i])
3087- continue ;
3088- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3089- if (!SA)
3090- return nullptr ;
3091- // Shifting more than the bitwidth is not valid.
3092- const APInt &ShAmt = SA->getAPIntValue ();
3093- if (ShAmt.uge (BitWidth))
3094- return nullptr ;
3095- if (MaxShAmt && MaxShAmt->uge (ShAmt))
3096- continue ;
3097- MaxShAmt = &ShAmt;
3098- }
3099- return MaxShAmt;
3109+ if (std::optional<ConstantRange> AmtRange =
3110+ getValidShiftAmountRange (V, DemandedElts, Depth))
3111+ return AmtRange->getUnsignedMax ().getZExtValue ();
3112+ return std::nullopt ;
31003113}
31013114
3102- const APInt *SelectionDAG::getValidMaximumShiftAmountConstant (SDValue V) const {
3115+ std::optional<uint64_t >
3116+ SelectionDAG::getValidMaximumShiftAmount (SDValue V, unsigned Depth) const {
31033117 EVT VT = V.getValueType ();
31043118 APInt DemandedElts = VT.isFixedLengthVector ()
31053119 ? APInt::getAllOnes (VT.getVectorNumElements ())
31063120 : APInt (1 , 1 );
3107- return getValidMaximumShiftAmountConstant (V, DemandedElts);
3121+ return getValidMaximumShiftAmount (V, DemandedElts, Depth );
31083122}
31093123
31103124// / Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3583,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35693583 Known = KnownBits::shl (Known, Known2, NUW, NSW, ShAmtNonZero);
35703584
35713585 // Minimum shift low bits are known zero.
3572- if (const APInt * ShMinAmt =
3573- getValidMinimumShiftAmountConstant (Op, DemandedElts))
3574- Known.Zero .setLowBits (ShMinAmt-> getZExtValue () );
3586+ if (std::optional< uint64_t > ShMinAmt =
3587+ getValidMinimumShiftAmount (Op, DemandedElts, Depth + 1 ))
3588+ Known.Zero .setLowBits (* ShMinAmt);
35753589 break ;
35763590 }
35773591 case ISD::SRL:
@@ -3581,9 +3595,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
35813595 Op->getFlags ().hasExact ());
35823596
35833597 // Minimum shift high bits are known zero.
3584- if (const APInt * ShMinAmt =
3585- getValidMinimumShiftAmountConstant (Op, DemandedElts))
3586- Known.Zero .setHighBits (ShMinAmt-> getZExtValue () );
3598+ if (std::optional< uint64_t > ShMinAmt =
3599+ getValidMinimumShiftAmount (Op, DemandedElts, Depth + 1 ))
3600+ Known.Zero .setHighBits (* ShMinAmt);
35873601 break ;
35883602 case ISD::SRA:
35893603 Known = computeKnownBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
@@ -4587,17 +4601,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
45874601 case ISD::SRA:
45884602 Tmp = ComputeNumSignBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
45894603 // SRA X, C -> adds C sign bits.
4590- if (const APInt * ShAmt =
4591- getValidMinimumShiftAmountConstant (Op, DemandedElts))
4592- Tmp = std::min<uint64_t >(Tmp + ShAmt-> getZExtValue () , VTBits);
4604+ if (std::optional< uint64_t > ShAmt =
4605+ getValidMinimumShiftAmount (Op, DemandedElts, Depth + 1 ))
4606+ Tmp = std::min<uint64_t >(Tmp + * ShAmt, VTBits);
45934607 return Tmp;
45944608 case ISD::SHL:
4595- if (const APInt * ShAmt =
4596- getValidMaximumShiftAmountConstant (Op, DemandedElts)) {
4609+ if (std::optional< uint64_t > ShAmt =
4610+ getValidMaximumShiftAmount (Op, DemandedElts, Depth + 1 )) {
45974611 // shl destroys sign bits, ensure it doesn't shift out all sign bits.
45984612 Tmp = ComputeNumSignBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
4599- if (ShAmt-> ult ( Tmp) )
4600- return Tmp - ShAmt-> getZExtValue () ;
4613+ if (* ShAmt < Tmp)
4614+ return Tmp - * ShAmt;
46014615 }
46024616 break ;
46034617 case ISD::AND:
@@ -5270,7 +5284,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
52705284 case ISD::SRL:
52715285 case ISD::SRA:
52725286 // If the max shift amount isn't in range, then the shift can create poison.
5273- return !getValidMaximumShiftAmountConstant (Op, DemandedElts);
5287+ return !getValidMaximumShiftAmount (Op, DemandedElts, Depth + 1 );
52745288
52755289 case ISD::SCALAR_TO_VECTOR:
52765290 // Check if we demand any upper (undef) elements.
0 commit comments