Skip to content

Commit d82d509

Browse files
committed
Restrict caching
1 parent e4c6990 commit d82d509

File tree

6 files changed

+177
-204
lines changed

6 files changed

+177
-204
lines changed

llvm/docs/MLGO.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
508508

509509
.. code-block:: c++
510510

511-
const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
511+
ir2vec::Embedding FuncVector = Emb->getFunctionVector();
512512

513513
Currently, ``Embedder`` can generate embeddings at three levels: Instructions,
514514
Basic Blocks, and Functions. Appropriate getters are provided to access the

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -533,21 +533,18 @@ class Embedder {
533533
/// in the IR instructions to generate the vector representation.
534534
const float OpcWeight, TypeWeight, ArgWeight;
535535

536-
// Utility maps - these are used to store the vector representations of
537-
// instructions, basic blocks and functions.
538-
mutable Embedding FuncVector;
539-
mutable BBEmbeddingsMap BBVecMap;
540-
mutable InstEmbeddingsMap InstVecMap;
541-
542536
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
543537

544538
/// Function to compute embeddings. It generates embeddings for all
545539
/// the instructions and basic blocks in the function F.
546-
void computeEmbeddings() const;
540+
Embedding computeEmbeddings() const;
547541

548542
/// Function to compute the embedding for a given basic block.
549-
/// Specific to the kind of embeddings being computed.
550-
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
543+
Embedding computeEmbeddings(const BasicBlock &BB) const;
544+
545+
/// Function to compute the embedding for a given instruction. Specific to the
546+
/// kind of embeddings being computed.
547+
virtual Embedding computeEmbeddings(const Instruction &I) const = 0;
551548

552549
public:
553550
virtual ~Embedder() = default;
@@ -556,31 +553,29 @@ class Embedder {
556553
LLVM_ABI static std::unique_ptr<Embedder>
557554
create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab);
558555

559-
/// Returns a map containing instructions and the corresponding embeddings for
560-
/// the function F if it has been computed. If not, it computes the embeddings
561-
/// for the function and returns the map.
562-
LLVM_ABI const InstEmbeddingsMap &getInstVecMap() const;
556+
/// Returns the embedding for a given instruction in the function F
557+
LLVM_ABI Embedding getInstVector(const Instruction &I) const;
563558

564-
/// Returns a map containing basic block and the corresponding embeddings for
565-
/// the function F if it has been computed. If not, it computes the embeddings
566-
/// for the function and returns the map.
567-
LLVM_ABI const BBEmbeddingsMap &getBBVecMap() const;
559+
/// Returns the embedding for a given basic block in the function F
560+
LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const;
568561

569-
/// Returns the embedding for a given basic block in the function F if it has
570-
/// been computed. If not, it computes the embedding for the basic block and
571-
/// returns it.
572-
LLVM_ABI const Embedding &getBBVector(const BasicBlock &BB) const;
562+
/// Returns the embedding for the current function.
563+
LLVM_ABI Embedding getFunctionVector() const;
573564

574-
/// Computes and returns the embedding for the current function.
575-
LLVM_ABI const Embedding &getFunctionVector() const;
565+
/// Invalidate embeddings if cached. The embeddings may not be relevant
566+
/// anymore when the IR changes due to transformations. In such cases, the
567+
/// cached embeddings should be invalidated to ensure
568+
/// correctness/recomputation. This is a no-op for SymbolicEmbedder but
569+
/// removes all the cached entries in FlowAwareEmbedder.
570+
virtual void invalidateEmbeddings() {}
576571
};
577572

578573
/// Class for computing the Symbolic embeddings of IR2Vec.
579574
/// Symbolic embeddings are constructed based on the entity-level
580575
/// representations obtained from the Vocabulary.
581576
class LLVM_ABI SymbolicEmbedder : public Embedder {
582577
private:
583-
void computeEmbeddings(const BasicBlock &BB) const override;
578+
Embedding computeEmbeddings(const Instruction &I) const override;
584579

585580
public:
586581
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
@@ -592,11 +587,17 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
592587
/// embeddings, and additionally capture the flow information in the IR.
593588
class LLVM_ABI FlowAwareEmbedder : public Embedder {
594589
private:
595-
void computeEmbeddings(const BasicBlock &BB) const override;
590+
// Utility map for caching - needed for flow-aware dependencies
591+
mutable InstEmbeddingsMap InstVecMap;
592+
593+
Embedding computeEmbeddings(const Instruction &I) const override;
596594

597595
public:
598596
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
599597
: Embedder(F, Vocab) {}
598+
599+
/// Override to invalidate all cached instruction embeddings
600+
void invalidateEmbeddings() override;
600601
};
601602

602603
} // namespace ir2vec

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 78 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ void Embedding::print(raw_ostream &OS) const {
155155

156156
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
157157
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
158-
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
159-
FuncVector(Embedding(Dimension)) {}
158+
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
159+
}
160160

161161
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
162162
const Vocabulary &Vocab) {
@@ -169,112 +169,104 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
169169
return nullptr;
170170
}
171171

172-
const InstEmbeddingsMap &Embedder::getInstVecMap() const {
173-
if (InstVecMap.empty())
174-
computeEmbeddings();
175-
return InstVecMap;
172+
Embedding Embedder::getInstVector(const Instruction &I) const {
173+
return computeEmbeddings(I);
176174
}
177175

178-
const BBEmbeddingsMap &Embedder::getBBVecMap() const {
179-
if (BBVecMap.empty())
180-
computeEmbeddings();
181-
return BBVecMap;
176+
Embedding Embedder::getBBVector(const BasicBlock &BB) const {
177+
return computeEmbeddings(BB);
182178
}
183179

184-
const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
185-
auto It = BBVecMap.find(&BB);
186-
if (It != BBVecMap.end())
187-
return It->second;
188-
computeEmbeddings(BB);
189-
return BBVecMap[&BB];
190-
}
191-
192-
const Embedding &Embedder::getFunctionVector() const {
180+
Embedding Embedder::getFunctionVector() const {
193181
// Currently, we always (re)compute the embeddings for the function.
194182
// This is cheaper than caching the vector.
195-
computeEmbeddings();
196-
return FuncVector;
183+
return computeEmbeddings();
197184
}
198185

199-
void Embedder::computeEmbeddings() const {
186+
Embedding Embedder::computeEmbeddings() const {
200187
if (F.isDeclaration())
201-
return;
188+
return Embedding(Dimension, 0.0);
202189

203-
FuncVector = Embedding(Dimension, 0.0);
190+
Embedding FuncVector(Dimension, 0.0);
204191

205192
// Consider only the basic blocks that are reachable from entry
206193
for (const BasicBlock *BB : depth_first(&F)) {
207-
computeEmbeddings(*BB);
208-
FuncVector += BBVecMap[BB];
194+
FuncVector += computeEmbeddings(*BB);
209195
}
196+
return FuncVector;
210197
}
211198

212-
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
199+
Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
213200
Embedding BBVector(Dimension, 0);
214201

215202
// We consider only the non-debug and non-pseudo instructions
216203
for (const auto &I : BB.instructionsWithoutDebug()) {
217-
Embedding ArgEmb(Dimension, 0);
218-
for (const auto &Op : I.operands())
219-
ArgEmb += Vocab[*Op];
220-
auto InstVector =
221-
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
222-
if (const auto *IC = dyn_cast<CmpInst>(&I))
223-
InstVector += Vocab[IC->getPredicate()];
224-
InstVecMap[&I] = InstVector;
225-
BBVector += InstVector;
204+
BBVector += computeEmbeddings(I);
226205
}
227-
BBVecMap[&BB] = BBVector;
206+
return BBVector;
228207
}
229208

230-
void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
231-
Embedding BBVector(Dimension, 0);
209+
Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
210+
Embedding ArgEmb(Dimension, 0);
211+
for (const auto &Op : I.operands())
212+
ArgEmb += Vocab[*Op];
213+
auto InstVector =
214+
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
215+
if (const auto *IC = dyn_cast<CmpInst>(&I))
216+
InstVector += Vocab[IC->getPredicate()];
217+
return InstVector;
218+
}
232219

233-
// We consider only the non-debug and non-pseudo instructions
234-
for (const auto &I : BB.instructionsWithoutDebug()) {
235-
// TODO: Handle call instructions differently.
236-
// For now, we treat them like other instructions
237-
Embedding ArgEmb(Dimension, 0);
238-
for (const auto &Op : I.operands()) {
239-
// If the operand is defined elsewhere, we use its embedding
240-
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
241-
auto DefIt = InstVecMap.find(DefInst);
242-
// Fixme (#159171): Ideally we should never miss an instruction
243-
// embedding here.
244-
// But when we have cyclic dependencies (e.g., phi
245-
// nodes), we might miss the embedding. In such cases, we fall back to
246-
// using the vocabulary embedding. This can be fixed by iterating to a
247-
// fixed-point, or by using a simple solver for the set of simultaneous
248-
// equations.
249-
// Another case when we might miss an instruction embedding is when
250-
// the operand instruction is in a different basic block that has not
251-
// been processed yet. This can be fixed by processing the basic blocks
252-
// in a topological order.
253-
if (DefIt != InstVecMap.end())
254-
ArgEmb += DefIt->second;
255-
else
256-
ArgEmb += Vocab[*Op];
257-
}
258-
// If the operand is not defined by an instruction, we use the vocabulary
259-
else {
260-
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
261-
<< *Op << "=" << Vocab[*Op][0] << "\n");
220+
Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
221+
// If we have already computed the embedding for this instruction, return it
222+
auto It = InstVecMap.find(&I);
223+
if (It != InstVecMap.end())
224+
return It->second;
225+
226+
// TODO: Handle call instructions differently.
227+
// For now, we treat them like other instructions
228+
Embedding ArgEmb(Dimension, 0);
229+
for (const auto &Op : I.operands()) {
230+
// If the operand is defined elsewhere, we use its embedding
231+
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
232+
auto DefIt = InstVecMap.find(DefInst);
233+
// Fixme (#159171): Ideally we should never miss an instruction
234+
// embedding here.
235+
// But when we have cyclic dependencies (e.g., phi
236+
// nodes), we might miss the embedding. In such cases, we fall back to
237+
// using the vocabulary embedding. This can be fixed by iterating to a
238+
// fixed-point, or by using a simple solver for the set of simultaneous
239+
// equations.
240+
// Another case when we might miss an instruction embedding is when
241+
// the operand instruction is in a different basic block that has not
242+
// been processed yet. This can be fixed by processing the basic blocks
243+
// in a topological order.
244+
if (DefIt != InstVecMap.end())
245+
ArgEmb += DefIt->second;
246+
else
262247
ArgEmb += Vocab[*Op];
263-
}
264248
}
265-
// Create the instruction vector by combining opcode, type, and arguments
266-
// embeddings
267-
auto InstVector =
268-
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
269-
// Add compare predicate embedding as an additional operand if applicable
270-
if (const auto *IC = dyn_cast<CmpInst>(&I))
271-
InstVector += Vocab[IC->getPredicate()];
272-
InstVecMap[&I] = InstVector;
273-
BBVector += InstVector;
249+
// If the operand is not defined by an instruction, we use the
250+
// vocabulary
251+
else {
252+
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
253+
<< *Op << "=" << Vocab[*Op][0] << "\n");
254+
ArgEmb += Vocab[*Op];
255+
}
274256
}
275-
BBVecMap[&BB] = BBVector;
257+
// Create the instruction vector by combining opcode, type, and arguments
258+
// embeddings
259+
auto InstVector =
260+
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
261+
// Add compare predicate embedding as an additional operand if applicable
262+
if (const auto *IC = dyn_cast<CmpInst>(&I))
263+
InstVector += Vocab[IC->getPredicate()];
264+
InstVecMap[&I] = InstVector;
265+
return InstVector;
276266
}
277267

268+
void FlowAwareEmbedder::invalidateEmbeddings() { InstVecMap.clear(); }
269+
278270
// ==----------------------------------------------------------------------===//
279271
// VocabStorage
280272
//===----------------------------------------------------------------------===//
@@ -695,25 +687,19 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
695687
Emb->getFunctionVector().print(OS);
696688

697689
OS << "Basic block vectors:\n";
698-
const auto &BBMap = Emb->getBBVecMap();
699690
for (const BasicBlock &BB : F) {
700-
auto It = BBMap.find(&BB);
701-
if (It != BBMap.end()) {
702-
OS << "Basic block: " << BB.getName() << ":\n";
703-
It->second.print(OS);
704-
}
691+
auto BBVector = Emb->getBBVector(BB);
692+
OS << "Basic block: " << BB.getName() << ":\n";
693+
BBVector.print(OS);
705694
}
706695

707696
OS << "Instruction vectors:\n";
708-
const auto &InstMap = Emb->getInstVecMap();
709697
for (const BasicBlock &BB : F) {
710698
for (const Instruction &I : BB) {
711-
auto It = InstMap.find(&I);
712-
if (It != InstMap.end()) {
713-
OS << "Instruction: ";
714-
I.print(OS);
715-
It->second.print(OS);
716-
}
699+
auto InstVector = Emb->getInstVector(I);
700+
OS << "Instruction: ";
701+
I.print(OS);
702+
InstVector.print(OS);
717703
}
718704
}
719705
}

llvm/test/Analysis/IR2Vec/unreachable.ll

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,4 @@ return: ; preds = %if.else, %if.then
3131
ret i32 %4
3232
}
3333

34-
; CHECK: Basic block vectors:
35-
; CHECK-NEXT: Basic block: entry:
36-
; CHECK-NEXT: [ 816.20 825.20 834.20 ]
37-
; CHECK-NEXT: Basic block: if.then:
38-
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
39-
; CHECK-NEXT: Basic block: if.else:
40-
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
41-
; CHECK-NEXT: Basic block: return:
42-
; CHECK-NEXT: [ 95.00 97.00 99.00 ]
34+
; CHECK: Function vector: [ 1301.20 1318.20 1335.20 ]

llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,25 +253,17 @@ class IR2VecTool {
253253
break;
254254
}
255255
case BasicBlockLevel: {
256-
const auto &BBVecMap = Emb->getBBVecMap();
257256
for (const BasicBlock &BB : F) {
258-
auto It = BBVecMap.find(&BB);
259-
if (It != BBVecMap.end()) {
260-
OS << BB.getName() << ":";
261-
It->second.print(OS);
262-
}
257+
OS << BB.getName() << ":";
258+
Emb->getBBVector(BB).print(OS);
263259
}
264260
break;
265261
}
266262
case InstructionLevel: {
267-
const auto &InstMap = Emb->getInstVecMap();
268263
for (const BasicBlock &BB : F) {
269264
for (const Instruction &I : BB) {
270-
auto It = InstMap.find(&I);
271-
if (It != InstMap.end()) {
272-
I.print(OS);
273-
It->second.print(OS);
274-
}
265+
I.print(OS);
266+
Emb->getInstVector(I).print(OS);
275267
}
276268
}
277269
break;

0 commit comments

Comments
 (0)