diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst index 377c2aec44475..4f8fb3f59ca19 100644 --- a/llvm/docs/MLGO.rst +++ b/llvm/docs/MLGO.rst @@ -469,7 +469,6 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance. return; } const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary(); - unsigned Dimension = VocabRes.getDimension(); Note that ``IR2VecVocabAnalysis`` pass is immutable. @@ -481,7 +480,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance. // Assuming F is an llvm::Function& // For example, using IR2VecKind::Symbolic: Expected> EmbOrErr = - ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension); + ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary); if (auto Err = EmbOrErr.takeError()) { // Handle error in embedder creation diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 288753b3b3b8f..9fd1b0ae8e248 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -84,7 +84,7 @@ class Embedder { mutable BBEmbeddingsMap BBVecMap; mutable InstEmbeddingsMap InstVecMap; - Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension); + Embedder(const Function &F, const Vocab &Vocabulary); /// Helper function to compute embeddings. It generates embeddings for all /// the instructions and basic blocks in the function F. Logic of computing @@ -110,10 +110,8 @@ class Embedder { virtual ~Embedder() = default; /// Factory method to create an Embedder object. - static Expected> create(IR2VecKind Mode, - const Function &F, - const Vocab &Vocabulary, - unsigned Dimension); + static Expected> + create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary); /// Returns a map containing instructions and the corresponding embeddings for /// the function F if it has been computed. If not, it computes the embeddings @@ -149,9 +147,8 @@ class SymbolicEmbedder : public Embedder { void computeEmbeddings(const BasicBlock &BB) const override; public: - SymbolicEmbedder(const Function &F, const Vocab &Vocabulary, - unsigned Dimension) - : Embedder(F, Vocabulary, Dimension) { + SymbolicEmbedder(const Function &F, const Vocab &Vocabulary) + : Embedder(F, Vocabulary) { FuncVector = Embedding(Dimension, 0); } }; diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 67af44dcac424..490db5fdcdf99 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -59,19 +59,16 @@ AnalysisKey IR2VecVocabAnalysis::Key; // Embedder and its subclasses //===----------------------------------------------------------------------===// -Embedder::Embedder(const Function &F, const Vocab &Vocabulary, - unsigned Dimension) - : F(F), Vocabulary(Vocabulary), Dimension(Dimension), - OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) { -} +Embedder::Embedder(const Function &F, const Vocab &Vocabulary) + : F(F), Vocabulary(Vocabulary), + Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight), + TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {} -Expected> Embedder::create(IR2VecKind Mode, - const Function &F, - const Vocab &Vocabulary, - unsigned Dimension) { +Expected> +Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) { switch (Mode) { case IR2VecKind::Symbolic: - return std::make_unique(F, Vocabulary, Dimension); + return std::make_unique(F, Vocabulary); } return make_error("Unknown IR2VecKind", errc::invalid_argument); } @@ -286,10 +283,9 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid"); auto Vocab = IR2VecVocabResult.getVocabulary(); - auto Dim = IR2VecVocabResult.getDimension(); for (Function &F : M) { Expected> EmbOrErr = - Embedder::create(IR2VecKind::Symbolic, F, Vocab, Dim); + Embedder::create(IR2VecKind::Symbolic, F, Vocab); if (auto Err = EmbOrErr.takeError()) { handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n"; diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 0158038b59b6c..9e47b2cd8bedd 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -28,8 +28,7 @@ namespace { class TestableEmbedder : public Embedder { public: - TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim) - : Embedder(F, V, Dim) {} + TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {} void computeEmbeddings() const override {} void computeEmbeddings(const BasicBlock &BB) const override {} using Embedder::lookupVocab; @@ -50,7 +49,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) { FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2); + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); EXPECT_TRUE(static_cast(Result)); auto *Emb = Result->get(); @@ -66,7 +65,7 @@ TEST(IR2VecTest, CreateInvalidMode) { Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); // static_cast an invalid int to IR2VecKind - auto Result = Embedder::create(static_cast(-1), *F, V, 2); + auto Result = Embedder::create(static_cast(-1), *F, V); EXPECT_FALSE(static_cast(Result)); std::string ErrMsg; @@ -123,7 +122,7 @@ TEST(IR2VecTest, LookupVocab) { FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); - TestableEmbedder E(*F, V, 2); + TestableEmbedder E(*F, V); auto V_foo = E.lookupVocab("foo"); EXPECT_EQ(V_foo.size(), 2u); EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0)); @@ -190,7 +189,7 @@ struct GetterTestEnv { Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB); Ret = ReturnInst::Create(Ctx, Add, BB); - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2); + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); EXPECT_TRUE(static_cast(Result)); Emb = std::move(*Result); }