Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ void Embedder::computeEmbeddings() const {
if (F.isDeclaration())
return;

FuncVector = Embedding(Dimension, 0.0);

// Consider only the basic blocks that are reachable from entry
for (const BasicBlock *BB : depth_first(&F)) {
computeEmbeddings(*BB);
Expand Down
54 changes: 54 additions & 0 deletions llvm/unittests/Analysis/IR2VecTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,60 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 58.1)));
}

TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));

// Get initial function vector
const auto &FuncVec1 = Emb->getFunctionVector();

// Compute embeddings again by calling getFunctionVector multiple times
const auto &FuncVec2 = Emb->getFunctionVector();
const auto &FuncVec3 = Emb->getFunctionVector();

// All function vectors should be identical
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));

// Also check that instruction vectors remain consistent
const auto &InstMap1 = Emb->getInstVecMap();
const auto &InstMap2 = Emb->getInstVecMap();

EXPECT_EQ(InstMap1.size(), InstMap2.size());
for (const auto &[Inst, Vec1] : InstMap1) {
ASSERT_TRUE(InstMap2.count(Inst));
EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst)));
}
}

TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));

// Get initial function vector
const auto &FuncVec1 = Emb->getFunctionVector();

// Compute embeddings again by calling getFunctionVector multiple times
const auto &FuncVec2 = Emb->getFunctionVector();
const auto &FuncVec3 = Emb->getFunctionVector();

// All function vectors should be identical
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec2));
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));

// Also check that instruction vectors remain consistent
const auto &InstMap1 = Emb->getInstVecMap();
const auto &InstMap2 = Emb->getInstVecMap();

EXPECT_EQ(InstMap1.size(), InstMap2.size());
for (const auto &[Inst, Vec1] : InstMap1) {
ASSERT_TRUE(InstMap2.count(Inst));
EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst)));
}
}

static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
[[maybe_unused]]
static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
Expand Down
Loading