Skip to content

Commit a185946

Browse files
committed
[GlobalISel] Remove workarounds for cache assertion while adding G_ABS knownbits
1 parent 7e84ad1 commit a185946

File tree

3 files changed

+24
-45
lines changed

3 files changed

+24
-45
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GISelValueTracking.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ class LLVM_ABI GISelValueTracking : public GISelChangeObserver {
6767
void computeKnownBitsImpl(Register R, KnownBits &Known,
6868
const APInt &DemandedElts, unsigned Depth = 0);
6969

70-
virtual unsigned computeNumSignBitsImpl(Register R, const APInt &DemandedElts,
71-
unsigned Depth = 0);
72-
7370
unsigned computeNumSignBits(Register R, const APInt &DemandedElts,
7471
unsigned Depth = 0);
7572
unsigned computeNumSignBits(Register R, unsigned Depth = 0);

llvm/lib/CodeGen/GlobalISel/GISelValueTracking.cpp

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ void GISelValueTracking::computeKnownBitsImpl(Register R, KnownBits &Known,
680680
computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
681681
Known = Known.abs();
682682
Known.Zero.setHighBits(
683-
computeNumSignBitsImpl(SrcReg, DemandedElts, Depth + 1) - 1);
683+
computeNumSignBits(SrcReg, DemandedElts, Depth + 1) - 1);
684684
break;
685685
}
686686
}
@@ -1731,11 +1731,10 @@ unsigned GISelValueTracking::computeNumSignBitsMin(Register Src0, Register Src1,
17311731
const APInt &DemandedElts,
17321732
unsigned Depth) {
17331733
// Test src1 first, since we canonicalize simpler expressions to the RHS.
1734-
unsigned Src1SignBits = computeNumSignBitsImpl(Src1, DemandedElts, Depth);
1734+
unsigned Src1SignBits = computeNumSignBits(Src1, DemandedElts, Depth);
17351735
if (Src1SignBits == 1)
17361736
return 1;
1737-
return std::min(computeNumSignBitsImpl(Src0, DemandedElts, Depth),
1738-
Src1SignBits);
1737+
return std::min(computeNumSignBits(Src0, DemandedElts, Depth), Src1SignBits);
17391738
}
17401739

17411740
/// Compute the known number of sign bits with attached range metadata in the
@@ -1765,9 +1764,9 @@ static unsigned computeNumSignBitsFromRangeMetadata(const GAnyLoad *Ld,
17651764
CR.getSignedMax().getNumSignBits());
17661765
}
17671766

1768-
unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
1769-
const APInt &DemandedElts,
1770-
unsigned Depth) {
1767+
unsigned GISelValueTracking::computeNumSignBits(Register R,
1768+
const APInt &DemandedElts,
1769+
unsigned Depth) {
17711770
MachineInstr &MI = *MRI.getVRegDef(R);
17721771
unsigned Opcode = MI.getOpcode();
17731772

@@ -1797,7 +1796,7 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
17971796
if (Src.getReg().isVirtual() && Src.getSubReg() == 0 &&
17981797
MRI.getType(Src.getReg()).isValid()) {
17991798
// Don't increment Depth for this one since we didn't do any work.
1800-
return computeNumSignBitsImpl(Src.getReg(), DemandedElts, Depth);
1799+
return computeNumSignBits(Src.getReg(), DemandedElts, Depth);
18011800
}
18021801

18031802
return 1;
@@ -1806,15 +1805,15 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
18061805
Register Src = MI.getOperand(1).getReg();
18071806
LLT SrcTy = MRI.getType(Src);
18081807
unsigned Tmp = DstTy.getScalarSizeInBits() - SrcTy.getScalarSizeInBits();
1809-
return computeNumSignBitsImpl(Src, DemandedElts, Depth + 1) + Tmp;
1808+
return computeNumSignBits(Src, DemandedElts, Depth + 1) + Tmp;
18101809
}
18111810
case TargetOpcode::G_ASSERT_SEXT:
18121811
case TargetOpcode::G_SEXT_INREG: {
18131812
// Max of the input and what this extends.
18141813
Register Src = MI.getOperand(1).getReg();
18151814
unsigned SrcBits = MI.getOperand(2).getImm();
18161815
unsigned InRegBits = TyBits - SrcBits + 1;
1817-
return std::max(computeNumSignBitsImpl(Src, DemandedElts, Depth + 1),
1816+
return std::max(computeNumSignBits(Src, DemandedElts, Depth + 1),
18181817
InRegBits);
18191818
}
18201819
case TargetOpcode::G_LOAD: {
@@ -1859,19 +1858,19 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
18591858
case TargetOpcode::G_XOR: {
18601859
Register Src1 = MI.getOperand(1).getReg();
18611860
unsigned Src1NumSignBits =
1862-
computeNumSignBitsImpl(Src1, DemandedElts, Depth + 1);
1861+
computeNumSignBits(Src1, DemandedElts, Depth + 1);
18631862
if (Src1NumSignBits != 1) {
18641863
Register Src2 = MI.getOperand(2).getReg();
18651864
unsigned Src2NumSignBits =
1866-
computeNumSignBitsImpl(Src2, DemandedElts, Depth + 1);
1865+
computeNumSignBits(Src2, DemandedElts, Depth + 1);
18671866
FirstAnswer = std::min(Src1NumSignBits, Src2NumSignBits);
18681867
}
18691868
break;
18701869
}
18711870
case TargetOpcode::G_ASHR: {
18721871
Register Src1 = MI.getOperand(1).getReg();
18731872
Register Src2 = MI.getOperand(2).getReg();
1874-
FirstAnswer = computeNumSignBitsImpl(Src1, DemandedElts, Depth + 1);
1873+
FirstAnswer = computeNumSignBits(Src1, DemandedElts, Depth + 1);
18751874
if (auto C = getValidMinimumShiftAmount(Src2, DemandedElts, Depth + 1))
18761875
FirstAnswer = std::min<uint64_t>(FirstAnswer + *C, TyBits);
18771876
break;
@@ -1921,8 +1920,7 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
19211920
// Check if the sign bits of source go down as far as the truncated value.
19221921
unsigned DstTyBits = DstTy.getScalarSizeInBits();
19231922
unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
1924-
unsigned NumSrcSignBits =
1925-
computeNumSignBitsImpl(Src, DemandedElts, Depth + 1);
1923+
unsigned NumSrcSignBits = computeNumSignBits(Src, DemandedElts, Depth + 1);
19261924
if (NumSrcSignBits > (NumSrcBits - DstTyBits))
19271925
return NumSrcSignBits - (NumSrcBits - DstTyBits);
19281926
break;
@@ -1982,7 +1980,7 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
19821980
continue;
19831981

19841982
unsigned Tmp2 =
1985-
computeNumSignBitsImpl(MO.getReg(), SingleDemandedElt, Depth + 1);
1983+
computeNumSignBits(MO.getReg(), SingleDemandedElt, Depth + 1);
19861984
FirstAnswer = std::min(FirstAnswer, Tmp2);
19871985

19881986
// If we don't know any bits, early out.
@@ -2004,8 +2002,7 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
20042002
DemandedElts.extractBits(NumSubVectorElts, I * NumSubVectorElts);
20052003
if (!DemandedSub)
20062004
continue;
2007-
unsigned Tmp2 =
2008-
computeNumSignBitsImpl(MO.getReg(), DemandedSub, Depth + 1);
2005+
unsigned Tmp2 = computeNumSignBits(MO.getReg(), DemandedSub, Depth + 1);
20092006

20102007
FirstAnswer = std::min(FirstAnswer, Tmp2);
20112008

@@ -2026,22 +2023,21 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
20262023
return 1;
20272024

20282025
if (!!DemandedLHS)
2029-
FirstAnswer = computeNumSignBitsImpl(Src1, DemandedLHS, Depth + 1);
2026+
FirstAnswer = computeNumSignBits(Src1, DemandedLHS, Depth + 1);
20302027
// If we don't know anything, early out and try computeKnownBits fall-back.
20312028
if (FirstAnswer == 1)
20322029
break;
20332030
if (!!DemandedRHS) {
2034-
unsigned Tmp2 = computeNumSignBitsImpl(MI.getOperand(2).getReg(),
2035-
DemandedRHS, Depth + 1);
2031+
unsigned Tmp2 =
2032+
computeNumSignBits(MI.getOperand(2).getReg(), DemandedRHS, Depth + 1);
20362033
FirstAnswer = std::min(FirstAnswer, Tmp2);
20372034
}
20382035
break;
20392036
}
20402037
case TargetOpcode::G_SPLAT_VECTOR: {
20412038
// Check if the sign bits of source go down as far as the truncated value.
20422039
Register Src = MI.getOperand(1).getReg();
2043-
unsigned NumSrcSignBits =
2044-
computeNumSignBitsImpl(Src, APInt(1, 1), Depth + 1);
2040+
unsigned NumSrcSignBits = computeNumSignBits(Src, APInt(1, 1), Depth + 1);
20452041
unsigned NumSrcBits = MRI.getType(Src).getSizeInBits();
20462042
if (NumSrcSignBits > (NumSrcBits - TyBits))
20472043
return NumSrcSignBits - (NumSrcBits - TyBits);
@@ -2062,8 +2058,7 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
20622058

20632059
// Finally, if we can prove that the top bits of the result are 0's or 1's,
20642060
// use this information.
2065-
KnownBits Known;
2066-
computeKnownBitsImpl(R, Known, DemandedElts, Depth);
2061+
KnownBits Known = getKnownBits(R, DemandedElts, Depth);
20672062
APInt Mask;
20682063
if (Known.isNonNegative()) { // sign bit is 0
20692064
Mask = Known.Zero;
@@ -2080,15 +2075,6 @@ unsigned GISelValueTracking::computeNumSignBitsImpl(Register R,
20802075
return std::max(FirstAnswer, Mask.countl_one());
20812076
}
20822077

2083-
unsigned GISelValueTracking::computeNumSignBits(Register R,
2084-
const APInt &DemandedElts,
2085-
unsigned Depth) {
2086-
assert(ComputeKnownBitsCache.empty() && "Cache should be empty");
2087-
unsigned NumSignBits = computeNumSignBitsImpl(R, DemandedElts, Depth);
2088-
ComputeKnownBitsCache.clear();
2089-
return NumSignBits;
2090-
}
2091-
20922078
unsigned GISelValueTracking::computeNumSignBits(Register R, unsigned Depth) {
20932079
LLT Ty = MRI.getType(R);
20942080
APInt DemandedElts =
@@ -2139,8 +2125,7 @@ std::optional<ConstantRange> GISelValueTracking::getValidShiftAmountRange(
21392125

21402126
// Use computeKnownBits to find a hidden constant/knownbits (usually type
21412127
// legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
2142-
KnownBits KnownAmt;
2143-
computeKnownBitsImpl(R, KnownAmt, DemandedElts, Depth);
2128+
KnownBits KnownAmt = getKnownBits(R, DemandedElts, Depth);
21442129
if (KnownAmt.getMaxValue().ult(BitWidth))
21452130
return ConstantRange::fromKnownBits(KnownAmt, /*IsSigned=*/false);
21462131

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6139,16 +6139,13 @@ unsigned AMDGPUTargetLowering::computeNumSignBitsForTargetInstr(
61396139
case AMDGPU::G_AMDGPU_SMED3:
61406140
case AMDGPU::G_AMDGPU_UMED3: {
61416141
auto [Dst, Src0, Src1, Src2] = MI->getFirst4Regs();
6142-
unsigned Tmp2 =
6143-
Analysis.computeNumSignBitsImpl(Src2, DemandedElts, Depth + 1);
6142+
unsigned Tmp2 = Analysis.computeNumSignBits(Src2, DemandedElts, Depth + 1);
61446143
if (Tmp2 == 1)
61456144
return 1;
6146-
unsigned Tmp1 =
6147-
Analysis.computeNumSignBitsImpl(Src1, DemandedElts, Depth + 1);
6145+
unsigned Tmp1 = Analysis.computeNumSignBits(Src1, DemandedElts, Depth + 1);
61486146
if (Tmp1 == 1)
61496147
return 1;
6150-
unsigned Tmp0 =
6151-
Analysis.computeNumSignBitsImpl(Src0, DemandedElts, Depth + 1);
6148+
unsigned Tmp0 = Analysis.computeNumSignBits(Src0, DemandedElts, Depth + 1);
61526149
if (Tmp0 == 1)
61536150
return 1;
61546151
return std::min({Tmp0, Tmp1, Tmp2});

0 commit comments

Comments
 (0)