Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ struct BBClusterInfo {
unsigned PositionInCluster;
};

// The prefetch symbol is emitted immediately after the call of the given index,
// in block `BBID` (First call has an index of 1). Zero callsite index means the
// start of the block.
struct CallsiteID {
UniqueBBID BBID;
unsigned CallsiteIndex;
};

// This represents the raw input profile for one function.
struct FunctionPathAndClusterInfo {
// BB Cluster information specified by `UniqueBBID`s.
Expand All @@ -50,9 +58,12 @@ struct FunctionPathAndClusterInfo {
// the edge a -> b (a is not cloned). The index of the path in this vector
// determines the `UniqueBBID::CloneID` of the cloned blocks in that path.
SmallVector<SmallVector<unsigned>> ClonePaths;
// Code prefetch targets, specified by the callsite ID immediately after
// which beginning must be targetted for prefetching.
SmallVector<CallsiteID> PrefetchTargets;
// Node counts for each basic block.
DenseMap<UniqueBBID, uint64_t> NodeCounts;
// Edge counts for each edge, stored as a nested map.
// Edge counts for each edge.
DenseMap<UniqueBBID, DenseMap<UniqueBBID, uint64_t>> EdgeCounts;
// Hash for each basic block. The Hashes are stored for every original block
// (not cloned blocks), hence the map key being unsigned instead of
Expand Down Expand Up @@ -86,6 +97,11 @@ class BasicBlockSectionsProfileReader {
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
const UniqueBBID &SinkBBID) const;

// Returns the prefetch targets (identified by their containing callsite IDs)
// for function `FuncName`.
SmallVector<CallsiteID>
getPrefetchTargetsForFunction(StringRef FuncName) const;

private:
StringRef getAliasName(StringRef FuncName) const {
auto R = FuncAliasMap.find(FuncName);
Expand Down Expand Up @@ -195,6 +211,9 @@ class BasicBlockSectionsProfileReaderWrapperPass : public ImmutablePass {
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
const UniqueBBID &DestBBID) const;

SmallVector<CallsiteID>
getPrefetchTargetsForFunction(StringRef FuncName) const;

// Initializes the FunctionNameToDIFilename map for the current module and
// then reads the profile for the matching functions.
bool doInitialization(Module &M) override;
Expand Down
14 changes: 14 additions & 0 deletions llvm/include/llvm/CodeGen/MachineBasicBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ class MachineBasicBlock
/// is only computed once and is cached.
mutable MCSymbol *CachedMCSymbol = nullptr;

/// Contains the callsite indices in this block that are targets of code
/// prefetching. The index `i` specifies the `i`th call, with zero
/// representing the beginning of the block and ` representing the first call.
/// Must be in ascending order and without duplicates.
SmallVector<unsigned> PrefetchTargetCallsiteIndexes;

/// Cached MCSymbol for this block (used if IsEHContTarget).
mutable MCSymbol *CachedEHContMCSymbol = nullptr;

Expand Down Expand Up @@ -710,6 +716,14 @@ class MachineBasicBlock

std::optional<UniqueBBID> getBBID() const { return BBID; }

const SmallVector<unsigned> &getPrefetchTargetCallsiteIndexes() const {
return PrefetchTargetCallsiteIndexes;
}

void setPrefetchTargetCallsiteIndexes(const SmallVector<unsigned> &V) {
PrefetchTargetCallsiteIndexes = V;
}

/// Returns the section ID of this basic block.
MBBSectionID getSectionID() const { return SectionID; }

Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ LLVM_ABI MachineFunctionPass *createBasicBlockSectionsPass();

LLVM_ABI MachineFunctionPass *createBasicBlockPathCloningPass();

LLVM_ABI MachineFunctionPass *createInsertCodePrefetchPass();

/// createMachineBlockHashInfoPass - This pass computes basic block hashes.
LLVM_ABI MachineFunctionPass *createMachineBlockHashInfoPass();

Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ LLVM_ABI void initializeAssignmentTrackingAnalysisPass(PassRegistry &);
LLVM_ABI void initializeAssumptionCacheTrackerPass(PassRegistry &);
LLVM_ABI void initializeAtomicExpandLegacyPass(PassRegistry &);
LLVM_ABI void initializeBasicBlockPathCloningPass(PassRegistry &);
LLVM_ABI void initializeInsertCodePrefetchPass(PassRegistry &);
LLVM_ABI void
initializeBasicBlockSectionsProfileReaderWrapperPassPass(PassRegistry &);
LLVM_ABI void initializeBasicBlockSectionsPass(PassRegistry &);
Expand Down
35 changes: 33 additions & 2 deletions llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,7 +1985,33 @@ void AsmPrinter::emitFunctionBody() {
// Print a label for the basic block.
emitBasicBlockStart(MBB);
DenseMap<StringRef, unsigned> MnemonicCounts;

SmallVector<unsigned> PrefetchTargets =
MBB.getPrefetchTargetCallsiteIndexes();
auto PrefetchTargetIt = PrefetchTargets.begin();
unsigned LastCallsiteIndex = 0;
// Helper to emit a symbol for the prefetch target and proceed to the next
// one.
auto EmitPrefetchTargetSymbolIfNeeded = [&]() {
if (PrefetchTargetIt != PrefetchTargets.end() &&
*PrefetchTargetIt == LastCallsiteIndex) {
MCSymbol *PrefetchTargetSymbol = OutContext.getOrCreateSymbol(
Twine("__llvm_prefetch_target_") + MF->getName() + Twine("_") +
utostr(MBB.getBBID()->BaseID) + Twine("_") +
utostr(static_cast<unsigned>(*PrefetchTargetIt)));
// If the function is weak-linkage it may be replaced by a strong
// version, in which case the prefetch targets should also be replaced.
OutStreamer->emitSymbolAttribute(
PrefetchTargetSymbol,
MF->getFunction().isWeakForLinker() ? MCSA_Weak : MCSA_Global);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to internal linkage? Can you make this symbol a global, that would conflict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. We are relying on funique as mentioned in the RFC. Should we add a check here to ensure funique has been used (it's possible that funique is not applied on some functions).

OutStreamer->emitLabel(PrefetchTargetSymbol);
++PrefetchTargetIt;
}
};

for (auto &MI : MBB) {
EmitPrefetchTargetSymbolIfNeeded();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have EmitPrefetchTargetIt take a parameter which is the unsigned callsiteindex. I would remove the updates to the PrefetchTargets from the lambda and do it within the loop and replace this line with:

if (PrefetchTargetIt != PrefetchTargets.end() &&
     *PrefetchTargetIt == LastCallSiteIndex) {
  EmitPrefetchTargetSymbolIfNeeded(*PrefetchTargetIt);
  ++PrefetchTargetIt;
}


// Print the assembly for the instruction.
if (!MI.isPosition() && !MI.isImplicitDef() && !MI.isKill() &&
!MI.isDebugInstr()) {
Expand Down Expand Up @@ -2123,8 +2149,11 @@ void AsmPrinter::emitFunctionBody() {
break;
}

if (MI.isCall() && MF->getTarget().Options.BBAddrMap)
OutStreamer->emitLabel(createCallsiteEndSymbol(MBB));
if (MI.isCall()) {
if (MF->getTarget().Options.BBAddrMap)
OutStreamer->emitLabel(createCallsiteEndSymbol(MBB));
LastCallsiteIndex++;
}

if (TM.Options.EmitCallGraphSection && MI.isCall())
handleCallsiteForCallgraph(FuncCGInfo, CallSitesInfoMap, MI);
Expand All @@ -2136,6 +2165,8 @@ void AsmPrinter::emitFunctionBody() {
for (auto &Handler : Handlers)
Handler->endInstruction();
}
// Emit the last prefetch target in case the last instruction was a call.
EmitPrefetchTargetSymbolIfNeeded();

// We must emit temporary symbol for the end of this basic block, if either
// we have BBLabels enabled or if this basic blocks marks the end of a
Expand Down
64 changes: 64 additions & 0 deletions llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ uint64_t BasicBlockSectionsProfileReader::getEdgeCount(
return EdgeIt->second;
}

SmallVector<CallsiteID>
BasicBlockSectionsProfileReader::getPrefetchTargetsForFunction(
StringRef FuncName) const {
return ProgramPathAndClusterInfo.lookup(getAliasName(FuncName))
.PrefetchTargets;
}

// Reads the version 1 basic block sections profile. Profile for each function
// is encoded as follows:
// m <module_name>
Expand Down Expand Up @@ -148,6 +155,36 @@ uint64_t BasicBlockSectionsProfileReader::getEdgeCount(
// +-->: 5 :
// ....
// ****************************************************************************
// This profile can also specify prefetch targets (starting with 't') which
// instruct the compiler to emit a prefetch symbol for the given target.
// A prefetch target is specified by a pair "<bbid>,<subblock_index>" where
// bbid specifies the target basic block and subblock_index is a zero-based
// index. Subblock 0 refers to the region at the beginning of the block up to
// the first callsite. Subblock `i > 0` refers to the region immediately after
// the `i`-th callsite up to the `i+1`-th callsite (or the end of the block).
// The prefetch target is always emitted at the beginning of the subblock.
// This is the beginning of the basic block for `i = 0` and immediately after
// the `i`-th call for every `i > 0`.
//
// Example: A basic block in function "foo" with BBID 10 and two call
// instructions (call_A, call_B). This block is conceptually split into
// subblocks, with the prefetch target symbol emitted at the beginning of each
// subblock.
//
// +----------------------------------+
// | __llvm_prefetch_target_foo_10_0: | <- Subblock 0 (before call_A)
// | Instruction 1 |
// | Instruction 2 |
// | call_A (Callsite 0) |
// | __llvm_prefetch_target_foo_10_1: | <--- Subblock 1 (after call_A,
// | | before call_B)
// | Instruction 3 |
// | call_B (Callsite 1) |
// | __llvm_prefetch_target_foo_10_2: | <--- Subblock 2 (after call_B,
// | | before call_C)
// | Instruction 4 |
// +----------------------------------+
//
Error BasicBlockSectionsProfileReader::ReadV1Profile() {
auto FI = ProgramPathAndClusterInfo.end();

Expand Down Expand Up @@ -308,6 +345,27 @@ Error BasicBlockSectionsProfileReader::ReadV1Profile() {
}
continue;
}
case 't': { // Callsite target specifier.
// Skip the profile when we the profile iterator (FI) refers to the
// past-the-end element.
if (FI == ProgramPathAndClusterInfo.end())
continue;
SmallVector<StringRef, 2> PrefetchTargetStr;
Values[0].split(PrefetchTargetStr, ',');
if (PrefetchTargetStr.size() != 2)
return createProfileParseError(Twine("Callsite target expected: ") +
Values[0]);
auto TargetBBID = parseUniqueBBID(PrefetchTargetStr[0]);
if (!TargetBBID)
return TargetBBID.takeError();
unsigned long long CallsiteIndex;
if (getAsUnsignedInteger(PrefetchTargetStr[1], 10, CallsiteIndex))
return createProfileParseError(Twine("signed integer expected: '") +
PrefetchTargetStr[1]);
FI->second.PrefetchTargets.push_back(
CallsiteID{*TargetBBID, static_cast<unsigned>(CallsiteIndex)});
continue;
}
default:
return createProfileParseError(Twine("invalid specifier: '") +
Twine(Specifier) + "'");
Expand Down Expand Up @@ -514,6 +572,12 @@ uint64_t BasicBlockSectionsProfileReaderWrapperPass::getEdgeCount(
return BBSPR.getEdgeCount(FuncName, SrcBBID, SinkBBID);
}

SmallVector<CallsiteID>
BasicBlockSectionsProfileReaderWrapperPass::getPrefetchTargetsForFunction(
StringRef FuncName) const {
return BBSPR.getPrefetchTargetsForFunction(FuncName);
}

BasicBlockSectionsProfileReader &
BasicBlockSectionsProfileReaderWrapperPass::getBBSPR() {
return BBSPR;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ add_llvm_component_library(LLVMCodeGen
IndirectBrExpandPass.cpp
InitUndef.cpp
InlineSpiller.cpp
InsertCodePrefetch.cpp
InterferenceCache.cpp
InterleavedAccessPass.cpp
InterleavedLoadCombinePass.cpp
Expand Down
101 changes: 101 additions & 0 deletions llvm/lib/CodeGen/InsertCodePrefetch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===-- InsertCodePrefetch.cpp ---=========--------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
/// Code Prefetch Insertion Pass.
//===----------------------------------------------------------------------===//
/// This pass inserts code prefetch instructions according to the prefetch
/// directives in the basic block section profile. The target of a prefetch can
/// be the beginning of any dynamic basic block, that is the beginning of a
/// machine basic block, or immediately after a callsite. A global symbol is
/// emitted at the position of the target so it can be addressed from the
/// prefetch instruction from any module.
//===----------------------------------------------------------------------===//

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/CodeGen/BasicBlockSectionUtils.h"
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/InitializePasses.h"

using namespace llvm;
#define DEBUG_TYPE "insert-code-prefetch"

namespace {
class InsertCodePrefetch : public MachineFunctionPass {
public:
static char ID;

InsertCodePrefetch() : MachineFunctionPass(ID) {
initializeInsertCodePrefetchPass(*PassRegistry::getPassRegistry());
}

StringRef getPassName() const override {
return "Code Prefetch Inserter Pass";
}

void getAnalysisUsage(AnalysisUsage &AU) const override;

// Sets prefetch targets based on the bb section profile.
bool runOnMachineFunction(MachineFunction &MF) override;
};

} // end anonymous namespace

//===----------------------------------------------------------------------===//
// Implementation
//===----------------------------------------------------------------------===//

char InsertCodePrefetch::ID = 0;
INITIALIZE_PASS_BEGIN(InsertCodePrefetch, DEBUG_TYPE, "Code prefetch insertion",
true, false)
INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReaderWrapperPass)
INITIALIZE_PASS_END(InsertCodePrefetch, DEBUG_TYPE, "Code prefetch insertion",
true, false)

bool InsertCodePrefetch::runOnMachineFunction(MachineFunction &MF) {
assert(MF.getTarget().getBBSectionsType() == BasicBlockSection::List &&
"BB Sections list not enabled!");
if (hasInstrProfHashMismatch(MF))
return false;
// Set each block's prefetch targets so AsmPrinter can emit a special symbol
// there.
SmallVector<CallsiteID> PrefetchTargets =
getAnalysis<BasicBlockSectionsProfileReaderWrapperPass>()
.getPrefetchTargetsForFunction(MF.getName());
DenseMap<UniqueBBID, SmallVector<unsigned>> PrefetchTargetsByBBID;
for (const auto &Target : PrefetchTargets)
PrefetchTargetsByBBID[Target.BBID].push_back(Target.CallsiteIndex);
// Sort and uniquify the callsite indices for every block.
for (auto &[K, V] : PrefetchTargetsByBBID) {
llvm::sort(V);
V.erase(llvm::unique(V), V.end());
}
for (auto &MBB : MF) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to iterate the BBIDs and use MF->getBlockNumbered(ID) method to get MBB?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't store a map from BBIDs to MBBs. So there is no way around once iterating over MBBs if that's what you mean (We can create the map here but it would still require an iteration).

auto R = PrefetchTargetsByBBID.find(*MBB.getBBID());
if (R == PrefetchTargetsByBBID.end())
continue;
MBB.setPrefetchTargetCallsiteIndexes(R->second);
}
return false;
}

void InsertCodePrefetch::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<BasicBlockSectionsProfileReaderWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}

MachineFunctionPass *llvm::createInsertCodePrefetchPass() {
return new InsertCodePrefetch();
}
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/TargetPassConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,7 @@ void TargetPassConfig::addMachinePasses() {
addPass(llvm::createBasicBlockSectionsProfileReaderWrapperPass(
TM->getBBSectionsFuncListBuf()));
addPass(llvm::createBasicBlockPathCloningPass());
addPass(llvm::createInsertCodePrefetchPass());
}
addPass(llvm::createBasicBlockSectionsPass());
}
Expand Down
Loading