Skip to content

Commit c937cb5

Browse files
committed
Lazy BBEmbeddings
1 parent 893ef7f commit c937cb5

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class Embedder {
9191
/// the embeddings is specific to the kind of embeddings being computed.
9292
virtual void computeEmbeddings() const = 0;
9393

94+
/// Helper function to compute the embedding for a given basic block.
95+
/// Specific to the kind of embeddings being computed.
96+
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
97+
9498
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
9599
/// zero vector.
96100
Embedding lookupVocab(const std::string &Key) const;
@@ -121,6 +125,11 @@ class Embedder {
121125
/// for the function and returns the map.
122126
const BBEmbeddingsMap &getBBVecMap() const;
123127

128+
/// Returns the embedding for a given basic block in the function F if it has
129+
/// been computed. If not, it computes the embedding for the basic block and
130+
/// returns it.
131+
const Embedding &getBBVector(const BasicBlock &BB) const;
132+
124133
/// Computes and returns the embedding for the current function.
125134
const Embedding &getFunctionVector() const;
126135
};
@@ -130,16 +139,14 @@ class Embedder {
130139
/// representations obtained from the Vocabulary.
131140
class SymbolicEmbedder : public Embedder {
132141
private:
133-
/// Utility function to compute the embedding for a given basic block.
134-
Embedding computeBB2Vec(const BasicBlock &BB) const;
135-
136142
/// Utility function to compute the embedding for a given type.
137143
Embedding getTypeEmbedding(const Type *Ty) const;
138144

139145
/// Utility function to compute the embedding for a given operand.
140146
Embedding getOperandEmbedding(const Value *Op) const;
141147

142148
void computeEmbeddings() const override;
149+
void computeEmbeddings(const BasicBlock &BB) const override;
143150

144151
public:
145152
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ const BBEmbeddingsMap &Embedder::getBBVecMap() const {
115115
return BBVecMap;
116116
}
117117

118+
const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
119+
auto It = BBVecMap.find(&BB);
120+
if (It == BBVecMap.end())
121+
computeEmbeddings(BB);
122+
return It->second;
123+
}
124+
118125
const Embedding &Embedder::getFunctionVector() const {
119126
// Currently, we always (re)compute the embeddings for the function.
120127
// This is cheaper than caching the vector.
@@ -151,17 +158,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
151158

152159
#undef RETURN_LOOKUP_IF
153160

154-
void SymbolicEmbedder::computeEmbeddings() const {
155-
if (F.isDeclaration())
156-
return;
157-
for (const auto &BB : F) {
158-
auto [It, WasInserted] = BBVecMap.try_emplace(&BB, computeBB2Vec(BB));
159-
assert(WasInserted && "Basic block already exists in the map");
160-
addVectors(FuncVector, It->second);
161-
}
162-
}
163-
164-
Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
161+
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
165162
Embedding BBVector(Dimension, 0);
166163

167164
for (const auto &I : BB) {
@@ -183,7 +180,16 @@ Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
183180
InstVecMap[&I] = InstVector;
184181
addVectors(BBVector, InstVector);
185182
}
186-
return BBVector;
183+
BBVecMap[&BB] = BBVector;
184+
}
185+
186+
void SymbolicEmbedder::computeEmbeddings() const {
187+
if (F.isDeclaration())
188+
return;
189+
for (const auto &BB : F) {
190+
computeEmbeddings(BB);
191+
addVectors(FuncVector, BBVecMap[&BB]);
192+
}
187193
}
188194

189195
// ==----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)