From 0b86cf2dc28c3a04af4de7678842a582fb0358d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolai=20H=C3=A4hnle?= Date: Sun, 2 Feb 2025 10:05:06 +0100 Subject: [PATCH] CFGPrinter: fix accidentally quadratic behavior Initialize a ModuleStateTracker at most once per BasicBlock instead of once per Instruction. When the CFG info is provided, it is initialized once per function. --- llvm/include/llvm/Analysis/CFGPrinter.h | 31 ++++++--------- llvm/lib/Analysis/CFGPrinter.cpp | 52 +++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/llvm/include/llvm/Analysis/CFGPrinter.h b/llvm/include/llvm/Analysis/CFGPrinter.h index cd785331d1f14..b844e3f11c4a5 100644 --- a/llvm/include/llvm/Analysis/CFGPrinter.h +++ b/llvm/include/llvm/Analysis/CFGPrinter.h @@ -31,6 +31,8 @@ #include "llvm/Support/FormatVariadic.h" namespace llvm { +class ModuleSlotTracker; + template struct GraphTraits; class CFGViewerPass : public PassInfoMixin { public: @@ -61,6 +63,7 @@ class DOTFuncInfo { const Function *F; const BlockFrequencyInfo *BFI; const BranchProbabilityInfo *BPI; + std::unique_ptr MSTStorage; uint64_t MaxFreq; bool ShowHeat; bool EdgeWeights; @@ -68,14 +71,10 @@ class DOTFuncInfo { public: DOTFuncInfo(const Function *F) : DOTFuncInfo(F, nullptr, nullptr, 0) {} + ~DOTFuncInfo(); DOTFuncInfo(const Function *F, const BlockFrequencyInfo *BFI, - const BranchProbabilityInfo *BPI, uint64_t MaxFreq) - : F(F), BFI(BFI), BPI(BPI), MaxFreq(MaxFreq) { - ShowHeat = false; - EdgeWeights = !!BPI; // Print EdgeWeights when BPI is available. - RawWeights = !!BFI; // Print RawWeights when BFI is available. - } + const BranchProbabilityInfo *BPI, uint64_t MaxFreq); const BlockFrequencyInfo *getBFI() const { return BFI; } @@ -83,6 +82,8 @@ class DOTFuncInfo { const Function *getFunction() const { return this->F; } + ModuleSlotTracker *getModuleSlotTracker(); + uint64_t getMaxFreq() const { return MaxFreq; } uint64_t getFreq(const BasicBlock *BB) const { @@ -203,22 +204,12 @@ struct DOTGraphTraits : public DefaultDOTGraphTraits { return SimpleNodeLabelString(Node); } - static void printBasicBlock(raw_string_ostream &OS, const BasicBlock &Node) { - // Prepend label name - Node.printAsOperand(OS, false); - OS << ":\n"; - for (const Instruction &Inst : Node) - OS << Inst << "\n"; - } - static std::string getCompleteNodeLabel( const BasicBlock *Node, DOTFuncInfo *, function_ref - HandleBasicBlock = printBasicBlock, - function_ref - HandleComment = eraseComment) { - return CompleteNodeLabelString(Node, HandleBasicBlock, HandleComment); - } + HandleBasicBlock = {}, + function_ref HandleComment = + eraseComment); std::string getNodeLabel(const BasicBlock *Node, DOTFuncInfo *CFGInfo) { @@ -337,6 +328,6 @@ struct DOTGraphTraits : public DefaultDOTGraphTraits { bool isNodeHidden(const BasicBlock *Node, const DOTFuncInfo *CFGInfo); void computeDeoptOrUnreachablePaths(const Function *F); }; -} // End llvm namespace +} // namespace llvm #endif diff --git a/llvm/lib/Analysis/CFGPrinter.cpp b/llvm/lib/Analysis/CFGPrinter.cpp index af18fb6626e3b..38aad849755be 100644 --- a/llvm/lib/Analysis/CFGPrinter.cpp +++ b/llvm/lib/Analysis/CFGPrinter.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/CFGPrinter.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/IR/ModuleSlotTracker.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/GraphWriter.h" @@ -90,6 +91,22 @@ static void viewCFG(Function &F, const BlockFrequencyInfo *BFI, ViewGraph(&CFGInfo, "cfg." + F.getName(), CFGOnly); } +DOTFuncInfo::DOTFuncInfo(const Function *F, const BlockFrequencyInfo *BFI, + const BranchProbabilityInfo *BPI, uint64_t MaxFreq) + : F(F), BFI(BFI), BPI(BPI), MaxFreq(MaxFreq) { + ShowHeat = false; + EdgeWeights = !!BPI; // Print EdgeWeights when BPI is available. + RawWeights = !!BFI; // Print RawWeights when BFI is available. +} + +DOTFuncInfo::~DOTFuncInfo() = default; + +ModuleSlotTracker *DOTFuncInfo::getModuleSlotTracker() { + if (!MSTStorage) + MSTStorage = std::make_unique(F->getParent()); + return &*MSTStorage; +} + PreservedAnalyses CFGViewerPass::run(Function &F, FunctionAnalysisManager &AM) { if (!CFGFuncName.empty() && !F.getName().contains(CFGFuncName)) return PreservedAnalyses::all(); @@ -208,3 +225,38 @@ bool DOTGraphTraits::isNodeHidden(const BasicBlock *Node, } return false; } + +std::string DOTGraphTraits::getCompleteNodeLabel( + const BasicBlock *Node, DOTFuncInfo *CFGInfo, + function_ref + HandleBasicBlock, + function_ref HandleComment) { + if (HandleBasicBlock) + return CompleteNodeLabelString(Node, HandleBasicBlock, HandleComment); + + // Default basic block printing + std::optional MSTStorage; + ModuleSlotTracker *MST = nullptr; + + if (CFGInfo) { + MST = CFGInfo->getModuleSlotTracker(); + } else { + MSTStorage.emplace(Node->getModule()); + MST = &*MSTStorage; + } + + return CompleteNodeLabelString( + Node, + function_ref( + [MST](raw_string_ostream &OS, const BasicBlock &Node) -> void { + // Prepend label name + Node.printAsOperand(OS, false, *MST); + OS << ":\n"; + + for (const Instruction &Inst : Node) { + Inst.print(OS, *MST, /* IsForDebug */ false); + OS << '\n'; + } + }), + HandleComment); +}