diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 5f3114dcdeeaa..683f05d5beb04 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -77,6 +77,7 @@ Expected> Embedder::create(IR2VecKind Mode, } void Embedder::addVectors(Embedding &Dst, const Embedding &Src) { + assert(Dst.size() == Src.size() && "Vectors must have the same dimension"); std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(), std::plus()); } diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt index 67f0b043e4f68..cd04a779b9467 100644 --- a/llvm/unittests/Analysis/CMakeLists.txt +++ b/llvm/unittests/Analysis/CMakeLists.txt @@ -32,6 +32,7 @@ set(ANALYSIS_TEST_SOURCES GlobalsModRefTest.cpp FunctionPropertiesAnalysisTest.cpp InlineCostTest.cpp + IR2VecTest.cpp IRSimilarityIdentifierTest.cpp IVDescriptorsTest.cpp LastRunTrackingAnalysisTest.cpp diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp new file mode 100644 index 0000000000000..5fb4da9f5fe20 --- /dev/null +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -0,0 +1,243 @@ +//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IR2Vec.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Error.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include + +using namespace llvm; +using namespace ir2vec; +using namespace ::testing; + +namespace { + +class TestableEmbedder : public Embedder { +public: + TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim) + : Embedder(F, V, Dim) {} + void computeEmbeddings() const override {} + using Embedder::lookupVocab; + static void addVectors(Embedding &Dst, const Embedding &Src) { + Embedder::addVectors(Dst, Src); + } + static void addScaledVector(Embedding &Dst, const Embedding &Src, + float Factor) { + Embedder::addScaledVector(Dst, Src, Factor); + } +}; + +TEST(IR2VecTest, CreateSymbolicEmbedder) { + Vocab V = {{"foo", {1.0, 2.0}}}; + + LLVMContext Ctx; + Module M("M", Ctx); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2); + EXPECT_TRUE(static_cast(Result)); + + auto *Emb = Result->get(); + EXPECT_NE(Emb, nullptr); +} + +TEST(IR2VecTest, CreateInvalidMode) { + Vocab V = {{"foo", {1.0, 2.0}}}; + + LLVMContext Ctx; + Module M("M", Ctx); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + + // static_cast an invalid int to IR2VecKind + auto Result = Embedder::create(static_cast(-1), *F, V, 2); + EXPECT_FALSE(static_cast(Result)); + + std::string ErrMsg; + llvm::handleAllErrors( + Result.takeError(), + [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); }); + EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos); +} + +TEST(IR2VecTest, AddVectors) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {0.5, 1.5, -1.0}; + + TestableEmbedder::addVectors(E1, E2); + EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0)); + + // Check that E2 is unchanged + EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); +} + +TEST(IR2VecTest, AddScaledVector) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {2.0, 0.5, -1.0}; + + TestableEmbedder::addScaledVector(E1, E2, 0.5f); + EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5)); + + // Check that E2 is unchanged + EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0)); +} + +#if GTEST_HAS_DEATH_TEST +#ifndef NDEBUG +TEST(IR2VecTest, MismatchedDimensionsAddVectors) { + Embedding E1 = {1.0, 2.0}; + Embedding E2 = {1.0}; + EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2), + "Vectors must have the same dimension"); +} + +TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) { + Embedding E1 = {1.0, 2.0}; + Embedding E2 = {1.0}; + EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f), + "Vectors must have the same dimension"); +} +#endif // NDEBUG +#endif // GTEST_HAS_DEATH_TEST + +TEST(IR2VecTest, LookupVocab) { + Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}}; + LLVMContext Ctx; + Module M("M", Ctx); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + + TestableEmbedder E(*F, V, 2); + auto V_foo = E.lookupVocab("foo"); + EXPECT_EQ(V_foo.size(), 2u); + EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0)); + + auto V_missing = E.lookupVocab("missing"); + EXPECT_EQ(V_missing.size(), 2u); + EXPECT_THAT(V_missing, ElementsAre(0.0, 0.0)); +} + +TEST(IR2VecTest, ZeroDimensionEmbedding) { + Embedding E1; + Embedding E2; + // Should be no-op, but not crash + TestableEmbedder::addVectors(E1, E2); + TestableEmbedder::addScaledVector(E1, E2, 1.0f); + EXPECT_TRUE(E1.empty()); +} + +TEST(IR2VecTest, IR2VecVocabResultValidity) { + // Default constructed is invalid + IR2VecVocabResult invalidResult; + EXPECT_FALSE(invalidResult.isValid()); +#if GTEST_HAS_DEATH_TEST +#ifndef NDEBUG + EXPECT_DEATH(invalidResult.getVocabulary(), "IR2Vec Vocabulary is invalid"); + EXPECT_DEATH(invalidResult.getDimension(), "IR2Vec Vocabulary is invalid"); +#endif // NDEBUG +#endif // GTEST_HAS_DEATH_TEST + + // Valid vocab + Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}}; + IR2VecVocabResult validResult(std::move(V)); + EXPECT_TRUE(validResult.isValid()); + EXPECT_EQ(validResult.getDimension(), 2u); +} + +// Helper to create a minimal function and embedder for getter tests +struct GetterTestEnv { + Vocab V = {}; + LLVMContext Ctx; + std::unique_ptr M = nullptr; + Function *F = nullptr; + BasicBlock *BB = nullptr; + Instruction *Add = nullptr; + Instruction *Ret = nullptr; + std::unique_ptr Emb = nullptr; + + GetterTestEnv() { + V = {{"add", {1.0, 2.0}}, + {"integerTy", {0.5, 0.5}}, + {"constant", {0.2, 0.3}}, + {"variable", {0.0, 0.0}}, + {"unknownTy", {0.0, 0.0}}}; + + M = std::make_unique("M", Ctx); + FunctionType *FTy = FunctionType::get( + Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)}, + false); + F = Function::Create(FTy, Function::ExternalLinkage, "f", M.get()); + BB = BasicBlock::Create(Ctx, "entry", F); + Argument *Arg = F->getArg(0); + llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42); + + Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB); + Ret = ReturnInst::Create(Ctx, Add, BB); + + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2); + EXPECT_TRUE(static_cast(Result)); + Emb = std::move(*Result); + } +}; + +TEST(IR2VecTest, GetInstVecMap) { + GetterTestEnv Env; + const auto &InstMap = Env.Emb->getInstVecMap(); + + EXPECT_EQ(InstMap.size(), 2u); + EXPECT_TRUE(InstMap.count(Env.Add)); + EXPECT_TRUE(InstMap.count(Env.Ret)); + + EXPECT_EQ(InstMap.at(Env.Add).size(), 2u); + EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u); + + // Check values for add: {1.29, 2.31} + EXPECT_THAT(InstMap.at(Env.Add), + ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); + + // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in + // vocab + EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0)); +} + +TEST(IR2VecTest, GetBBVecMap) { + GetterTestEnv Env; + const auto &BBMap = Env.Emb->getBBVecMap(); + + EXPECT_EQ(BBMap.size(), 1u); + EXPECT_TRUE(BBMap.count(Env.BB)); + EXPECT_EQ(BBMap.at(Env.BB).size(), 2u); + + // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} = + // {1.29, 2.31} + EXPECT_THAT(BBMap.at(Env.BB), + ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); +} + +TEST(IR2VecTest, GetFunctionVector) { + GetterTestEnv Env; + const auto &FuncVec = Env.Emb->getFunctionVector(); + + EXPECT_EQ(FuncVec.size(), 2u); + + // Function vector should match BB vector (only one BB): {1.29, 2.31} + EXPECT_THAT(FuncVec, + ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); +} + +} // end anonymous namespace