Skip to content

Commit c65d6cb

Browse files
authored
[SelectionDAG] Return std::optional<unsigned> from getValidShiftAmount and friends. NFC (#156224)
Instead of std::optional<uint64_t>. Shift amounts must be less than or equal to our maximum supported bit widths which fit in unsigned. Most of the callers already assumed it fit in unsigned.
1 parent b062aad commit c65d6cb

File tree

4 files changed

+37
-37
lines changed

4 files changed

+37
-37
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2350,35 +2350,35 @@ class SelectionDAG {
23502350

23512351
/// If a SHL/SRA/SRL node \p V has a uniform shift amount
23522352
/// that is less than the element bit-width of the shift node, return it.
2353-
LLVM_ABI std::optional<uint64_t>
2353+
LLVM_ABI std::optional<unsigned>
23542354
getValidShiftAmount(SDValue V, const APInt &DemandedElts,
23552355
unsigned Depth = 0) const;
23562356

23572357
/// If a SHL/SRA/SRL node \p V has a uniform shift amount
23582358
/// that is less than the element bit-width of the shift node, return it.
2359-
LLVM_ABI std::optional<uint64_t>
2359+
LLVM_ABI std::optional<unsigned>
23602360
getValidShiftAmount(SDValue V, unsigned Depth = 0) const;
23612361

23622362
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
23632363
/// element bit-width of the shift node, return the minimum possible value.
2364-
LLVM_ABI std::optional<uint64_t>
2364+
LLVM_ABI std::optional<unsigned>
23652365
getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
23662366
unsigned Depth = 0) const;
23672367

23682368
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
23692369
/// element bit-width of the shift node, return the minimum possible value.
2370-
LLVM_ABI std::optional<uint64_t>
2370+
LLVM_ABI std::optional<unsigned>
23712371
getValidMinimumShiftAmount(SDValue V, unsigned Depth = 0) const;
23722372

23732373
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
23742374
/// element bit-width of the shift node, return the maximum possible value.
2375-
LLVM_ABI std::optional<uint64_t>
2375+
LLVM_ABI std::optional<unsigned>
23762376
getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
23772377
unsigned Depth = 0) const;
23782378

23792379
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
23802380
/// element bit-width of the shift node, return the maximum possible value.
2381-
LLVM_ABI std::optional<uint64_t>
2381+
LLVM_ABI std::optional<unsigned>
23822382
getValidMaximumShiftAmount(SDValue V, unsigned Depth = 0) const;
23832383

23842384
/// Match a binop + shuffle pyramid that represents a horizontal reduction

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3299,7 +3299,7 @@ SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
32993299
return std::nullopt;
33003300
}
33013301

3302-
std::optional<uint64_t>
3302+
std::optional<unsigned>
33033303
SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
33043304
unsigned Depth) const {
33053305
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
@@ -3312,7 +3312,7 @@ SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
33123312
return std::nullopt;
33133313
}
33143314

3315-
std::optional<uint64_t>
3315+
std::optional<unsigned>
33163316
SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
33173317
EVT VT = V.getValueType();
33183318
APInt DemandedElts = VT.isFixedLengthVector()
@@ -3321,7 +3321,7 @@ SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
33213321
return getValidShiftAmount(V, DemandedElts, Depth);
33223322
}
33233323

3324-
std::optional<uint64_t>
3324+
std::optional<unsigned>
33253325
SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
33263326
unsigned Depth) const {
33273327
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
@@ -3333,7 +3333,7 @@ SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
33333333
return std::nullopt;
33343334
}
33353335

3336-
std::optional<uint64_t>
3336+
std::optional<unsigned>
33373337
SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
33383338
EVT VT = V.getValueType();
33393339
APInt DemandedElts = VT.isFixedLengthVector()
@@ -3342,7 +3342,7 @@ SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
33423342
return getValidMinimumShiftAmount(V, DemandedElts, Depth);
33433343
}
33443344

3345-
std::optional<uint64_t>
3345+
std::optional<unsigned>
33463346
SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
33473347
unsigned Depth) const {
33483348
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
@@ -3354,7 +3354,7 @@ SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
33543354
return std::nullopt;
33553355
}
33563356

3357-
std::optional<uint64_t>
3357+
std::optional<unsigned>
33583358
SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
33593359
EVT VT = V.getValueType();
33603360
APInt DemandedElts = VT.isFixedLengthVector()
@@ -3828,7 +3828,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
38283828
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);
38293829

38303830
// Minimum shift low bits are known zero.
3831-
if (std::optional<uint64_t> ShMinAmt =
3831+
if (std::optional<unsigned> ShMinAmt =
38323832
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
38333833
Known.Zero.setLowBits(*ShMinAmt);
38343834
break;
@@ -3840,7 +3840,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
38403840
Op->getFlags().hasExact());
38413841

38423842
// Minimum shift high bits are known zero.
3843-
if (std::optional<uint64_t> ShMinAmt =
3843+
if (std::optional<unsigned> ShMinAmt =
38443844
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
38453845
Known.Zero.setHighBits(*ShMinAmt);
38463846
break;
@@ -4887,15 +4887,15 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
48874887
case ISD::SRA:
48884888
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
48894889
// SRA X, C -> adds C sign bits.
4890-
if (std::optional<uint64_t> ShAmt =
4890+
if (std::optional<unsigned> ShAmt =
48914891
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
4892-
Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
4892+
Tmp = std::min(Tmp + *ShAmt, VTBits);
48934893
return Tmp;
48944894
case ISD::SHL:
48954895
if (std::optional<ConstantRange> ShAmtRange =
48964896
getValidShiftAmountRange(Op, DemandedElts, Depth + 1)) {
4897-
uint64_t MaxShAmt = ShAmtRange->getUnsignedMax().getZExtValue();
4898-
uint64_t MinShAmt = ShAmtRange->getUnsignedMin().getZExtValue();
4897+
unsigned MaxShAmt = ShAmtRange->getUnsignedMax().getZExtValue();
4898+
unsigned MinShAmt = ShAmtRange->getUnsignedMin().getZExtValue();
48994899
// Try to look through ZERO/SIGN/ANY_EXTEND. If all extended bits are
49004900
// shifted out, then we can compute the number of sign bits for the
49014901
// operand being extended. A future improvement could be to pass along the
@@ -4906,7 +4906,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
49064906
EVT ExtVT = Ext.getValueType();
49074907
SDValue Extendee = Ext.getOperand(0);
49084908
EVT ExtendeeVT = Extendee.getValueType();
4909-
uint64_t SizeDifference =
4909+
unsigned SizeDifference =
49104910
ExtVT.getScalarSizeInBits() - ExtendeeVT.getScalarSizeInBits();
49114911
if (SizeDifference <= MinShAmt) {
49124912
Tmp = SizeDifference +

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
832832
case ISD::SHL: {
833833
// If we are only demanding sign bits then we can use the shift source
834834
// directly.
835-
if (std::optional<uint64_t> MaxSA =
835+
if (std::optional<unsigned> MaxSA =
836836
DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
837837
SDValue Op0 = Op.getOperand(0);
838838
unsigned ShAmt = *MaxSA;
@@ -847,7 +847,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
847847
case ISD::SRL: {
848848
// If we are only demanding sign bits then we can use the shift source
849849
// directly.
850-
if (std::optional<uint64_t> MaxSA =
850+
if (std::optional<unsigned> MaxSA =
851851
DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
852852
SDValue Op0 = Op.getOperand(0);
853853
unsigned ShAmt = *MaxSA;
@@ -1780,7 +1780,7 @@ bool TargetLowering::SimplifyDemandedBits(
17801780
SDValue Op1 = Op.getOperand(1);
17811781
EVT ShiftVT = Op1.getValueType();
17821782

1783-
if (std::optional<uint64_t> KnownSA =
1783+
if (std::optional<unsigned> KnownSA =
17841784
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
17851785
unsigned ShAmt = *KnownSA;
17861786
if (ShAmt == 0)
@@ -1792,7 +1792,7 @@ bool TargetLowering::SimplifyDemandedBits(
17921792
// TODO - support non-uniform vector amounts.
17931793
if (Op0.getOpcode() == ISD::SRL) {
17941794
if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
1795-
if (std::optional<uint64_t> InnerSA =
1795+
if (std::optional<unsigned> InnerSA =
17961796
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
17971797
unsigned C1 = *InnerSA;
17981798
unsigned Opc = ISD::SHL;
@@ -1832,7 +1832,7 @@ bool TargetLowering::SimplifyDemandedBits(
18321832
// TODO - support non-uniform vector amounts.
18331833
if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
18341834
InnerOp.hasOneUse()) {
1835-
if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
1835+
if (std::optional<unsigned> SA2 = TLO.DAG.getValidShiftAmount(
18361836
InnerOp, DemandedElts, Depth + 2)) {
18371837
unsigned InnerShAmt = *SA2;
18381838
if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
@@ -1949,7 +1949,7 @@ bool TargetLowering::SimplifyDemandedBits(
19491949

19501950
// If we are only demanding sign bits then we can use the shift source
19511951
// directly.
1952-
if (std::optional<uint64_t> MaxSA =
1952+
if (std::optional<unsigned> MaxSA =
19531953
TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
19541954
unsigned ShAmt = *MaxSA;
19551955
unsigned NumSignBits =
@@ -1965,7 +1965,7 @@ bool TargetLowering::SimplifyDemandedBits(
19651965
SDValue Op1 = Op.getOperand(1);
19661966
EVT ShiftVT = Op1.getValueType();
19671967

1968-
if (std::optional<uint64_t> KnownSA =
1968+
if (std::optional<unsigned> KnownSA =
19691969
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
19701970
unsigned ShAmt = *KnownSA;
19711971
if (ShAmt == 0)
@@ -1977,7 +1977,7 @@ bool TargetLowering::SimplifyDemandedBits(
19771977
// TODO - support non-uniform vector amounts.
19781978
if (Op0.getOpcode() == ISD::SHL) {
19791979
if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
1980-
if (std::optional<uint64_t> InnerSA =
1980+
if (std::optional<unsigned> InnerSA =
19811981
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
19821982
unsigned C1 = *InnerSA;
19831983
unsigned Opc = ISD::SRL;
@@ -1997,7 +1997,7 @@ bool TargetLowering::SimplifyDemandedBits(
19971997
// single sra. We can do this if the top bits are never demanded.
19981998
if (Op0.getOpcode() == ISD::SRA && Op0.hasOneUse()) {
19991999
if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
2000-
if (std::optional<uint64_t> InnerSA =
2000+
if (std::optional<unsigned> InnerSA =
20012001
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
20022002
unsigned C1 = *InnerSA;
20032003
// Clamp the combined shift amount if it exceeds the bit width.
@@ -2062,7 +2062,7 @@ bool TargetLowering::SimplifyDemandedBits(
20622062

20632063
// If we are only demanding sign bits then we can use the shift source
20642064
// directly.
2065-
if (std::optional<uint64_t> MaxSA =
2065+
if (std::optional<unsigned> MaxSA =
20662066
TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
20672067
unsigned ShAmt = *MaxSA;
20682068
// Must already be signbits in DemandedBits bounds, and can't demand any
@@ -2101,7 +2101,7 @@ bool TargetLowering::SimplifyDemandedBits(
21012101
if (DemandedBits.isOne())
21022102
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
21032103

2104-
if (std::optional<uint64_t> KnownSA =
2104+
if (std::optional<unsigned> KnownSA =
21052105
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
21062106
unsigned ShAmt = *KnownSA;
21072107
if (ShAmt == 0)
@@ -2110,7 +2110,7 @@ bool TargetLowering::SimplifyDemandedBits(
21102110
// fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
21112111
// supports sext_inreg.
21122112
if (Op0.getOpcode() == ISD::SHL) {
2113-
if (std::optional<uint64_t> InnerSA =
2113+
if (std::optional<unsigned> InnerSA =
21142114
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
21152115
unsigned LowBits = BitWidth - ShAmt;
21162116
EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
@@ -2657,11 +2657,11 @@ bool TargetLowering::SimplifyDemandedBits(
26572657
break;
26582658
}
26592659

2660-
std::optional<uint64_t> ShAmtC =
2660+
std::optional<unsigned> ShAmtC =
26612661
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
26622662
if (!ShAmtC || *ShAmtC >= BitWidth)
26632663
break;
2664-
uint64_t ShVal = *ShAmtC;
2664+
unsigned ShVal = *ShAmtC;
26652665

26662666
APInt HighBits =
26672667
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21252,7 +21252,7 @@ static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
2125221252
// the truncation then we can use PACKSS by converting the srl to a sra.
2125321253
// SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
2125421254
if (In.getOpcode() == ISD::SRL && In->hasOneUse())
21255-
if (std::optional<uint64_t> ShAmt = DAG.getValidShiftAmount(In)) {
21255+
if (std::optional<unsigned> ShAmt = DAG.getValidShiftAmount(In)) {
2125621256
if (*ShAmt == MinSignBits) {
2125721257
PackOpcode = X86ISD::PACKSS;
2125821258
return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
@@ -48383,7 +48383,7 @@ static SDValue checkSignTestSetCCCombine(SDValue Cmp, X86::CondCode &CC,
4838348383
// If Src came from a SHL (probably from an expanded SIGN_EXTEND_INREG), then
4838448384
// peek through and adjust the TEST bit.
4838548385
if (Src.getOpcode() == ISD::SHL) {
48386-
if (std::optional<uint64_t> ShiftAmt = DAG.getValidShiftAmount(Src)) {
48386+
if (std::optional<unsigned> ShiftAmt = DAG.getValidShiftAmount(Src)) {
4838748387
Src = Src.getOperand(0);
4838848388
BitMask.lshrInPlace(*ShiftAmt);
4838948389
}
@@ -54169,10 +54169,10 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
5416954169
static SDValue combinei64TruncSrlConstant(SDValue N, EVT VT, SelectionDAG &DAG,
5417054170
const SDLoc &DL) {
5417154171
assert(N.getOpcode() == ISD::SRL && "Unknown shift opcode");
54172-
std::optional<uint64_t> ValidSrlConst = DAG.getValidShiftAmount(N);
54172+
std::optional<unsigned> ValidSrlConst = DAG.getValidShiftAmount(N);
5417354173
if (!ValidSrlConst)
5417454174
return SDValue();
54175-
uint64_t SrlConstVal = *ValidSrlConst;
54175+
unsigned SrlConstVal = *ValidSrlConst;
5417654176

5417754177
SDValue Op = N.getOperand(0);
5417854178
unsigned Opcode = Op.getOpcode();

0 commit comments

Comments
 (0)