110110#include " llvm/Transforms/Instrumentation.h"
111111#include " llvm/Transforms/Instrumentation/BlockCoverageInference.h"
112112#include " llvm/Transforms/Instrumentation/CFGMST.h"
113+ #include " llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
113114#include " llvm/Transforms/Utils/BasicBlockUtils.h"
114115#include " llvm/Transforms/Utils/MisExpect.h"
115116#include " llvm/Transforms/Utils/ModuleUtils.h"
@@ -333,6 +334,20 @@ extern cl::opt<bool> EnableVTableValueProfiling;
333334extern cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate;
334335} // namespace llvm
335336
337+ bool shouldInstrumentEntryBB () {
338+ return PGOInstrumentEntry ||
339+ PGOCtxProfLoweringPass::isContextualIRPGOEnabled ();
340+ }
341+
342+ // FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls. Ctx
343+ // profiling implicitly captures indirect call cases, but not other values.
344+ // Supporting other values is relatively straight-forward - just another counter
345+ // range within the context.
346+ bool isValueProfilingDisabled () {
347+ return DisableValueProfiling ||
348+ PGOCtxProfLoweringPass::isContextualIRPGOEnabled ();
349+ }
350+
336351// Return a string describing the branch condition that can be
337352// used in static branch probability heuristics:
338353static std::string getBranchCondString (Instruction *TI) {
@@ -379,7 +394,7 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
379394 uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
380395 if (IsCS)
381396 ProfileVersion |= VARIANT_MASK_CSIR_PROF;
382- if (PGOInstrumentEntry )
397+ if (shouldInstrumentEntryBB () )
383398 ProfileVersion |= VARIANT_MASK_INSTR_ENTRY;
384399 if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)
385400 ProfileVersion |= VARIANT_MASK_DBG_CORRELATE;
@@ -861,7 +876,7 @@ static void instrumentOneFunc(
861876 }
862877
863878 FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo (
864- F, TLI, ComdatMembers, true , BPI, BFI, IsCS, PGOInstrumentEntry ,
879+ F, TLI, ComdatMembers, true , BPI, BFI, IsCS, shouldInstrumentEntryBB () ,
865880 PGOBlockCoverage);
866881
867882 auto Name = FuncInfo.FuncNameVar ;
@@ -883,6 +898,43 @@ static void instrumentOneFunc(
883898 unsigned NumCounters =
884899 InstrumentBBs.size () + FuncInfo.SIVisitor .getNumOfSelectInsts ();
885900
901+ if (PGOCtxProfLoweringPass::isContextualIRPGOEnabled ()) {
902+ auto *CSIntrinsic =
903+ Intrinsic::getDeclaration (M, Intrinsic::instrprof_callsite);
904+ // We want to count the instrumentable callsites, then instrument them. This
905+ // is because the llvm.instrprof.callsite intrinsic has an argument (like
906+ // the other instrprof intrinsics) capturing the total number of
907+ // instrumented objects (counters, or callsites, in this case). In this
908+ // case, we want that value so we can readily pass it to the compiler-rt
909+ // APIs that may have to allocate memory based on the nr of callsites.
910+ // The traversal logic is the same for both counting and instrumentation,
911+ // just needs to be done in succession.
912+ auto Visit = [&](llvm::function_ref<void (CallBase * CB)> Visitor) {
913+ for (auto &BB : F)
914+ for (auto &Instr : BB)
915+ if (auto *CS = dyn_cast<CallBase>(&Instr)) {
916+ if ((CS->getCalledFunction () &&
917+ CS->getCalledFunction ()->isIntrinsic ()) ||
918+ dyn_cast<InlineAsm>(CS->getCalledOperand ()))
919+ continue ;
920+ Visitor (CS);
921+ }
922+ };
923+ // First, count callsites.
924+ uint32_t TotalNrCallsites = 0 ;
925+ Visit ([&TotalNrCallsites](auto *) { ++TotalNrCallsites; });
926+
927+ // Now instrument.
928+ uint32_t CallsiteIndex = 0 ;
929+ Visit ([&](auto *CB) {
930+ IRBuilder<> Builder (CB);
931+ Builder.CreateCall (CSIntrinsic,
932+ {Name, CFGHash, Builder.getInt32 (TotalNrCallsites),
933+ Builder.getInt32 (CallsiteIndex++),
934+ CB->getCalledOperand ()});
935+ });
936+ }
937+
886938 uint32_t I = 0 ;
887939 if (PGOTemporalInstrumentation) {
888940 NumCounters += PGOBlockCoverage ? 8 : 1 ;
@@ -914,7 +966,7 @@ static void instrumentOneFunc(
914966 FuncInfo.FunctionHash );
915967 assert (I == NumCounters);
916968
917- if (DisableValueProfiling )
969+ if (isValueProfilingDisabled () )
918970 return ;
919971
920972 NumOfPGOICall += FuncInfo.ValueSites [IPVK_IndirectCallTarget].size ();
@@ -1676,7 +1728,7 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) {
16761728
16771729// Traverse all valuesites and annotate the instructions for all value kind.
16781730void PGOUseFunc::annotateValueSites () {
1679- if (DisableValueProfiling )
1731+ if (isValueProfilingDisabled () )
16801732 return ;
16811733
16821734 // Create the PGOFuncName meta data.
@@ -1779,7 +1831,7 @@ static bool InstrumentAllFunctions(
17791831 function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
17801832 // For the context-sensitve instrumentation, we should have a separated pass
17811833 // (before LTO/ThinLTO linking) to create these variables.
1782- if (!IsCS)
1834+ if (!IsCS && ! PGOCtxProfLoweringPass::isContextualIRPGOEnabled () )
17831835 createIRLevelProfileFlagVar (M, /* IsCS=*/ false );
17841836
17851837 Triple TT (M.getTargetTriple ());
@@ -2018,6 +2070,8 @@ static bool annotateAllFunctions(
20182070 bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled ();
20192071 if (PGOInstrumentEntry.getNumOccurrences () > 0 )
20202072 InstrumentFuncEntry = PGOInstrumentEntry;
2073+ InstrumentFuncEntry |= PGOCtxProfLoweringPass::isContextualIRPGOEnabled ();
2074+
20212075 bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage ();
20222076 for (auto &F : M) {
20232077 if (skipPGOUse (F))
0 commit comments