Skip to content

Commit 727a6b8

Browse files
committed
Restrict caching only to Flow-Aware computation
1 parent 0a61c67 commit 727a6b8

File tree

6 files changed

+123
-224
lines changed

6 files changed

+123
-224
lines changed

llvm/docs/MLGO.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
508508

509509
.. code-block:: c++
510510

511-
const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
511+
ir2vec::Embedding FuncVector = Emb->getFunctionVector();
512512

513513
Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
514514
Basic Blocks, and Functions. Appropriate getters are provided to access the

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -533,21 +533,20 @@ class Embedder {
533533
/// in the IR instructions to generate the vector representation.
534534
const float OpcWeight, TypeWeight, ArgWeight;
535535

536-
// Utility maps - these are used to store the vector representations of
537-
// instructions, basic blocks and functions.
538-
mutable Embedding FuncVector;
539-
mutable BBEmbeddingsMap BBVecMap;
540-
mutable InstEmbeddingsMap InstVecMap;
541-
542-
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
536+
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
537+
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
538+
OpcWeight(ir2vec::OpcWeight), TypeWeight(ir2vec::TypeWeight),
539+
ArgWeight(ir2vec::ArgWeight) {}
543540

544-
/// Function to compute embeddings. It generates embeddings for all
545-
/// the instructions and basic blocks in the function F.
546-
void computeEmbeddings() const;
541+
/// Function to compute embeddings.
542+
Embedding computeEmbeddings() const;
547543

548544
/// Function to compute the embedding for a given basic block.
545+
Embedding computeEmbeddings(const BasicBlock &BB) const;
546+
547+
/// Function to compute the embedding for a given instruction.
549548
/// Specific to the kind of embeddings being computed.
550-
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
549+
virtual Embedding computeEmbeddings(const Instruction &I) const = 0;
551550

552551
public:
553552
virtual ~Embedder() = default;
@@ -556,31 +555,28 @@ class Embedder {
556555
LLVM_ABI static std::unique_ptr<Embedder>
557556
create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab);
558557

559-
/// Returns a map containing instructions and the corresponding embeddings for
560-
/// the function F if it has been computed. If not, it computes the embeddings
561-
/// for the function and returns the map.
562-
LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const;
563-
564-
/// Returns a map containing basic block and the corresponding embeddings for
565-
/// the function F if it has been computed. If not, it computes the embeddings
566-
/// for the function and returns the map.
567-
LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const;
558+
/// Computes and returns the embedding for a given instruction in the function
559+
/// F
560+
LLVM_ABI Embedding getInstVector(const Instruction &I) const {
561+
return computeEmbeddings(I);
562+
}
568563

569-
/// Returns the embedding for a given basic block in the function F if it has
570-
/// been computed. If not, it computes the embedding for the basic block and
571-
/// returns it.
572-
LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const;
564+
/// Computes and returns the embedding for a given basic block in the function
565+
/// F
566+
LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const {
567+
return computeEmbeddings(BB);
568+
}
573569

574570
/// Computes and returns the embedding for the current function.
575-
LLVM_ABI const Embedding &getFunctionVector() const;
571+
LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); }
576572
};
577573

578574
/// Class for computing the Symbolic embeddings of IR2Vec.
579575
/// Symbolic embeddings are constructed based on the entity-level
580576
/// representations obtained from the Vocabulary.
581577
class LLVM_ABI SymbolicEmbedder : public Embedder {
582578
private:
583-
void computeEmbeddings(const BasicBlock &BB) const override;
579+
Embedding computeEmbeddings(const Instruction &I) const override;
584580

585581
public:
586582
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
@@ -592,7 +588,10 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
592588
/// embeddings, and additionally capture the flow information in the IR.
593589
class LLVM_ABI FlowAwareEmbedder : public Embedder {
594590
private:
595-
void computeEmbeddings(const BasicBlock &BB) const override;
591+
// FlowAware embeddings would benefit from caching instruction embeddings as
592+
// they are reused while computing the embeddings of other instructions.
593+
mutable InstEmbeddingsMap InstVecMap;
594+
Embedding computeEmbeddings(const Instruction &I) const override;
596595

597596
public:
598597
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 71 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
153153
// Embedder and its subclasses
154154
//===----------------------------------------------------------------------===//
155155

156-
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
157-
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
158-
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
159-
FuncVector(Embedding(Dimension)) {}
160-
161156
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
162157
const Vocabulary &Vocab) {
163158
switch (Mode) {
@@ -169,110 +164,83 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
169164
return nullptr;
170165
}
171166

172-
const InstEmbeddingsMap &Embedder::getInstVecMap() const {
173-
if (InstVecMap.empty())
174-
computeEmbeddings();
175-
return InstVecMap;
176-
}
177-
178-
const BBEmbeddingsMap &Embedder::getBBVecMap() const {
179-
if (BBVecMap.empty())
180-
computeEmbeddings();
181-
return BBVecMap;
182-
}
183-
184-
const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
185-
auto It = BBVecMap.find(&BB);
186-
if (It != BBVecMap.end())
187-
return It->second;
188-
computeEmbeddings(BB);
189-
return BBVecMap[&BB];
190-
}
167+
Embedding Embedder::computeEmbeddings() const {
168+
Embedding FuncVector(Dimension, 0);
191169

192-
const Embedding &Embedder::getFunctionVector() const {
193-
// Currently, we always (re)compute the embeddings for the function.
194-
// This is cheaper than caching the vector.
195-
computeEmbeddings();
196-
return FuncVector;
197-
}
198-
199-
void Embedder::computeEmbeddings() const {
200170
if (F.isDeclaration())
201-
return;
202-
203-
FuncVector = Embedding(Dimension, 0.0);
171+
return FuncVector;
204172

205173
// Consider only the basic blocks that are reachable from entry
206-
for (const BasicBlock *BB : depth_first(&F)) {
207-
computeEmbeddings(*BB);
208-
FuncVector += BBVecMap[BB];
209-
}
210-
}
211-
212-
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
213-
Embedding BBVector(Dimension, 0);
214-
215-
// We consider only the non-debug and non-pseudo instructions
216-
for (const auto &I : BB.instructionsWithoutDebug()) {
217-
Embedding ArgEmb(Dimension, 0);
218-
for (const auto &Op : I.operands())
219-
ArgEmb += Vocab[*Op];
220-
auto InstVector =
221-
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
222-
if (const auto *IC = dyn_cast<CmpInst>(&I))
223-
InstVector += Vocab[IC->getPredicate()];
224-
InstVecMap[&I] = InstVector;
225-
BBVector += InstVector;
226-
}
227-
BBVecMap[&BB] = BBVector;
174+
for (const BasicBlock *BB : depth_first(&F))
175+
FuncVector += computeEmbeddings(*BB);
176+
return FuncVector;
228177
}
229178

230-
void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
179+
Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
231180
Embedding BBVector(Dimension, 0);
181+
for (const Instruction &I : BB.instructionsWithoutDebug())
182+
BBVector += computeEmbeddings(I);
183+
return BBVector;
184+
}
185+
186+
Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
187+
// Currently, we always (re)compute the embeddings for symbolic embedder.
188+
// This is cheaper than caching the vectors.
189+
Embedding ArgEmb(Dimension, 0);
190+
for (const auto &Op : I.operands())
191+
ArgEmb += Vocab[*Op];
192+
auto InstVector =
193+
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
194+
if (const auto *IC = dyn_cast<CmpInst>(&I))
195+
InstVector += Vocab[IC->getPredicate()];
196+
return InstVector;
197+
}
198+
199+
Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
200+
// If we have already computed the embedding for this instruction, return it
201+
auto It = InstVecMap.find(&I);
202+
if (It != InstVecMap.end())
203+
return It->second;
232204

233-
// We consider only the non-debug and non-pseudo instructions
234-
for (const auto &I : BB.instructionsWithoutDebug()) {
235-
// TODO: Handle call instructions differently.
236-
// For now, we treat them like other instructions
237-
Embedding ArgEmb(Dimension, 0);
238-
for (const auto &Op : I.operands()) {
239-
// If the operand is defined elsewhere, we use its embedding
240-
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
241-
auto DefIt = InstVecMap.find(DefInst);
242-
// Fixme (#159171): Ideally we should never miss an instruction
243-
// embedding here.
244-
// But when we have cyclic dependencies (e.g., phi
245-
// nodes), we might miss the embedding. In such cases, we fall back to
246-
// using the vocabulary embedding. This can be fixed by iterating to a
247-
// fixed-point, or by using a simple solver for the set of simultaneous
248-
// equations.
249-
// Another case when we might miss an instruction embedding is when
250-
// the operand instruction is in a different basic block that has not
251-
// been processed yet. This can be fixed by processing the basic blocks
252-
// in a topological order.
253-
if (DefIt != InstVecMap.end())
254-
ArgEmb += DefIt->second;
255-
else
256-
ArgEmb += Vocab[*Op];
257-
}
258-
// If the operand is not defined by an instruction, we use the vocabulary
259-
else {
260-
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
261-
<< *Op << "=" << Vocab[*Op][0] << "\n");
205+
// TODO: Handle call instructions differently.
206+
// For now, we treat them like other instructions
207+
Embedding ArgEmb(Dimension, 0);
208+
for (const auto &Op : I.operands()) {
209+
// If the operand is defined elsewhere, we use its embedding
210+
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
211+
auto DefIt = InstVecMap.find(DefInst);
212+
// Fixme (#159171): Ideally we should never miss an instruction
213+
// embedding here.
214+
// But when we have cyclic dependencies (e.g., phi
215+
// nodes), we might miss the embedding. In such cases, we fall back to
216+
// using the vocabulary embedding. This can be fixed by iterating to a
217+
// fixed-point, or by using a simple solver for the set of simultaneous
218+
// equations.
219+
// Another case when we might miss an instruction embedding is when
220+
// the operand instruction is in a different basic block that has not
221+
// been processed yet. This can be fixed by processing the basic blocks
222+
// in a topological order.
223+
if (DefIt != InstVecMap.end())
224+
ArgEmb += DefIt->second;
225+
else
262226
ArgEmb += Vocab[*Op];
263-
}
264227
}
265-
// Create the instruction vector by combining opcode, type, and arguments
266-
// embeddings
267-
auto InstVector =
268-
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
269-
// Add compare predicate embedding as an additional operand if applicable
270-
if (const auto *IC = dyn_cast<CmpInst>(&I))
271-
InstVector += Vocab[IC->getPredicate()];
272-
InstVecMap[&I] = InstVector;
273-
BBVector += InstVector;
228+
// If the operand is not defined by an instruction, we use the
229+
// vocabulary
230+
else {
231+
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
232+
<< *Op << "=" << Vocab[*Op][0] << "\n");
233+
ArgEmb += Vocab[*Op];
234+
}
274235
}
275-
BBVecMap[&BB] = BBVector;
236+
// Create the instruction vector by combining opcode, type, and arguments
237+
// embeddings
238+
auto InstVector =
239+
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
240+
if (const auto *IC = dyn_cast<CmpInst>(&I))
241+
InstVector += Vocab[IC->getPredicate()];
242+
InstVecMap[&I] = InstVector;
243+
return InstVector;
276244
}
277245

278246
// ==----------------------------------------------------------------------===//
@@ -695,25 +663,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
695663
Emb->getFunctionVector().print(OS);
696664

697665
OS << "Basic block vectors:\n";
698-
const auto &BBMap = Emb->getBBVecMap();
699666
for (const BasicBlock &BB : F) {
700-
auto It = BBMap.find(&BB);
701-
if (It != BBMap.end()) {
702-
OS << "Basic block: " << BB.getName() << ":\n";
703-
It->second.print(OS);
704-
}
667+
OS << "Basic block: " << BB.getName() << ":\n";
668+
Emb->getBBVector(BB).print(OS);
705669
}
706670

707671
OS << "Instruction vectors:\n";
708-
const auto &InstMap = Emb->getInstVecMap();
709672
for (const BasicBlock &BB : F) {
710673
for (const Instruction &I : BB) {
711-
auto It = InstMap.find(&I);
712-
if (It != InstMap.end()) {
713-
OS << "Instruction: ";
714-
I.print(OS);
715-
It->second.print(OS);
716-
}
674+
OS << "Instruction: ";
675+
I.print(OS);
676+
Emb->getInstVector(I).print(OS);
717677
}
718678
}
719679
}

llvm/test/Analysis/IR2Vec/unreachable.ll

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@ return: ; preds = %if.else, %if.then
3030
%4 = load i32, ptr %retval, align 4
3131
ret i32 %4
3232
}
33-
34-
; CHECK: Basic block vectors:
33+
; We'll get individual basic block embeddings for all blocks in the function.
34+
; But unreachable blocks are not counted for computing the function embedding.
35+
; CHECK: Function vector: [ 1301.20 1318.20 1335.20 ]
36+
; CHECK-NEXT: Basic block vectors:
3537
; CHECK-NEXT: Basic block: entry:
3638
; CHECK-NEXT: [ 816.20 825.20 834.20 ]
3739
; CHECK-NEXT: Basic block: if.then:
3840
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
3941
; CHECK-NEXT: Basic block: if.else:
4042
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
43+
; CHECK-NEXT: Basic block: unreachable:
44+
; CHECK-NEXT: [ 101.00 103.00 105.00 ]
4145
; CHECK-NEXT: Basic block: return:
4246
; CHECK-NEXT: [ 95.00 97.00 99.00 ]

llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,25 +253,17 @@ class IR2VecTool {
253253
break;
254254
}
255255
case BasicBlockLevel: {
256-
const auto &BBVecMap = Emb->getBBVecMap();
257256
for (const BasicBlock &BB : F) {
258-
auto It = BBVecMap.find(&BB);
259-
if (It != BBVecMap.end()) {
260-
OS << BB.getName() << ":";
261-
It->second.print(OS);
262-
}
257+
OS << BB.getName() << ":";
258+
Emb->getBBVector(BB).print(OS);
263259
}
264260
break;
265261
}
266262
case InstructionLevel: {
267-
const auto &InstMap = Emb->getInstVecMap();
268263
for (const BasicBlock &BB : F) {
269264
for (const Instruction &I : BB) {
270-
auto It = InstMap.find(&I);
271-
if (It != InstMap.end()) {
272-
I.print(OS);
273-
It->second.print(OS);
274-
}
265+
I.print(OS);
266+
Emb->getInstVector(I).print(OS);
275267
}
276268
}
277269
break;

0 commit comments

Comments
 (0)