Skip to content

Commit 5580d84

Browse files
authored
Merge branch 'main' into x86-bittest-multiload
2 parents 6e9448d + f6d6d2d commit 5580d84

21 files changed

+615
-91
lines changed

llvm/include/llvm/Support/BranchProbability.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ class BranchProbability {
9797
/// \return \c Num divided by \c this.
9898
LLVM_ABI uint64_t scaleByInverse(uint64_t Num) const;
9999

100+
/// Compute pow(Probability, N).
101+
BranchProbability pow(unsigned N) const;
102+
100103
BranchProbability &operator+=(BranchProbability RHS) {
101104
assert(N != UnknownN && RHS.N != UnknownN &&
102105
"Unknown probability cannot participate in arithmetics.");

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,40 @@ LLVM_ABI bool setLoopEstimatedTripCount(
365365
Loop *L, unsigned EstimatedTripCount,
366366
std::optional<unsigned> EstimatedLoopInvocationWeight = std::nullopt);
367367

368+
/// Based on branch weight metadata, return either:
369+
/// - An unknown probability if the implementation is unable to handle the loop
370+
/// form of \p L (e.g., \p L must have a latch block that controls the loop
371+
/// exit).
372+
/// - The probability \c P that, at the end of any iteration, the latch of \p L
373+
/// will start another iteration such that `1 - P` is the probability of
374+
/// exiting the loop.
375+
BranchProbability getLoopProbability(Loop *L);
376+
377+
/// Set branch weight metadata for the latch of \p L to indicate that, at the
378+
/// end of any iteration, \p P and `1 - P` are the probabilities of starting
379+
/// another iteration and exiting the loop, respectively. Return false if the
380+
/// implementation is unable to handle the loop form of \p L (e.g., \p L must
381+
/// have a latch block that controls the loop exit). Otherwise, return true.
382+
bool setLoopProbability(Loop *L, BranchProbability P);
383+
384+
/// Based on branch weight metadata, return either:
385+
/// - An unknown probability if the implementation cannot extract the
386+
/// probability (e.g., \p B must have exactly two target labels, so it must be
387+
/// a conditional branch).
388+
/// - The probability \c P that control flows from \p B to its first target
389+
/// label such that `1 - P` is the probability of control flowing to its
390+
/// second target label, or vice-versa if \p ForFirstTarget is false.
391+
BranchProbability getBranchProbability(BranchInst *B, bool ForFirstTarget);
392+
393+
/// Set branch weight metadata for \p B to indicate that \p P and `1 - P` are
394+
/// the probabilities of control flowing to its first and second target labels,
395+
/// respectively, or vice-versa if \p ForFirstTarget is false. Return false if
396+
/// the implementation cannot set the probability (e.g., \p B must have exactly
397+
/// two target labels, so it must be a conditional branch). Otherwise, return
398+
/// true.
399+
bool setBranchProbability(BranchInst *B, BranchProbability P,
400+
bool ForFirstTarget);
401+
368402
/// Check inner loop (L) backedge count is known to be invariant on all
369403
/// iterations of its outer loop. If the loop has no parent, this is trivially
370404
/// true.

llvm/include/llvm/Transforms/Utils/UnrollLoop.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ LLVM_ABI bool UnrollRuntimeLoopRemainder(
9797
LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
9898
const TargetTransformInfo *TTI, bool PreserveLCSSA,
9999
unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit,
100-
Loop **ResultLoop = nullptr);
100+
Loop **ResultLoop = nullptr,
101+
std::optional<unsigned> OriginalTripCount = std::nullopt,
102+
BranchProbability OriginalLoopProb = BranchProbability::getUnknown());
101103

102104
LLVM_ABI LoopUnrollResult UnrollAndJamLoop(
103105
Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple,

llvm/lib/Support/BranchProbability.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,10 @@ uint64_t BranchProbability::scale(uint64_t Num) const {
111111
uint64_t BranchProbability::scaleByInverse(uint64_t Num) const {
112112
return ::scale<0>(Num, D, N);
113113
}
114+
115+
BranchProbability BranchProbability::pow(unsigned N) const {
116+
BranchProbability Res = BranchProbability::getOne();
117+
for (unsigned I = 0; I < N; ++I)
118+
Res *= *this;
119+
return Res;
120+
}

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ std::optional<unsigned> getFoldedOpcode(MachineFunction &MF, MachineInstr &MI,
869869
}
870870
}
871871

872-
// This is the version used during inline spilling
872+
// This is the version used during InlineSpiller::spillAroundUses
873873
MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl(
874874
MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned> Ops,
875875
MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS,

llvm/lib/Target/RISCV/RISCVInstrInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def BGE : BranchCC_rri<0b101, "bge">;
768768
def BLTU : BranchCC_rri<0b110, "bltu">;
769769
def BGEU : BranchCC_rri<0b111, "bgeu">;
770770

771-
let IsSignExtendingOpW = 1 in {
771+
let IsSignExtendingOpW = 1, canFoldAsLoad = 1 in {
772772
def LB : Load_ri<0b000, "lb">, Sched<[WriteLDB, ReadMemBase]>;
773773
def LH : Load_ri<0b001, "lh">, Sched<[WriteLDH, ReadMemBase]>;
774774
def LW : Load_ri<0b010, "lw">, Sched<[WriteLDW, ReadMemBase]>;
@@ -889,8 +889,10 @@ def CSRRCI : CSR_ii<0b111, "csrrci">;
889889
/// RV64I instructions
890890

891891
let Predicates = [IsRV64] in {
892+
let canFoldAsLoad = 1 in {
892893
def LWU : Load_ri<0b110, "lwu">, Sched<[WriteLDW, ReadMemBase]>;
893894
def LD : Load_ri<0b011, "ld">, Sched<[WriteLDD, ReadMemBase]>;
895+
}
894896
def SD : Store_rri<0b011, "sd">, Sched<[WriteSTD, ReadStoreData, ReadMemBase]>;
895897

896898
let IsSignExtendingOpW = 1 in {

llvm/lib/Target/RISCV/RISCVInstrInfoD.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ defvar DExtsRV64 = [DExt, ZdinxExt];
7171
//===----------------------------------------------------------------------===//
7272

7373
let Predicates = [HasStdExtD] in {
74+
let canFoldAsLoad = 1 in
7475
def FLD : FPLoad_r<0b011, "fld", FPR64, WriteFLD64>;
7576

7677
// Operands for stores are in the order srcreg, base, offset rather than

llvm/lib/Target/RISCV/RISCVInstrInfoF.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ class PseudoFROUND<DAGOperand Ty, ValueType vt, ValueType intvt = XLenVT>
330330
//===----------------------------------------------------------------------===//
331331

332332
let Predicates = [HasStdExtF] in {
333+
let canFoldAsLoad = 1 in
333334
def FLW : FPLoad_r<0b010, "flw", FPR32, WriteFLD32>;
334335

335336
// Operands for stores are in the order srcreg, base, offset rather than

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3151,6 +3151,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
31513151
return selectInsertElt(ResVReg, ResType, I);
31523152
case Intrinsic::spv_gep:
31533153
return selectGEP(ResVReg, ResType, I);
3154+
case Intrinsic::spv_bitcast: {
3155+
Register OpReg = I.getOperand(2).getReg();
3156+
SPIRVType *OpType =
3157+
OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
3158+
if (!GR.isBitcastCompatible(ResType, OpType))
3159+
report_fatal_error("incompatible result and operand types in a bitcast");
3160+
return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast);
3161+
}
31543162
case Intrinsic::spv_unref_global:
31553163
case Intrinsic::spv_init_global: {
31563164
MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg());

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -192,31 +192,43 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
192192
.addUse(OpReg);
193193
}
194194

195-
// We do instruction selections early instead of calling MIB.buildBitcast()
196-
// generating the general op code G_BITCAST. When MachineVerifier validates
197-
// G_BITCAST we see a check of a kind: if Source Type is equal to Destination
198-
// Type then report error "bitcast must change the type". This doesn't take into
199-
// account the notion of a typed pointer that is important for SPIR-V where a
200-
// user may and should use bitcast between pointers with different pointee types
201-
// (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast).
202-
// It's important for correct lowering in SPIR-V, because interpretation of the
203-
// data type is not left to instructions that utilize the pointer, but encoded
204-
// by the pointer declaration, and the SPIRV target can and must handle the
205-
// declaration and use of pointers that specify the type of data they point to.
206-
// It's not feasible to improve validation of G_BITCAST using just information
207-
// provided by low level types of source and destination. Therefore we don't
208-
// produce G_BITCAST as the general op code with semantics different from
209-
// OpBitcast, but rather lower to OpBitcast immediately. As for now, the only
210-
// difference would be that CombinerHelper couldn't transform known patterns
211-
// around G_BUILD_VECTOR. See discussion
212-
// in https://github.com/llvm/llvm-project/pull/110270 for even more context.
213-
static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
214-
MachineIRBuilder MIB) {
195+
// We lower G_BITCAST to OpBitcast here to avoid a MachineVerifier error.
196+
// The verifier checks if the source and destination LLTs of a G_BITCAST are
197+
// different, but this check is too strict for SPIR-V's typed pointers, which
198+
// may have the same LLT but different SPIRVType (e.g. pointers to different
199+
// pointee types). By lowering to OpBitcast here, we bypass the verifier's
200+
// check. See discussion in https://github.com/llvm/llvm-project/pull/110270
201+
// for more context.
202+
//
203+
// We also handle the llvm.spv.bitcast intrinsic here. If the source and
204+
// destination SPIR-V types are the same, we lower it to a COPY to enable
205+
// further optimizations like copy propagation.
206+
static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
207+
MachineIRBuilder MIB) {
215208
SmallVector<MachineInstr *, 16> ToErase;
216209
for (MachineBasicBlock &MBB : MF) {
217210
for (MachineInstr &MI : MBB) {
211+
if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
212+
Register DstReg = MI.getOperand(0).getReg();
213+
Register SrcReg = MI.getOperand(2).getReg();
214+
SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg);
215+
assert(
216+
DstType &&
217+
"Expected destination SPIR-V type to have been assigned already.");
218+
SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg);
219+
assert(SrcType &&
220+
"Expected source SPIR-V type to have been assigned already.");
221+
if (DstType == SrcType) {
222+
MIB.setInsertPt(*MI.getParent(), MI);
223+
MIB.buildCopy(DstReg, SrcReg);
224+
ToErase.push_back(&MI);
225+
continue;
226+
}
227+
}
228+
218229
if (MI.getOpcode() != TargetOpcode::G_BITCAST)
219230
continue;
231+
220232
MIB.setInsertPt(*MI.getParent(), MI);
221233
buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),
222234
MI.getOperand(1).getReg());
@@ -237,16 +249,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
237249
SmallVector<MachineInstr *, 10> ToErase;
238250
for (MachineBasicBlock &MBB : MF) {
239251
for (MachineInstr &MI : MBB) {
240-
if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
241-
!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
252+
if (!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
242253
continue;
243254
assert(MI.getOperand(2).isReg());
244255
MIB.setInsertPt(*MI.getParent(), MI);
245256
ToErase.push_back(&MI);
246-
if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
247-
MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
248-
continue;
249-
}
250257
Register Def = MI.getOperand(0).getReg();
251258
Register Source = MI.getOperand(2).getReg();
252259
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
@@ -1089,7 +1096,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
10891096
removeImplicitFallthroughs(MF, MIB);
10901097
insertSpirvDecorations(MF, GR, MIB);
10911098
insertInlineAsm(MF, GR, ST, MIB);
1092-
selectOpBitcasts(MF, GR, MIB);
1099+
lowerBitcasts(MF, GR, MIB);
10931100

10941101
return true;
10951102
}

0 commit comments

Comments
 (0)