Skip to content

Commit 5911754

Browse files
wdx727lifengxiang1025zcfh
authored
Adding Matching and Inference Functionality to Propeller-PR4: Implement matching and inference and create clusters (#167622)
This PR re-submits the previously reverted PR(#165868) and fixes the return type mismatch error. Co-authored-by: lifengxiang1025 <[email protected]> Co-authored-by: zcfh <[email protected]>
1 parent 785cadd commit 5911754

13 files changed

+490
-6
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- llvm/CodeGen/BasicBlockMatchingAndInference.h ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Infer weights for all basic blocks using matching and inference.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H
14+
#define LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H
15+
16+
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
17+
#include "llvm/CodeGen/MachineFunctionPass.h"
18+
#include "llvm/Transforms/Utils/SampleProfileInference.h"
19+
20+
namespace llvm {
21+
22+
class BasicBlockMatchingAndInference : public MachineFunctionPass {
23+
private:
24+
using Edge = std::pair<const MachineBasicBlock *, const MachineBasicBlock *>;
25+
using BlockWeightMap = DenseMap<const MachineBasicBlock *, uint64_t>;
26+
using EdgeWeightMap = DenseMap<Edge, uint64_t>;
27+
using BlockEdgeMap = DenseMap<const MachineBasicBlock *,
28+
SmallVector<const MachineBasicBlock *, 8>>;
29+
30+
struct WeightInfo {
31+
// Weight of basic blocks.
32+
BlockWeightMap BlockWeights;
33+
// Weight of edges.
34+
EdgeWeightMap EdgeWeights;
35+
};
36+
37+
public:
38+
static char ID;
39+
BasicBlockMatchingAndInference();
40+
41+
StringRef getPassName() const override {
42+
return "Basic Block Matching and Inference";
43+
}
44+
45+
void getAnalysisUsage(AnalysisUsage &AU) const override;
46+
47+
bool runOnMachineFunction(MachineFunction &F) override;
48+
49+
std::optional<WeightInfo> getWeightInfo(StringRef FuncName) const;
50+
51+
private:
52+
StringMap<WeightInfo> ProgramWeightInfo;
53+
54+
WeightInfo initWeightInfoByMatching(MachineFunction &MF);
55+
56+
void generateWeightInfoByInference(MachineFunction &MF,
57+
WeightInfo &MatchWeight);
58+
};
59+
60+
} // end namespace llvm
61+
62+
#endif // LLVM_CODEGEN_BASIC_BLOCK_AND_INFERENCE_H

llvm/include/llvm/CodeGen/BasicBlockSectionsProfileReader.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class BasicBlockSectionsProfileReader {
8686
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
8787
const UniqueBBID &SinkBBID) const;
8888

89+
// Return the complete function path and cluster info for the given function.
90+
std::pair<bool, FunctionPathAndClusterInfo>
91+
getFunctionPathAndClusterInfo(StringRef FuncName) const;
92+
8993
private:
9094
StringRef getAliasName(StringRef FuncName) const {
9195
auto R = FuncAliasMap.find(FuncName);
@@ -195,6 +199,9 @@ class BasicBlockSectionsProfileReaderWrapperPass : public ImmutablePass {
195199
uint64_t getEdgeCount(StringRef FuncName, const UniqueBBID &SrcBBID,
196200
const UniqueBBID &DestBBID) const;
197201

202+
std::pair<bool, FunctionPathAndClusterInfo>
203+
getFunctionPathAndClusterInfo(StringRef FuncName) const;
204+
198205
// Initializes the FunctionNameToDIFilename map for the current module and
199206
// then reads the profile for the matching functions.
200207
bool doInitialization(Module &M) override;

llvm/include/llvm/CodeGen/MachineBlockHashInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ struct BlendedBlockHash {
8080
return Dist;
8181
}
8282

83+
uint16_t getOpcodeHash() const { return OpcodeHash; }
84+
8385
private:
8486
/// The offset of the basic block from the function start.
8587
uint16_t Offset{0};

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ LLVM_ABI MachineFunctionPass *createBasicBlockSectionsPass();
6969

7070
LLVM_ABI MachineFunctionPass *createBasicBlockPathCloningPass();
7171

72+
/// createBasicBlockMatchingAndInferencePass - This pass enables matching
73+
/// and inference when using propeller.
74+
LLVM_ABI MachineFunctionPass *createBasicBlockMatchingAndInferencePass();
75+
7276
/// createMachineBlockHashInfoPass - This pass computes basic block hashes.
7377
LLVM_ABI MachineFunctionPass *createMachineBlockHashInfoPass();
7478

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ LLVM_ABI void initializeAlwaysInlinerLegacyPassPass(PassRegistry &);
5555
LLVM_ABI void initializeAssignmentTrackingAnalysisPass(PassRegistry &);
5656
LLVM_ABI void initializeAssumptionCacheTrackerPass(PassRegistry &);
5757
LLVM_ABI void initializeAtomicExpandLegacyPass(PassRegistry &);
58+
LLVM_ABI void initializeBasicBlockMatchingAndInferencePass(PassRegistry &);
5859
LLVM_ABI void initializeBasicBlockPathCloningPass(PassRegistry &);
5960
LLVM_ABI void
6061
initializeBasicBlockSectionsProfileReaderWrapperPassPass(PassRegistry &);

llvm/include/llvm/Transforms/Utils/SampleProfileInference.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ template <typename FT> class SampleProfileInference {
130130
SampleProfileInference(FunctionT &F, BlockEdgeMap &Successors,
131131
BlockWeightMap &SampleBlockWeights)
132132
: F(F), Successors(Successors), SampleBlockWeights(SampleBlockWeights) {}
133+
SampleProfileInference(FunctionT &F, BlockEdgeMap &Successors,
134+
BlockWeightMap &SampleBlockWeights,
135+
EdgeWeightMap &SampleEdgeWeights)
136+
: F(F), Successors(Successors), SampleBlockWeights(SampleBlockWeights),
137+
SampleEdgeWeights(SampleEdgeWeights) {}
133138

134139
/// Apply the profile inference algorithm for a given function
135140
void apply(BlockWeightMap &BlockWeights, EdgeWeightMap &EdgeWeights);
@@ -157,6 +162,9 @@ template <typename FT> class SampleProfileInference {
157162

158163
/// Map basic blocks to their sampled weights.
159164
BlockWeightMap &SampleBlockWeights;
165+
166+
/// Map edges to their sampled weights.
167+
EdgeWeightMap SampleEdgeWeights;
160168
};
161169

162170
template <typename BT>
@@ -266,6 +274,14 @@ FlowFunction SampleProfileInference<BT>::createFlowFunction(
266274
FlowJump Jump;
267275
Jump.Source = BlockIndex[BB];
268276
Jump.Target = BlockIndex[Succ];
277+
auto It = SampleEdgeWeights.find(std::make_pair(BB, Succ));
278+
if (It != SampleEdgeWeights.end()) {
279+
Jump.HasUnknownWeight = false;
280+
Jump.Weight = It->second;
281+
} else {
282+
Jump.HasUnknownWeight = true;
283+
Jump.Weight = 0;
284+
}
269285
Func.Jumps.push_back(Jump);
270286
}
271287
}
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
//===- llvm/CodeGen/BasicBlockMatchingAndInference.cpp ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// In Propeller's profile, we have already read the hash values of basic blocks,
10+
// as well as the weights of basic blocks and edges in the CFG. In this file,
11+
// we first match the basic blocks in the profile with those in the current
12+
// MachineFunction using the basic block hash, thereby obtaining the weights of
13+
// some basic blocks and edges. Subsequently, we infer the weights of all basic
14+
// blocks using an inference algorithm.
15+
//
16+
// TODO: Integrate part of the code in this file with BOLT's implementation into
17+
// the LLVM infrastructure, enabling both BOLT and Propeller to reuse it.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
21+
#include "llvm/CodeGen/BasicBlockMatchingAndInference.h"
22+
#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
23+
#include "llvm/CodeGen/MachineBlockHashInfo.h"
24+
#include "llvm/CodeGen/Passes.h"
25+
#include "llvm/InitializePasses.h"
26+
#include <llvm/Support/CommandLine.h>
27+
#include <unordered_map>
28+
29+
using namespace llvm;
30+
31+
static cl::opt<float>
32+
PropellerInferThreshold("propeller-infer-threshold",
33+
cl::desc("Threshold for infer stale profile"),
34+
cl::init(0.6), cl::Optional);
35+
36+
/// The object is used to identify and match basic blocks given their hashes.
37+
class StaleMatcher {
38+
public:
39+
/// Initialize stale matcher.
40+
void init(const std::vector<MachineBasicBlock *> &Blocks,
41+
const std::vector<BlendedBlockHash> &Hashes) {
42+
assert(Blocks.size() == Hashes.size() &&
43+
"incorrect matcher initialization");
44+
for (size_t I = 0; I < Blocks.size(); I++) {
45+
MachineBasicBlock *Block = Blocks[I];
46+
uint16_t OpHash = Hashes[I].getOpcodeHash();
47+
OpHashToBlocks[OpHash].push_back(std::make_pair(Hashes[I], Block));
48+
}
49+
}
50+
51+
/// Find the most similar block for a given hash.
52+
MachineBasicBlock *matchBlock(BlendedBlockHash BlendedHash) const {
53+
auto BlockIt = OpHashToBlocks.find(BlendedHash.getOpcodeHash());
54+
if (BlockIt == OpHashToBlocks.end()) {
55+
return nullptr;
56+
}
57+
MachineBasicBlock *BestBlock = nullptr;
58+
uint64_t BestDist = std::numeric_limits<uint64_t>::max();
59+
for (auto It : BlockIt->second) {
60+
MachineBasicBlock *Block = It.second;
61+
BlendedBlockHash Hash = It.first;
62+
uint64_t Dist = Hash.distance(BlendedHash);
63+
if (BestBlock == nullptr || Dist < BestDist) {
64+
BestDist = Dist;
65+
BestBlock = Block;
66+
}
67+
}
68+
return BestBlock;
69+
}
70+
71+
private:
72+
using HashBlockPairType = std::pair<BlendedBlockHash, MachineBasicBlock *>;
73+
std::unordered_map<uint16_t, std::vector<HashBlockPairType>> OpHashToBlocks;
74+
};
75+
76+
INITIALIZE_PASS_BEGIN(BasicBlockMatchingAndInference,
77+
"machine-block-match-infer",
78+
"Machine Block Matching and Inference Analysis", true,
79+
true)
80+
INITIALIZE_PASS_DEPENDENCY(MachineBlockHashInfo)
81+
INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReaderWrapperPass)
82+
INITIALIZE_PASS_END(BasicBlockMatchingAndInference, "machine-block-match-infer",
83+
"Machine Block Matching and Inference Analysis", true, true)
84+
85+
char BasicBlockMatchingAndInference::ID = 0;
86+
87+
BasicBlockMatchingAndInference::BasicBlockMatchingAndInference()
88+
: MachineFunctionPass(ID) {
89+
initializeBasicBlockMatchingAndInferencePass(
90+
*PassRegistry::getPassRegistry());
91+
}
92+
93+
void BasicBlockMatchingAndInference::getAnalysisUsage(AnalysisUsage &AU) const {
94+
AU.addRequired<MachineBlockHashInfo>();
95+
AU.addRequired<BasicBlockSectionsProfileReaderWrapperPass>();
96+
AU.setPreservesAll();
97+
MachineFunctionPass::getAnalysisUsage(AU);
98+
}
99+
100+
std::optional<BasicBlockMatchingAndInference::WeightInfo>
101+
BasicBlockMatchingAndInference::getWeightInfo(StringRef FuncName) const {
102+
auto It = ProgramWeightInfo.find(FuncName);
103+
if (It == ProgramWeightInfo.end()) {
104+
return std::nullopt;
105+
}
106+
return It->second;
107+
}
108+
109+
BasicBlockMatchingAndInference::WeightInfo
110+
BasicBlockMatchingAndInference::initWeightInfoByMatching(MachineFunction &MF) {
111+
std::vector<MachineBasicBlock *> Blocks;
112+
std::vector<BlendedBlockHash> Hashes;
113+
auto BSPR = &getAnalysis<BasicBlockSectionsProfileReaderWrapperPass>();
114+
auto MBHI = &getAnalysis<MachineBlockHashInfo>();
115+
for (auto &Block : MF) {
116+
Blocks.push_back(&Block);
117+
Hashes.push_back(BlendedBlockHash(MBHI->getMBBHash(Block)));
118+
}
119+
StaleMatcher Matcher;
120+
Matcher.init(Blocks, Hashes);
121+
BasicBlockMatchingAndInference::WeightInfo MatchWeight;
122+
auto [IsValid, PathAndClusterInfo] =
123+
BSPR->getFunctionPathAndClusterInfo(MF.getName());
124+
if (!IsValid)
125+
return MatchWeight;
126+
for (auto &BlockCount : PathAndClusterInfo.NodeCounts) {
127+
if (PathAndClusterInfo.BBHashes.count(BlockCount.first.BaseID)) {
128+
auto Hash = PathAndClusterInfo.BBHashes[BlockCount.first.BaseID];
129+
MachineBasicBlock *Block = Matcher.matchBlock(BlendedBlockHash(Hash));
130+
// When a basic block has clone copies, sum their counts.
131+
if (Block != nullptr)
132+
MatchWeight.BlockWeights[Block] += BlockCount.second;
133+
}
134+
}
135+
for (auto &PredItem : PathAndClusterInfo.EdgeCounts) {
136+
auto PredID = PredItem.first.BaseID;
137+
if (!PathAndClusterInfo.BBHashes.count(PredID))
138+
continue;
139+
auto PredHash = PathAndClusterInfo.BBHashes[PredID];
140+
MachineBasicBlock *PredBlock =
141+
Matcher.matchBlock(BlendedBlockHash(PredHash));
142+
if (PredBlock == nullptr)
143+
continue;
144+
for (auto &SuccItem : PredItem.second) {
145+
auto SuccID = SuccItem.first.BaseID;
146+
auto EdgeWeight = SuccItem.second;
147+
if (PathAndClusterInfo.BBHashes.count(SuccID)) {
148+
auto SuccHash = PathAndClusterInfo.BBHashes[SuccID];
149+
MachineBasicBlock *SuccBlock =
150+
Matcher.matchBlock(BlendedBlockHash(SuccHash));
151+
// When an edge has clone copies, sum their counts.
152+
if (SuccBlock != nullptr)
153+
MatchWeight.EdgeWeights[std::make_pair(PredBlock, SuccBlock)] +=
154+
EdgeWeight;
155+
}
156+
}
157+
}
158+
return MatchWeight;
159+
}
160+
161+
void BasicBlockMatchingAndInference::generateWeightInfoByInference(
162+
MachineFunction &MF,
163+
BasicBlockMatchingAndInference::WeightInfo &MatchWeight) {
164+
BlockEdgeMap Successors;
165+
for (auto &Block : MF) {
166+
for (auto *Succ : Block.successors())
167+
Successors[&Block].push_back(Succ);
168+
}
169+
SampleProfileInference<MachineFunction> SPI(
170+
MF, Successors, MatchWeight.BlockWeights, MatchWeight.EdgeWeights);
171+
BlockWeightMap BlockWeights;
172+
EdgeWeightMap EdgeWeights;
173+
SPI.apply(BlockWeights, EdgeWeights);
174+
ProgramWeightInfo.try_emplace(
175+
MF.getName(), BasicBlockMatchingAndInference::WeightInfo{
176+
std::move(BlockWeights), std::move(EdgeWeights)});
177+
}
178+
179+
bool BasicBlockMatchingAndInference::runOnMachineFunction(MachineFunction &MF) {
180+
if (MF.empty())
181+
return false;
182+
auto MatchWeight = initWeightInfoByMatching(MF);
183+
// If the ratio of the number of MBBs in matching to the total number of MBBs
184+
// in the function is less than the threshold value, the processing should be
185+
// abandoned.
186+
if (static_cast<float>(MatchWeight.BlockWeights.size()) / MF.size() <
187+
PropellerInferThreshold) {
188+
return false;
189+
}
190+
generateWeightInfoByInference(MF, MatchWeight);
191+
return false;
192+
}
193+
194+
MachineFunctionPass *llvm::createBasicBlockMatchingAndInferencePass() {
195+
return new BasicBlockMatchingAndInference();
196+
}

0 commit comments

Comments
 (0)