diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index ca56e4aa81575..ad888ce744a93 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -127,6 +127,12 @@ LLVM_ABI bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, LLVM_ABI bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights); +/// visit each element of the value profile, calling a function where the first +/// argument will be the key and the next on the value. +LLVM_ABI void visitValueProfile( + const MDNode &ProfData, + llvm::function_ref Visitor); + /// Retrieve the total of all weights from an instruction. /// /// \param I The instruction to extract the total weight from diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index b1b5f67689e6d..1ab281cec0795 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -12,6 +12,7 @@ #include "llvm/IR/ProfDataUtils.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -270,6 +271,17 @@ void setBranchWeights(Instruction &I, ArrayRef Weights, I.setMetadata(LLVMContext::MD_prof, BranchWeights); } +void visitValueProfile( + const MDNode &ProfData, + llvm::function_ref Visitor) { + assert(isValueProfileMD(&ProfData) && + "Expected valid Value Profile Metadata"); + for (unsigned Idx = 1; Idx < ProfData.getNumOperands(); Idx += 2) { + if (!Visitor(ProfData.getOperand(Idx), ProfData.getOperand(Idx + 1))) + break; + } +} + void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { assert(T != 0 && "Caller should guarantee"); auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); @@ -303,24 +315,25 @@ void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { Vals.push_back(MDB.createConstant(ConstantInt::get( Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); } else if (ProfDataName->getString() == MDProfLabels::ValueProfile) - for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) { + visitValueProfile(*ProfileData, [&](const MDOperand &Key, + const MDOperand &Value) { // The first value is the key of the value profile, which will not change. - Vals.push_back(ProfileData->getOperand(Idx)); + Vals.push_back(Key); uint64_t Count = - mdconst::dyn_extract(ProfileData->getOperand(Idx + 1)) - ->getValue() - .getZExtValue(); + mdconst::dyn_extract(Value)->getValue().getZExtValue(); // Don't scale the magic number. if (Count == NOMORE_ICP_MAGICNUM) { - Vals.push_back(ProfileData->getOperand(Idx + 1)); - continue; + Vals.push_back(Value); + return true; } // Using APInt::div may be expensive, but most cases should fit 64 bits. APInt Val(128, Count); Val *= APS; Vals.push_back(MDB.createConstant(ConstantInt::get( Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue()))); - } + return true; + }); + I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals)); }