Skip to content

Commit 9e0d3bc

Browse files
authored
[IR2Vec] Restrict caching only to Flow-Aware computation (#162559)
Removed all the caching maps (BB, Inst) in `Embedder` as we don't want to cache embeddings in general. Our earlier experiments on Symbolic embeddings show recomputation of embeddings is cheaper than cache lookups. OTOH, Flow-Aware embeddings would benefit from instruction level caching, as computing the embedding for an instruction would depend on the embeddings of other instructions in a function. So, retained instruction embedding caching logic only for Flow-Aware computation. This also necessitates an `invalidate` method that would clean up the cache when the embeddings would become invalid due to transformations.
1 parent e9205ca commit 9e0d3bc

File tree

6 files changed

+139
-220
lines changed

6 files changed

+139
-220
lines changed

llvm/docs/MLGO.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
508508

509509
.. code-block:: c++
510510

511-
const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
511+
ir2vec::Embedding FuncVector = Emb->getFunctionVector();
512512

513513
Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
514514
Basic Blocks, and Functions. Appropriate getters are provided to access the

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -533,21 +533,20 @@ class Embedder {
533533
/// in the IR instructions to generate the vector representation.
534534
const float OpcWeight, TypeWeight, ArgWeight;
535535

536-
// Utility maps - these are used to store the vector representations of
537-
// instructions, basic blocks and functions.
538-
mutable Embedding FuncVector;
539-
mutable BBEmbeddingsMap BBVecMap;
540-
mutable InstEmbeddingsMap InstVecMap;
541-
542-
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
536+
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
537+
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
538+
OpcWeight(ir2vec::OpcWeight), TypeWeight(ir2vec::TypeWeight),
539+
ArgWeight(ir2vec::ArgWeight) {}
543540

544-
/// Function to compute embeddings. It generates embeddings for all
545-
/// the instructions and basic blocks in the function F.
546-
void computeEmbeddings() const;
541+
/// Function to compute embeddings.
542+
Embedding computeEmbeddings() const;
547543

548544
/// Function to compute the embedding for a given basic block.
545+
Embedding computeEmbeddings(const BasicBlock &BB) const;
546+
547+
/// Function to compute the embedding for a given instruction.
549548
/// Specific to the kind of embeddings being computed.
550-
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
549+
virtual Embedding computeEmbeddings(const Instruction &I) const = 0;
551550

552551
public:
553552
virtual ~Embedder() = default;
@@ -556,31 +555,35 @@ class Embedder {
556555
LLVM_ABI static std::unique_ptr<Embedder>
557556
create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab);
558557

559-
/// Returns a map containing instructions and the corresponding embeddings for
560-
/// the function F if it has been computed. If not, it computes the embeddings
561-
/// for the function and returns the map.
562-
LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const;
563-
564-
/// Returns a map containing basic block and the corresponding embeddings for
565-
/// the function F if it has been computed. If not, it computes the embeddings
566-
/// for the function and returns the map.
567-
LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const;
558+
/// Computes and returns the embedding for a given instruction in the function
559+
/// F
560+
LLVM_ABI Embedding getInstVector(const Instruction &I) const {
561+
return computeEmbeddings(I);
562+
}
568563

569-
/// Returns the embedding for a given basic block in the function F if it has
570-
/// been computed. If not, it computes the embedding for the basic block and
571-
/// returns it.
572-
LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const;
564+
/// Computes and returns the embedding for a given basic block in the function
565+
/// F
566+
LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const {
567+
return computeEmbeddings(BB);
568+
}
573569

574570
/// Computes and returns the embedding for the current function.
575-
LLVM_ABI const Embedding &getFunctionVector() const;
571+
LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); }
572+
573+
/// Invalidate embeddings if cached. The embeddings may not be relevant
574+
/// anymore when the IR changes due to transformations. In such cases, the
575+
/// cached embeddings should be invalidated to ensure
576+
/// correctness/recomputation. This is a no-op for SymbolicEmbedder but
577+
/// removes all the cached entries in FlowAwareEmbedder.
578+
virtual void invalidateEmbeddings() { return; }
576579
};
577580

578581
/// Class for computing the Symbolic embeddings of IR2Vec.
579582
/// Symbolic embeddings are constructed based on the entity-level
580583
/// representations obtained from the Vocabulary.
581584
class LLVM_ABI SymbolicEmbedder : public Embedder {
582585
private:
583-
void computeEmbeddings(const BasicBlock &BB) const override;
586+
Embedding computeEmbeddings(const Instruction &I) const override;
584587

585588
public:
586589
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
@@ -592,11 +595,15 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
592595
/// embeddings, and additionally capture the flow information in the IR.
593596
class LLVM_ABI FlowAwareEmbedder : public Embedder {
594597
private:
595-
void computeEmbeddings(const BasicBlock &BB) const override;
598+
// FlowAware embeddings would benefit from caching instruction embeddings as
599+
// they are reused while computing the embeddings of other instructions.
600+
mutable InstEmbeddingsMap InstVecMap;
601+
Embedding computeEmbeddings(const Instruction &I) const override;
596602

597603
public:
598604
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
599605
: Embedder(F, Vocab) {}
606+
void invalidateEmbeddings() override { InstVecMap.clear(); }
600607
};
601608

602609
} // namespace ir2vec

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 71 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
153153
// Embedder and its subclasses
154154
//===----------------------------------------------------------------------===//
155155

156-
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
157-
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
158-
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
159-
FuncVector(Embedding(Dimension)) {}
160-
161156
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
162157
const Vocabulary &Vocab) {
163158
switch (Mode) {
@@ -169,110 +164,85 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
169164
return nullptr;
170165
}
171166

172-
const InstEmbeddingsMap &Embedder::getInstVecMap() const {
173-
if (InstVecMap.empty())
174-
computeEmbeddings();
175-
return InstVecMap;
176-
}
177-
178-
const BBEmbeddingsMap &Embedder::getBBVecMap() const {
179-
if (BBVecMap.empty())
180-
computeEmbeddings();
181-
return BBVecMap;
182-
}
183-
184-
const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
185-
auto It = BBVecMap.find(&BB);
186-
if (It != BBVecMap.end())
187-
return It->second;
188-
computeEmbeddings(BB);
189-
return BBVecMap[&BB];
190-
}
167+
Embedding Embedder::computeEmbeddings() const {
168+
Embedding FuncVector(Dimension, 0.0);
191169

192-
const Embedding &Embedder::getFunctionVector() const {
193-
// Currently, we always (re)compute the embeddings for the function.
194-
// This is cheaper than caching the vector.
195-
computeEmbeddings();
196-
return FuncVector;
197-
}
198-
199-
void Embedder::computeEmbeddings() const {
200170
if (F.isDeclaration())
201-
return;
202-
203-
FuncVector = Embedding(Dimension, 0.0);
171+
return FuncVector;
204172

205173
// Consider only the basic blocks that are reachable from entry
206-
for (const BasicBlock *BB : depth_first(&F)) {
207-
computeEmbeddings(*BB);
208-
FuncVector += BBVecMap[BB];
209-
}
174+
for (const BasicBlock *BB : depth_first(&F))
175+
FuncVector += computeEmbeddings(*BB);
176+
return FuncVector;
210177
}
211178

212-
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
179+
Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
213180
Embedding BBVector(Dimension, 0);
214181

215182
// We consider only the non-debug and non-pseudo instructions
216-
for (const auto &I : BB.instructionsWithoutDebug()) {
217-
Embedding ArgEmb(Dimension, 0);
218-
for (const auto &Op : I.operands())
219-
ArgEmb += Vocab[*Op];
220-
auto InstVector =
221-
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
222-
if (const auto *IC = dyn_cast<CmpInst>(&I))
223-
InstVector += Vocab[IC->getPredicate()];
224-
InstVecMap[&I] = InstVector;
225-
BBVector += InstVector;
226-
}
227-
BBVecMap[&BB] = BBVector;
228-
}
229-
230-
void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
231-
Embedding BBVector(Dimension, 0);
183+
for (const auto &I : BB.instructionsWithoutDebug())
184+
BBVector += computeEmbeddings(I);
185+
return BBVector;
186+
}
187+
188+
Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
189+
// Currently, we always (re)compute the embeddings for symbolic embedder.
190+
// This is cheaper than caching the vectors.
191+
Embedding ArgEmb(Dimension, 0);
192+
for (const auto &Op : I.operands())
193+
ArgEmb += Vocab[*Op];
194+
auto InstVector =
195+
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
196+
if (const auto *IC = dyn_cast<CmpInst>(&I))
197+
InstVector += Vocab[IC->getPredicate()];
198+
return InstVector;
199+
}
200+
201+
Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
202+
// If we have already computed the embedding for this instruction, return it
203+
auto It = InstVecMap.find(&I);
204+
if (It != InstVecMap.end())
205+
return It->second;
232206

233-
// We consider only the non-debug and non-pseudo instructions
234-
for (const auto &I : BB.instructionsWithoutDebug()) {
235-
// TODO: Handle call instructions differently.
236-
// For now, we treat them like other instructions
237-
Embedding ArgEmb(Dimension, 0);
238-
for (const auto &Op : I.operands()) {
239-
// If the operand is defined elsewhere, we use its embedding
240-
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
241-
auto DefIt = InstVecMap.find(DefInst);
242-
// Fixme (#159171): Ideally we should never miss an instruction
243-
// embedding here.
244-
// But when we have cyclic dependencies (e.g., phi
245-
// nodes), we might miss the embedding. In such cases, we fall back to
246-
// using the vocabulary embedding. This can be fixed by iterating to a
247-
// fixed-point, or by using a simple solver for the set of simultaneous
248-
// equations.
249-
// Another case when we might miss an instruction embedding is when
250-
// the operand instruction is in a different basic block that has not
251-
// been processed yet. This can be fixed by processing the basic blocks
252-
// in a topological order.
253-
if (DefIt != InstVecMap.end())
254-
ArgEmb += DefIt->second;
255-
else
256-
ArgEmb += Vocab[*Op];
257-
}
258-
// If the operand is not defined by an instruction, we use the vocabulary
259-
else {
260-
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
261-
<< *Op << "=" << Vocab[*Op][0] << "\n");
207+
// TODO: Handle call instructions differently.
208+
// For now, we treat them like other instructions
209+
Embedding ArgEmb(Dimension, 0);
210+
for (const auto &Op : I.operands()) {
211+
// If the operand is defined elsewhere, we use its embedding
212+
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
213+
auto DefIt = InstVecMap.find(DefInst);
214+
// Fixme (#159171): Ideally we should never miss an instruction
215+
// embedding here.
216+
// But when we have cyclic dependencies (e.g., phi
217+
// nodes), we might miss the embedding. In such cases, we fall back to
218+
// using the vocabulary embedding. This can be fixed by iterating to a
219+
// fixed-point, or by using a simple solver for the set of simultaneous
220+
// equations.
221+
// Another case when we might miss an instruction embedding is when
222+
// the operand instruction is in a different basic block that has not
223+
// been processed yet. This can be fixed by processing the basic blocks
224+
// in a topological order.
225+
if (DefIt != InstVecMap.end())
226+
ArgEmb += DefIt->second;
227+
else
262228
ArgEmb += Vocab[*Op];
263-
}
264229
}
265-
// Create the instruction vector by combining opcode, type, and arguments
266-
// embeddings
267-
auto InstVector =
268-
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
269-
// Add compare predicate embedding as an additional operand if applicable
270-
if (const auto *IC = dyn_cast<CmpInst>(&I))
271-
InstVector += Vocab[IC->getPredicate()];
272-
InstVecMap[&I] = InstVector;
273-
BBVector += InstVector;
230+
// If the operand is not defined by an instruction, we use the
231+
// vocabulary
232+
else {
233+
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
234+
<< *Op << "=" << Vocab[*Op][0] << "\n");
235+
ArgEmb += Vocab[*Op];
236+
}
274237
}
275-
BBVecMap[&BB] = BBVector;
238+
// Create the instruction vector by combining opcode, type, and arguments
239+
// embeddings
240+
auto InstVector =
241+
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
242+
if (const auto *IC = dyn_cast<CmpInst>(&I))
243+
InstVector += Vocab[IC->getPredicate()];
244+
InstVecMap[&I] = InstVector;
245+
return InstVector;
276246
}
277247

278248
// ==----------------------------------------------------------------------===//
@@ -695,25 +665,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
695665
Emb->getFunctionVector().print(OS);
696666

697667
OS << "Basic block vectors:\n";
698-
const auto &BBMap = Emb->getBBVecMap();
699668
for (const BasicBlock &BB : F) {
700-
auto It = BBMap.find(&BB);
701-
if (It != BBMap.end()) {
702-
OS << "Basic block: " << BB.getName() << ":\n";
703-
It->second.print(OS);
704-
}
669+
OS << "Basic block: " << BB.getName() << ":\n";
670+
Emb->getBBVector(BB).print(OS);
705671
}
706672

707673
OS << "Instruction vectors:\n";
708-
const auto &InstMap = Emb->getInstVecMap();
709674
for (const BasicBlock &BB : F) {
710675
for (const Instruction &I : BB) {
711-
auto It = InstMap.find(&I);
712-
if (It != InstMap.end()) {
713-
OS << "Instruction: ";
714-
I.print(OS);
715-
It->second.print(OS);
716-
}
676+
OS << "Instruction: ";
677+
I.print(OS);
678+
Emb->getInstVector(I).print(OS);
717679
}
718680
}
719681
}

llvm/test/Analysis/IR2Vec/unreachable.ll

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@ return: ; preds = %if.else, %if.then
3030
%4 = load i32, ptr %retval, align 4
3131
ret i32 %4
3232
}
33-
34-
; CHECK: Basic block vectors:
33+
; We'll get individual basic block embeddings for all blocks in the function.
34+
; But unreachable blocks are not counted for computing the function embedding.
35+
; CHECK: Function vector: [ 1301.20 1318.20 1335.20 ]
36+
; CHECK-NEXT: Basic block vectors:
3537
; CHECK-NEXT: Basic block: entry:
3638
; CHECK-NEXT: [ 816.20 825.20 834.20 ]
3739
; CHECK-NEXT: Basic block: if.then:
3840
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
3941
; CHECK-NEXT: Basic block: if.else:
4042
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
43+
; CHECK-NEXT: Basic block: unreachable:
44+
; CHECK-NEXT: [ 101.00 103.00 105.00 ]
4145
; CHECK-NEXT: Basic block: return:
4246
; CHECK-NEXT: [ 95.00 97.00 99.00 ]

llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,25 +253,17 @@ class IR2VecTool {
253253
break;
254254
}
255255
case BasicBlockLevel: {
256-
const auto &BBVecMap = Emb->getBBVecMap();
257256
for (const BasicBlock &BB : F) {
258-
auto It = BBVecMap.find(&BB);
259-
if (It != BBVecMap.end()) {
260-
OS << BB.getName() << ":";
261-
It->second.print(OS);
262-
}
257+
OS << BB.getName() << ":";
258+
Emb->getBBVector(BB).print(OS);
263259
}
264260
break;
265261
}
266262
case InstructionLevel: {
267-
const auto &InstMap = Emb->getInstVecMap();
268263
for (const BasicBlock &BB : F) {
269264
for (const Instruction &I : BB) {
270-
auto It = InstMap.find(&I);
271-
if (It != InstMap.end()) {
272-
I.print(OS);
273-
It->second.print(OS);
274-
}
265+
I.print(OS);
266+
Emb->getInstVector(I).print(OS);
275267
}
276268
}
277269
break;

0 commit comments

Comments
 (0)