diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 295b6d33525d9..688535161d4b9 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -200,6 +200,8 @@ void Embedder::computeEmbeddings() const { if (F.isDeclaration()) return; + FuncVector = Embedding(Dimension, 0.0); + // Consider only the basic blocks that are reachable from entry for (const BasicBlock *BB : depth_first(&F)) { computeEmbeddings(*BB); diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index d136cb6a316b1..40b4aa21f2b46 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -430,6 +430,60 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) { EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 58.1))); } +TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) { + auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V); + ASSERT_TRUE(static_cast(Emb)); + + // Get initial function vector + const auto &FuncVec1 = Emb->getFunctionVector(); + + // Compute embeddings again by calling getFunctionVector multiple times + const auto &FuncVec2 = Emb->getFunctionVector(); + const auto &FuncVec3 = Emb->getFunctionVector(); + + // All function vectors should be identical + EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2)); + EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3)); + EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3)); + + // Also check that instruction vectors remain consistent + const auto &InstMap1 = Emb->getInstVecMap(); + const auto &InstMap2 = Emb->getInstVecMap(); + + EXPECT_EQ(InstMap1.size(), InstMap2.size()); + for (const auto &[Inst, Vec1] : InstMap1) { + ASSERT_TRUE(InstMap2.count(Inst)); + EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst))); + } +} + +TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) { + auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V); + ASSERT_TRUE(static_cast(Emb)); + + // Get initial function vector + const auto &FuncVec1 = Emb->getFunctionVector(); + + // Compute embeddings again by calling getFunctionVector multiple times + const auto &FuncVec2 = Emb->getFunctionVector(); + const auto &FuncVec3 = Emb->getFunctionVector(); + + // All function vectors should be identical + EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2)); + EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3)); + EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3)); + + // Also check that instruction vectors remain consistent + const auto &InstMap1 = Emb->getInstVecMap(); + const auto &InstMap2 = Emb->getInstVecMap(); + + EXPECT_EQ(InstMap1.size(), InstMap2.size()); + for (const auto &[Inst, Vec1] : InstMap1) { + ASSERT_TRUE(InstMap2.count(Inst)); + EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst))); + } +} + static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes; [[maybe_unused]] static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;