Skip to content

Commit d3f8c84

Browse files
committed
Added invalidate()
1 parent db4eb0d commit d3f8c84

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,13 @@ class Embedder {
569569

570570
/// Computes and returns the embedding for the current function.
571571
LLVM_ABI Embedding getFunctionVector() const { return computeEmbeddings(); }
572+
573+
/// Invalidate embeddings if cached. The embeddings may not be relevant
574+
/// anymore when the IR changes due to transformations. In such cases, the
575+
/// cached embeddings should be invalidated to ensure
576+
/// correctness/recomputation. This is a no-op for SymbolicEmbedder but
577+
/// removes all the cached entries in FlowAwareEmbedder.
578+
virtual void invalidateEmbeddings() { return; }
572579
};
573580

574581
/// Class for computing the Symbolic embeddings of IR2Vec.
@@ -596,6 +603,7 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
596603
public:
597604
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
598605
: Embedder(F, Vocab) {}
606+
void invalidateEmbeddings() override { InstVecMap.clear(); }
599607
};
600608

601609
} // namespace ir2vec

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
165165
}
166166

167167
Embedding Embedder::computeEmbeddings() const {
168-
Embedding FuncVector(Dimension, 0);
168+
Embedding FuncVector(Dimension, 0.0);
169169

170170
if (F.isDeclaration())
171171
return FuncVector;
@@ -178,7 +178,9 @@ Embedding Embedder::computeEmbeddings() const {
178178

179179
Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
180180
Embedding BBVector(Dimension, 0);
181-
for (const Instruction &I : BB.instructionsWithoutDebug())
181+
182+
// We consider only the non-debug and non-pseudo instructions
183+
for (const auto &I : BB.instructionsWithoutDebug())
182184
BBVector += computeEmbeddings(I);
183185
return BBVector;
184186
}

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,10 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) {
409409
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
410410
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
411411
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
412+
413+
Emb->invalidateEmbeddings();
414+
const auto &FuncVec4 = Emb->getFunctionVector();
415+
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4));
412416
}
413417

414418
TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
@@ -426,6 +430,10 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
426430
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
427431
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
428432
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
433+
434+
Emb->invalidateEmbeddings();
435+
const auto &FuncVec4 = Emb->getFunctionVector();
436+
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4));
429437
}
430438

431439
static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;

0 commit comments

Comments
 (0)