diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 17f41129fd4fa..3cfc206c94788 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -54,14 +54,19 @@ class IR2VecVocabAnalysis; /// of the IR entities. Flow-aware embeddings build on top of symbolic /// embeddings and additionally capture the flow information in the IR. /// IR2VecKind is used to specify the type of embeddings to generate. -/// Currently, only Symbolic embeddings are supported. -enum class IR2VecKind { Symbolic }; +/// Note: Implementation of FlowAware embeddings is not same as the one +/// described in the paper. The current implementation is a simplified version +/// that captures the flow information (SSA-based use-defs) without tracing +/// through memory level use-defs in the embedding computation described in the +/// paper. +enum class IR2VecKind { Symbolic, FlowAware }; namespace ir2vec { LLVM_ABI extern cl::opt OpcWeight; LLVM_ABI extern cl::opt TypeWeight; LLVM_ABI extern cl::opt ArgWeight; +LLVM_ABI extern cl::opt IR2VecEmbeddingKind; /// Embedding is a datatype that wraps std::vector. It provides /// additional functionality for arithmetic and comparison operations. @@ -257,9 +262,8 @@ class Embedder { LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab); /// Helper function to compute embeddings. It generates embeddings for all - /// the instructions and basic blocks in the function F. Logic of computing - /// the embeddings is specific to the kind of embeddings being computed. - virtual void computeEmbeddings() const = 0; + /// the instructions and basic blocks in the function F. + void computeEmbeddings() const; /// Helper function to compute the embedding for a given basic block. /// Specific to the kind of embeddings being computed. @@ -296,7 +300,6 @@ class Embedder { /// representations obtained from the Vocabulary. class LLVM_ABI SymbolicEmbedder : public Embedder { private: - void computeEmbeddings() const override; void computeEmbeddings(const BasicBlock &BB) const override; public: @@ -306,6 +309,20 @@ class LLVM_ABI SymbolicEmbedder : public Embedder { } }; +/// Class for computing the Flow-aware embeddings of IR2Vec. +/// Flow-aware embeddings build on the vocabulary, just like Symbolic +/// embeddings, and additionally capture the flow information in the IR. +class LLVM_ABI FlowAwareEmbedder : public Embedder { +private: + void computeEmbeddings(const BasicBlock &BB) const override; + +public: + FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab) + : Embedder(F, Vocab) { + FuncVector = Embedding(Dimension, 0); + } +}; + } // namespace ir2vec /// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 95f30fd3f4275..081a4d073b65f 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -52,6 +52,15 @@ cl::opt TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5), cl::opt ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2), cl::desc("Weight for argument embeddings"), cl::cat(IR2VecCategory)); +cl::opt IR2VecEmbeddingKind( + "ir2vec-kind", cl::Optional, + cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic", + "Generate symbolic embeddings"), + clEnumValN(IR2VecKind::FlowAware, "flow-aware", + "Generate flow-aware embeddings")), + cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"), + cl::cat(IR2VecCategory)); + } // namespace ir2vec } // namespace llvm @@ -123,8 +132,12 @@ bool Embedding::approximatelyEquals(const Embedding &RHS, double Tolerance) const { assert(this->size() == RHS.size() && "Vectors must have the same dimension"); for (size_t Itr = 0; Itr < this->size(); ++Itr) - if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) + if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) { + LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": " + << (*this)[Itr] << " vs " << RHS[Itr] + << "; Tolerance: " << Tolerance << "\n"); return false; + } return true; } @@ -149,6 +162,8 @@ std::unique_ptr Embedder::create(IR2VecKind Mode, const Function &F, switch (Mode) { case IR2VecKind::Symbolic: return std::make_unique(F, Vocab); + case IR2VecKind::FlowAware: + return std::make_unique(F, Vocab); } return nullptr; } @@ -180,6 +195,17 @@ const Embedding &Embedder::getFunctionVector() const { return FuncVector; } +void Embedder::computeEmbeddings() const { + if (F.isDeclaration()) + return; + + // Consider only the basic blocks that are reachable from entry + for (const BasicBlock *BB : depth_first(&F)) { + computeEmbeddings(*BB); + FuncVector += BBVecMap[BB]; + } +} + void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { Embedding BBVector(Dimension, 0); @@ -196,15 +222,38 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { BBVecMap[&BB] = BBVector; } -void SymbolicEmbedder::computeEmbeddings() const { - if (F.isDeclaration()) - return; +void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { + Embedding BBVector(Dimension, 0); - // Consider only the basic blocks that are reachable from entry - for (const BasicBlock *BB : depth_first(&F)) { - computeEmbeddings(*BB); - FuncVector += BBVecMap[BB]; + // We consider only the non-debug and non-pseudo instructions + for (const auto &I : BB.instructionsWithoutDebug()) { + // TODO: Handle call instructions differently. + // For now, we treat them like other instructions + Embedding ArgEmb(Dimension, 0); + for (const auto &Op : I.operands()) { + // If the operand is defined elsewhere, we use its embedding + if (const auto *DefInst = dyn_cast(Op)) { + auto DefIt = InstVecMap.find(DefInst); + assert(DefIt != InstVecMap.end() && + "Instruction should have been processed before its operands"); + ArgEmb += DefIt->second; + continue; + } + // If the operand is not defined by an instruction, we use the vocabulary + else { + LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " + << *Op << "=" << Vocab[Op][0] << "\n"); + ArgEmb += Vocab[Op]; + } + } + // Create the instruction vector by combining opcode, type, and arguments + // embeddings + auto InstVector = + Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + InstVecMap[&I] = InstVector; + BBVector += InstVector; } + BBVecMap[&BB] = BBVector; } // ==----------------------------------------------------------------------===// @@ -552,8 +601,11 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid"); for (Function &F : M) { - std::unique_ptr Emb = - Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); + std::unique_ptr Emb; + if (IR2VecEmbeddingKind == IR2VecKind::Symbolic) + Emb = Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); + else + Emb = Embedder::create(IR2VecKind::FlowAware, F, Vocabulary); if (!Emb) { OS << "Error creating IR2Vec embeddings \n"; continue; diff --git a/llvm/test/Analysis/IR2Vec/basic-flowaware.ll b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll new file mode 100644 index 0000000000000..4a7f970a9cf91 --- /dev/null +++ b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll @@ -0,0 +1,72 @@ +; RUN: opt -passes='print' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC +; RUN: opt -passes='print' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE +; RUN: opt -passes='print' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG + +define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 { +entry: + %a.addr = alloca i32, align 4 + %b.addr = alloca float, align 4 + store i32 %a, ptr %a.addr, align 4 + store float %b, ptr %b.addr, align 4 + %0 = load i32, ptr %a.addr, align 4 + %1 = load i32, ptr %a.addr, align 4 + %mul = mul nsw i32 %0, %1 + %conv = sitofp i32 %mul to float + %2 = load float, ptr %b.addr, align 4 + %add = fadd float %conv, %2 + ret float %add +} + +; 3D-CHECK-OPC: IR2Vec embeddings for function _Z3abcif: +; 3D-CHECK-OPC-NEXT: Function vector: [ 3630.00 3672.00 3714.00 ] +; 3D-CHECK-OPC-NEXT: Basic block vectors: +; 3D-CHECK-OPC-NEXT: Basic block: entry: +; 3D-CHECK-OPC-NEXT: [ 3630.00 3672.00 3714.00 ] +; 3D-CHECK-OPC-NEXT: Instruction vectors: +; 3D-CHECK-OPC-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 188.00 190.00 192.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 188.00 190.00 192.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 419.00 424.00 429.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 549.00 555.00 561.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 185.00 187.00 189.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: %add = fadd float %conv, %2 [ 774.00 783.00 792.00 ] +; 3D-CHECK-OPC-NEXT: Instruction: ret float %add [ 775.00 785.00 795.00 ] + +; 3D-CHECK-TYPE: IR2Vec embeddings for function _Z3abcif: +; 3D-CHECK-TYPE-NEXT: Function vector: [ 355.50 376.50 397.50 ] +; 3D-CHECK-TYPE-NEXT: Basic block vectors: +; 3D-CHECK-TYPE-NEXT: Basic block: entry: +; 3D-CHECK-TYPE-NEXT: [ 355.50 376.50 397.50 ] +; 3D-CHECK-TYPE-NEXT: Instruction vectors: +; 3D-CHECK-TYPE-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 12.50 13.00 13.50 ] +; 3D-CHECK-TYPE-NEXT: Instruction: %b.addr = alloca float, align 4 [ 12.50 13.00 13.50 ] +; 3D-CHECK-TYPE-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 14.50 15.50 16.50 ] +; 3D-CHECK-TYPE-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 14.50 15.50 16.50 ] +; 3D-CHECK-TYPE-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ] +; 3D-CHECK-TYPE-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ] +; 3D-CHECK-TYPE-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 53.50 56.00 58.50 ] +; 3D-CHECK-TYPE-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 54.00 57.00 60.00 ] +; 3D-CHECK-TYPE-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 13.00 14.00 15.00 ] +; 3D-CHECK-TYPE-NEXT: Instruction: %add = fadd float %conv, %2 [ 67.50 72.00 76.50 ] +; 3D-CHECK-TYPE-NEXT: Instruction: ret float %add [ 69.50 74.50 79.50 ] + +; 3D-CHECK-ARG: IR2Vec embeddings for function _Z3abcif: +; 3D-CHECK-ARG-NEXT: Function vector: [ 27.80 31.60 35.40 ] +; 3D-CHECK-ARG-NEXT: Basic block vectors: +; 3D-CHECK-ARG-NEXT: Basic block: entry: +; 3D-CHECK-ARG-NEXT: [ 27.80 31.60 35.40 ] +; 3D-CHECK-ARG-NEXT: Instruction vectors: +; 3D-CHECK-ARG-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 1.40 1.60 1.80 ] +; 3D-CHECK-ARG-NEXT: Instruction: %b.addr = alloca float, align 4 [ 1.40 1.60 1.80 ] +; 3D-CHECK-ARG-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 3.40 3.80 4.20 ] +; 3D-CHECK-ARG-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 3.40 3.80 4.20 ] +; 3D-CHECK-ARG-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ] +; 3D-CHECK-ARG-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ] +; 3D-CHECK-ARG-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 2.80 3.20 3.60 ] +; 3D-CHECK-ARG-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 2.80 3.20 3.60 ] +; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 1.40 1.60 1.80 ] +; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.20 4.80 5.40 ] +; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 4.20 4.80 5.40 ] diff --git a/llvm/test/Analysis/IR2Vec/basic.ll b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll similarity index 81% rename from llvm/test/Analysis/IR2Vec/basic.ll rename to llvm/test/Analysis/IR2Vec/basic-symbolic.ll index cb0544fb19860..35abd3c7fa269 100644 --- a/llvm/test/Analysis/IR2Vec/basic.ll +++ b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll @@ -1,11 +1,7 @@ ; RUN: opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC ; RUN: opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE ; RUN: opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG -; RUN: not opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK -; RUN: not opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK -; RUN: not opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK -; RUN: not opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK - + define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 { entry: %a.addr = alloca i32, align 4 @@ -74,11 +70,3 @@ entry: ; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 0.80 1.00 1.20 ] ; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.00 4.40 4.80 ] ; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 2.00 2.20 2.40 ] - -; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file - -; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file - -; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file - -; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions diff --git a/llvm/test/Analysis/IR2Vec/basic-vocab.ll b/llvm/test/Analysis/IR2Vec/basic-vocab.ll new file mode 100644 index 0000000000000..eeeee831814a8 --- /dev/null +++ b/llvm/test/Analysis/IR2Vec/basic-vocab.ll @@ -0,0 +1,27 @@ +; RUN: not opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK +; RUN: not opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK +; RUN: not opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK +; RUN: not opt -passes='print' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK + +define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 { +entry: + %a.addr = alloca i32, align 4 + %b.addr = alloca float, align 4 + store i32 %a, ptr %a.addr, align 4 + store float %b, ptr %b.addr, align 4 + %0 = load i32, ptr %a.addr, align 4 + %1 = load i32, ptr %a.addr, align 4 + %mul = mul nsw i32 %0, %1 + %conv = sitofp i32 %mul to float + %2 = load float, ptr %b.addr, align 4 + %add = fadd float %conv, %2 + ret float %add +} + +; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file + +; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file + +; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file + +; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index f7838cc4068ce..f0c81e160ca15 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -30,7 +30,6 @@ namespace { class TestableEmbedder : public Embedder { public: TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {} - void computeEmbeddings() const override {} void computeEmbeddings(const BasicBlock &BB) const override {} }; @@ -258,6 +257,18 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) { EXPECT_NE(Emb, nullptr); } +TEST(IR2VecTest, CreateFlowAwareEmbedder) { + Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest()); + + LLVMContext Ctx; + Module M("M", Ctx); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + + auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V); + EXPECT_NE(Emb, nullptr); +} + TEST(IR2VecTest, CreateInvalidMode) { Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest()); @@ -310,7 +321,7 @@ class IR2VecTestFixture : public ::testing::Test { } }; -TEST_F(IR2VecTestFixture, GetInstVecMap) { +TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) { auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); ASSERT_TRUE(static_cast(Emb)); @@ -329,7 +340,24 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) { EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 17.0))); } -TEST_F(IR2VecTestFixture, GetBBVecMap) { +TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) { + auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V); + ASSERT_TRUE(static_cast(Emb)); + + const auto &InstMap = Emb->getInstVecMap(); + + EXPECT_EQ(InstMap.size(), 2u); + EXPECT_TRUE(InstMap.count(AddInst)); + EXPECT_TRUE(InstMap.count(RetInst)); + + EXPECT_EQ(InstMap.at(AddInst).size(), 2u); + EXPECT_EQ(InstMap.at(RetInst).size(), 2u); + + EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.9))); + EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.6))); +} + +TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) { auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); ASSERT_TRUE(static_cast(Emb)); @@ -344,7 +372,22 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) { EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.9))); } -TEST_F(IR2VecTestFixture, GetBBVector) { +TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) { + auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V); + ASSERT_TRUE(static_cast(Emb)); + + const auto &BBMap = Emb->getBBVecMap(); + + EXPECT_EQ(BBMap.size(), 1u); + EXPECT_TRUE(BBMap.count(BB)); + EXPECT_EQ(BBMap.at(BB).size(), 2u); + + // BB vector should be sum of add and ret: {27.9, 27.9} + {35.6, 35.6} = + // {63.5, 63.5} + EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 63.5))); +} + +TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) { auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); ASSERT_TRUE(static_cast(Emb)); @@ -354,7 +397,17 @@ TEST_F(IR2VecTestFixture, GetBBVector) { EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.9))); } -TEST_F(IR2VecTestFixture, GetFunctionVector) { +TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) { + auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V); + ASSERT_TRUE(static_cast(Emb)); + + const auto &BBVec = Emb->getBBVector(*BB); + + EXPECT_EQ(BBVec.size(), 2u); + EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 63.5))); +} + +TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) { auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V); ASSERT_TRUE(static_cast(Emb)); @@ -366,6 +419,17 @@ TEST_F(IR2VecTestFixture, GetFunctionVector) { EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 44.9))); } +TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) { + auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V); + ASSERT_TRUE(static_cast(Emb)); + + const auto &FuncVec = Emb->getFunctionVector(); + + EXPECT_EQ(FuncVec.size(), 2u); + // Function vector should match BB vector (only one BB): {63.5, 63.5} + EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 63.5))); +} + static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes; static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs; static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;