From 6bfe29db340572ad3b67c6e6deaea6b237be36c9 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Thu, 25 Sep 2025 02:18:30 +0000 Subject: [PATCH] [SimplifyCFG][profcheck] Fix artificially-failing `preserve-branchweights.ll` --- llvm/include/llvm/IR/Instructions.h | 7 +- llvm/include/llvm/IR/ProfDataUtils.h | 8 +- llvm/lib/IR/Instructions.cpp | 17 --- llvm/lib/IR/ProfDataUtils.cpp | 38 ++++- llvm/lib/Transforms/IPO/SampleProfile.cpp | 8 +- .../Instrumentation/IndirectCallPromotion.cpp | 4 +- llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 134 ++++++------------ llvm/utils/profcheck-xfail.txt | 2 - 8 files changed, 100 insertions(+), 118 deletions(-) diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h index 95a0a7fd2f97e..de7a237098594 100644 --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -32,6 +32,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/OperandTraits.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/Support/AtomicOrdering.h" @@ -3536,8 +3537,6 @@ class SwitchInstProfUpdateWrapper { bool Changed = false; protected: - LLVM_ABI MDNode *buildProfBranchWeightsMD(); - LLVM_ABI void init(); public: @@ -3549,8 +3548,8 @@ class SwitchInstProfUpdateWrapper { SwitchInstProfUpdateWrapper(SwitchInst &SI) : SI(SI) { init(); } ~SwitchInstProfUpdateWrapper() { - if (Changed) - SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD()); + if (Changed && Weights.has_value() && Weights->size() >= 2) + setBranchWeights(SI, Weights.value(), /*IsExpected=*/false); } /// Delegate the call to the underlying SwitchInst::removeCase() and remove diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index e97160e59c795..a0876b169e0b8 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -145,7 +145,13 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I, /// \param Weights an array of weights to set on instruction I. /// \param IsExpected were these weights added from an llvm.expect* intrinsic. LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef Weights, - bool IsExpected); + bool IsExpected, bool ElideAllZero = false); + +/// Variant of `setBranchWeights` where the `Weights` will be fit first to +/// uint32_t by shifting right. +LLVM_ABI void setFittedBranchWeights(Instruction &I, ArrayRef Weights, + bool IsExpected, + bool ElideAllZero = false); /// downscale the given weights preserving the ratio. If the maximum value is /// not already known and not provided via \param KnownMaxCount , it will be diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index dd83168ab3c6e..941e41f3127d5 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -4141,23 +4141,6 @@ void SwitchInst::growOperands() { growHungoffUses(ReservedSpace); } -MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { - assert(Changed && "called only if metadata has changed"); - - if (!Weights) - return nullptr; - - assert(SI.getNumSuccessors() == Weights->size() && - "num of prof branch_weights must accord with num of successors"); - - bool AllZeroes = all_of(*Weights, [](uint32_t W) { return W == 0; }); - - if (AllZeroes || Weights->size() < 2) - return nullptr; - - return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights); -} - void SwitchInstProfUpdateWrapper::init() { MDNode *ProfileData = getBranchWeightMDNode(SI); if (!ProfileData) diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index 99029c1719507..edeca976d293e 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -12,6 +12,7 @@ #include "llvm/IR/ProfDataUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -19,6 +20,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/Support/CommandLine.h" using namespace llvm; @@ -84,10 +86,31 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData, } } +/// Push the weights right to fit in uint32_t. +static SmallVector fitWeights(ArrayRef Weights) { + SmallVector Ret; + Ret.reserve(Weights.size()); + uint64_t Max = *llvm::max_element(Weights); + if (Max > UINT_MAX) { + unsigned Offset = 32 - llvm::countl_zero(Max); + for (const uint64_t &Value : Weights) + Ret.push_back(static_cast(Value >> Offset)); + } else { + append_range(Ret, Weights); + } + return Ret; +} + } // namespace namespace llvm { - +cl::opt ElideAllZeroBranchWeights("elide-all-zero-branch-weights", +#if defined(LLVM_ENABLE_PROFCHECK) + cl::init(false) +#else + cl::init(true) +#endif +); const char *MDProfLabels::BranchWeights = "branch_weights"; const char *MDProfLabels::ExpectedBranchWeights = "expected"; const char *MDProfLabels::ValueProfile = "VP"; @@ -282,12 +305,23 @@ bool hasExplicitlyUnknownBranchWeights(const Instruction &I) { } void setBranchWeights(Instruction &I, ArrayRef Weights, - bool IsExpected) { + bool IsExpected, bool ElideAllZero) { + if ((ElideAllZeroBranchWeights && ElideAllZero) && + llvm::all_of(Weights, [](uint32_t V) { return V == 0; })) { + I.setMetadata(LLVMContext::MD_prof, nullptr); + return; + } + MDBuilder MDB(I.getContext()); MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); I.setMetadata(LLVMContext::MD_prof, BranchWeights); } +void setFittedBranchWeights(Instruction &I, ArrayRef Weights, + bool IsExpected, bool ElideAllZero) { + setBranchWeights(I, fitWeights(Weights), IsExpected, ElideAllZero); +} + SmallVector downscaleWeights(ArrayRef Weights, std::optional KnownMaxCount) { uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value() diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp index 5bc7e34938127..99b8b88ebedbb 100644 --- a/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -1664,8 +1664,9 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { else if (OverwriteExistingWeights) I.setMetadata(LLVMContext::MD_prof, nullptr); } else if (!isa(&I)) { - setBranchWeights(I, {static_cast(BlockWeights[BB])}, - /*IsExpected=*/false); + setBranchWeights( + I, ArrayRef{static_cast(BlockWeights[BB])}, + /*IsExpected=*/false); } } } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) { @@ -1676,7 +1677,8 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { if (cast(I).isIndirectCall()) { I.setMetadata(LLVMContext::MD_prof, nullptr); } else { - setBranchWeights(I, {uint32_t(0)}, /*IsExpected=*/false); + setBranchWeights(I, ArrayRef{uint32_t(0)}, + /*IsExpected=*/false); } } } diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index f451c2b471aa6..0249f210f4754 100644 --- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -672,8 +672,8 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee, createBranchWeights(CB.getContext(), Count, TotalCount - Count)); if (AttachProfToDirectCall) - setBranchWeights(NewInst, {static_cast(Count)}, - /*IsExpected=*/false); + setFittedBranchWeights(NewInst, {Count}, + /*IsExpected=*/false); using namespace ore; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 216bdf4eb9efb..4d1f768e2177a 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -955,33 +955,6 @@ static bool valuesOverlap(std::vector &C1, return false; } -// Set branch weights on SwitchInst. This sets the metadata if there is at -// least one non-zero weight. -static void setBranchWeights(SwitchInst *SI, ArrayRef Weights, - bool IsExpected) { - // Check that there is at least one non-zero weight. Otherwise, pass - // nullptr to setMetadata which will erase the existing metadata. - MDNode *N = nullptr; - if (llvm::any_of(Weights, [](uint32_t W) { return W != 0; })) - N = MDBuilder(SI->getParent()->getContext()) - .createBranchWeights(Weights, IsExpected); - SI->setMetadata(LLVMContext::MD_prof, N); -} - -// Similar to the above, but for branch and select instructions that take -// exactly 2 weights. -static void setBranchWeights(Instruction *I, uint32_t TrueWeight, - uint32_t FalseWeight, bool IsExpected) { - assert(isa(I) || isa(I)); - // Check that there is at least one non-zero weight. Otherwise, pass - // nullptr to setMetadata which will erase the existing metadata. - MDNode *N = nullptr; - if (TrueWeight || FalseWeight) - N = MDBuilder(I->getParent()->getContext()) - .createBranchWeights(TrueWeight, FalseWeight, IsExpected); - I->setMetadata(LLVMContext::MD_prof, N); -} - /// If TI is known to be a terminator instruction and its block is known to /// only have a single predecessor block, check to see if that predecessor is /// also a value comparison with the same value, and if that comparison @@ -1181,16 +1154,6 @@ static void getBranchWeights(Instruction *TI, } } -/// Keep halving the weights until all can fit in uint32_t. -static void fitWeights(MutableArrayRef Weights) { - uint64_t Max = *llvm::max_element(Weights); - if (Max > UINT_MAX) { - unsigned Offset = 32 - llvm::countl_zero(Max); - for (uint64_t &I : Weights) - I >>= Offset; - } -} - static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( BasicBlock *BB, BasicBlock *PredBlock, ValueToValueMapTy &VMap) { Instruction *PTI = PredBlock->getTerminator(); @@ -1446,14 +1409,9 @@ bool SimplifyCFGOpt::performValueComparisonIntoPredecessorFolding( for (ValueEqualityComparisonCase &V : PredCases) NewSI->addCase(V.Value, V.Dest); - if (PredHasWeights || SuccHasWeights) { - // Halve the weights if any of them cannot fit in an uint32_t - fitWeights(Weights); - - SmallVector MDWeights(Weights.begin(), Weights.end()); - - setBranchWeights(NewSI, MDWeights, /*IsExpected=*/false); - } + if (PredHasWeights || SuccHasWeights) + setFittedBranchWeights(*NewSI, Weights, /*IsExpected=*/false, + /*ElideAllZero=*/true); eraseTerminatorAndDCECond(PTI); @@ -4053,39 +4011,34 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, // Try to update branch weights. uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight; - SmallVector MDWeights; + SmallVector MDWeights; if (extractPredSuccWeights(PBI, BI, PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight)) { - SmallVector NewWeights; if (PBI->getSuccessor(0) == BB) { // PBI: br i1 %x, BB, FalseDest // BI: br i1 %y, UniqueSucc, FalseDest // TrueWeight is TrueWeight for PBI * TrueWeight for BI. - NewWeights.push_back(PredTrueWeight * SuccTrueWeight); + MDWeights.push_back(PredTrueWeight * SuccTrueWeight); // FalseWeight is FalseWeight for PBI * TotalWeight for BI + // TrueWeight for PBI * FalseWeight for BI. // We assume that total weights of a BranchInst can fit into 32 bits. // Therefore, we will not have overflow using 64-bit arithmetic. - NewWeights.push_back(PredFalseWeight * - (SuccFalseWeight + SuccTrueWeight) + - PredTrueWeight * SuccFalseWeight); + MDWeights.push_back(PredFalseWeight * (SuccFalseWeight + SuccTrueWeight) + + PredTrueWeight * SuccFalseWeight); } else { // PBI: br i1 %x, TrueDest, BB // BI: br i1 %y, TrueDest, UniqueSucc // TrueWeight is TrueWeight for PBI * TotalWeight for BI + // FalseWeight for PBI * TrueWeight for BI. - NewWeights.push_back(PredTrueWeight * (SuccFalseWeight + SuccTrueWeight) + - PredFalseWeight * SuccTrueWeight); + MDWeights.push_back(PredTrueWeight * (SuccFalseWeight + SuccTrueWeight) + + PredFalseWeight * SuccTrueWeight); // FalseWeight is FalseWeight for PBI * FalseWeight for BI. - NewWeights.push_back(PredFalseWeight * SuccFalseWeight); + MDWeights.push_back(PredFalseWeight * SuccFalseWeight); } - // Halve the weights if any of them cannot fit in an uint32_t - fitWeights(NewWeights); - - append_range(MDWeights, NewWeights); - setBranchWeights(PBI, MDWeights[0], MDWeights[1], /*IsExpected=*/false); + setFittedBranchWeights(*PBI, MDWeights, /*IsExpected=*/false, + /*ElideAllZero=*/true); // TODO: If BB is reachable from all paths through PredBlock, then we // could replace PBI's branch probabilities with BI's. @@ -4125,8 +4078,8 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, if (auto *SI = dyn_cast(PBI->getCondition())) if (!MDWeights.empty()) { assert(isSelectInRoleOfConjunctionOrDisjunction(SI)); - setBranchWeights(SI, MDWeights[0], MDWeights[1], - /*IsExpected=*/false); + setFittedBranchWeights(*SI, {MDWeights[0], MDWeights[1]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } ++NumFoldBranchToCommonDest; @@ -4478,9 +4431,9 @@ static bool mergeConditionalStoreToAddress( if (InvertQCond) std::swap(QWeights[0], QWeights[1]); auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights); - setBranchWeights(PostBB->getTerminator(), CombinedWeights[0], - CombinedWeights[1], - /*IsExpected=*/false); + setFittedBranchWeights(*PostBB->getTerminator(), + {CombinedWeights[0], CombinedWeights[1]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } QB.SetInsertPoint(T); @@ -4836,10 +4789,9 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, uint64_t NewWeights[2] = {PredCommon * (SuccCommon + SuccOther) + PredOther * SuccCommon, PredOther * SuccOther}; - // Halve the weights if any of them cannot fit in an uint32_t - fitWeights(NewWeights); - setBranchWeights(PBI, NewWeights[0], NewWeights[1], /*IsExpected=*/false); + setFittedBranchWeights(*PBI, NewWeights, /*IsExpected=*/false, + /*ElideAllZero=*/true); // Cond may be a select instruction with the first operand set to "true", or // the second to "false" (see how createLogicalOp works for `and` and `or`) if (!ProfcheckDisableMetadataFixes) @@ -4849,8 +4801,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, assert(dyn_cast(SI)->getCondition() == PBICond); // The corresponding probabilities are what was referred to above as // PredCommon and PredOther. - setBranchWeights(SI, PredCommon, PredOther, - /*IsExpected=*/false); + setFittedBranchWeights(*SI, {PredCommon, PredOther}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } @@ -4876,8 +4828,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, if (HasWeights) { uint64_t TrueWeight = PBIOp ? PredFalseWeight : PredTrueWeight; uint64_t FalseWeight = PBIOp ? PredTrueWeight : PredFalseWeight; - setBranchWeights(NV, TrueWeight, FalseWeight, - /*IsExpected=*/false); + setFittedBranchWeights(*NV, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } } @@ -4940,7 +4892,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm, // Create a conditional branch sharing the condition of the select. BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB); if (TrueWeight != FalseWeight) - setBranchWeights(NewBI, TrueWeight, FalseWeight, /*IsExpected=*/false); + setBranchWeights(*NewBI, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) { // Neither of the selected blocks were successors, so this @@ -5889,7 +5842,8 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, TrueWeight /= 2; FalseWeight /= 2; } - setBranchWeights(NewBI, TrueWeight, FalseWeight, /*IsExpected=*/false); + setFittedBranchWeights(*NewBI, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } @@ -6364,9 +6318,9 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, // BranchWeights. We want the probability and negative probability of // Condition == SecondCase. assert(BranchWeights.size() == 3); - setBranchWeights(SI, BranchWeights[2], - BranchWeights[0] + BranchWeights[1], - /*IsExpected=*/false); + setBranchWeights( + *SI, {BranchWeights[2], BranchWeights[0] + BranchWeights[1]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } Value *ValueCompare = @@ -6381,9 +6335,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, size_t FirstCasePos = (Condition != nullptr); size_t SecondCasePos = FirstCasePos + 1; uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0; - setBranchWeights(SI, BranchWeights[FirstCasePos], - DefaultCase + BranchWeights[SecondCasePos], - /*IsExpected=*/false); + setBranchWeights(*SI, + {BranchWeights[FirstCasePos], + DefaultCase + BranchWeights[SecondCasePos]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } return Ret; } @@ -6427,8 +6382,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, // We know there's a Default case. We base the resulting branch // weights off its probability. assert(BranchWeights.size() >= 2); - setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), - BranchWeights[0], /*IsExpected=*/false); + setBranchWeights( + *SI, + {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } return Ret; } @@ -6451,8 +6408,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); if (auto *SI = dyn_cast(Ret); SI && HasBranchWeights) { assert(BranchWeights.size() >= 2); - setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), - BranchWeights[0], /*IsExpected=*/false); + setBranchWeights( + *SI, + {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } return Ret; } @@ -6469,8 +6428,9 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); if (auto *SI = dyn_cast(Ret); SI && HasBranchWeights) { assert(BranchWeights.size() >= 2); - setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0), - BranchWeights[0], /*IsExpected=*/false); + setBranchWeights( + *SI, {accumulate(drop_begin(BranchWeights), 0U), BranchWeights[0]}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } return Ret; } @@ -8152,8 +8112,8 @@ static bool mergeNestedCondBranch(BranchInst *BI, DomTreeUpdater *DTU) { if (HasWeight) { uint64_t Weights[2] = {BBTWeight * BB1FWeight + BBFWeight * BB2TWeight, BBTWeight * BB1TWeight + BBFWeight * BB2FWeight}; - fitWeights(Weights); - setBranchWeights(BI, Weights[0], Weights[1], /*IsExpected=*/false); + setFittedBranchWeights(*BI, Weights, /*IsExpected=*/false, + /*ElideAllZero=*/true); } return true; } diff --git a/llvm/utils/profcheck-xfail.txt b/llvm/utils/profcheck-xfail.txt index 08c89441ec855..77e6ab7c5a6ea 100644 --- a/llvm/utils/profcheck-xfail.txt +++ b/llvm/utils/profcheck-xfail.txt @@ -1414,7 +1414,6 @@ Transforms/SimplifyCFG/merge-cond-stores.ll Transforms/SimplifyCFG/multiple-phis.ll Transforms/SimplifyCFG/PhiBlockMerge.ll Transforms/SimplifyCFG/pr48641.ll -Transforms/SimplifyCFG/preserve-branchweights.ll Transforms/SimplifyCFG/preserve-store-alignment.ll Transforms/SimplifyCFG/rangereduce.ll Transforms/SimplifyCFG/RISCV/select-trunc-i64.ll @@ -1424,7 +1423,6 @@ Transforms/SimplifyCFG/safe-abs.ll Transforms/SimplifyCFG/SimplifyEqualityComparisonWithOnlyPredecessor-domtree-preservation-edgecase.ll Transforms/SimplifyCFG/speculate-blocks.ll Transforms/SimplifyCFG/speculate-derefable-load.ll -Transforms/SimplifyCFG/suppress-zero-branch-weights.ll Transforms/SimplifyCFG/switch_create-custom-dl.ll Transforms/SimplifyCFG/switch_create.ll Transforms/SimplifyCFG/switch-dup-bbs.ll