Skip to content

Commit 741136a

Browse files
authored
[NFC][IR2Vec] Removing Dimension from Embedder::Create (llvm#142486)
This PR removes the necessity to know the dimension of the embeddings while invoking `Embedder::Create`. Having the `Dimension` parameter introduces complexities in downstream consumers. (Tracking issue - llvm#141817)
1 parent c80c452 commit 741136a

File tree

4 files changed

+19
-28
lines changed

4 files changed

+19
-28
lines changed

llvm/docs/MLGO.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,6 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
469469
return;
470470
}
471471
const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary();
472-
unsigned Dimension = VocabRes.getDimension();
473472

474473
Note that ``IR2VecVocabAnalysis`` pass is immutable.
475474

@@ -481,7 +480,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
481480
// Assuming F is an llvm::Function&
482481
// For example, using IR2VecKind::Symbolic:
483482
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
484-
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension);
483+
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
485484

486485
if (auto Err = EmbOrErr.takeError()) {
487486
// Handle error in embedder creation

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class Embedder {
8484
mutable BBEmbeddingsMap BBVecMap;
8585
mutable InstEmbeddingsMap InstVecMap;
8686

87-
Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
87+
Embedder(const Function &F, const Vocab &Vocabulary);
8888

8989
/// Helper function to compute embeddings. It generates embeddings for all
9090
/// the instructions and basic blocks in the function F. Logic of computing
@@ -110,10 +110,8 @@ class Embedder {
110110
virtual ~Embedder() = default;
111111

112112
/// Factory method to create an Embedder object.
113-
static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
114-
const Function &F,
115-
const Vocab &Vocabulary,
116-
unsigned Dimension);
113+
static Expected<std::unique_ptr<Embedder>>
114+
create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
117115

118116
/// Returns a map containing instructions and the corresponding embeddings for
119117
/// the function F if it has been computed. If not, it computes the embeddings
@@ -149,9 +147,8 @@ class SymbolicEmbedder : public Embedder {
149147
void computeEmbeddings(const BasicBlock &BB) const override;
150148

151149
public:
152-
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
153-
unsigned Dimension)
154-
: Embedder(F, Vocabulary, Dimension) {
150+
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
151+
: Embedder(F, Vocabulary) {
155152
FuncVector = Embedding(Dimension, 0);
156153
}
157154
};

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,16 @@ AnalysisKey IR2VecVocabAnalysis::Key;
5959
// Embedder and its subclasses
6060
//===----------------------------------------------------------------------===//
6161

62-
Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
63-
unsigned Dimension)
64-
: F(F), Vocabulary(Vocabulary), Dimension(Dimension),
65-
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
66-
}
62+
Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
63+
: F(F), Vocabulary(Vocabulary),
64+
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
65+
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
6766

68-
Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
69-
const Function &F,
70-
const Vocab &Vocabulary,
71-
unsigned Dimension) {
67+
Expected<std::unique_ptr<Embedder>>
68+
Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
7269
switch (Mode) {
7370
case IR2VecKind::Symbolic:
74-
return std::make_unique<SymbolicEmbedder>(F, Vocabulary, Dimension);
71+
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
7572
}
7673
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
7774
}
@@ -286,10 +283,9 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
286283
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
287284

288285
auto Vocab = IR2VecVocabResult.getVocabulary();
289-
auto Dim = IR2VecVocabResult.getDimension();
290286
for (Function &F : M) {
291287
Expected<std::unique_ptr<Embedder>> EmbOrErr =
292-
Embedder::create(IR2VecKind::Symbolic, F, Vocab, Dim);
288+
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
293289
if (auto Err = EmbOrErr.takeError()) {
294290
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
295291
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ namespace {
2828

2929
class TestableEmbedder : public Embedder {
3030
public:
31-
TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
32-
: Embedder(F, V, Dim) {}
31+
TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {}
3332
void computeEmbeddings() const override {}
3433
void computeEmbeddings(const BasicBlock &BB) const override {}
3534
using Embedder::lookupVocab;
@@ -50,7 +49,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
5049
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
5150
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
5251

53-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
52+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
5453
EXPECT_TRUE(static_cast<bool>(Result));
5554

5655
auto *Emb = Result->get();
@@ -66,7 +65,7 @@ TEST(IR2VecTest, CreateInvalidMode) {
6665
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
6766

6867
// static_cast an invalid int to IR2VecKind
69-
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
68+
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
7069
EXPECT_FALSE(static_cast<bool>(Result));
7170

7271
std::string ErrMsg;
@@ -123,7 +122,7 @@ TEST(IR2VecTest, LookupVocab) {
123122
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
124123
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
125124

126-
TestableEmbedder E(*F, V, 2);
125+
TestableEmbedder E(*F, V);
127126
auto V_foo = E.lookupVocab("foo");
128127
EXPECT_EQ(V_foo.size(), 2u);
129128
EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
@@ -190,7 +189,7 @@ struct GetterTestEnv {
190189
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
191190
Ret = ReturnInst::Create(Ctx, Add, BB);
192191

193-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
192+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
194193
EXPECT_TRUE(static_cast<bool>(Result));
195194
Emb = std::move(*Result);
196195
}

0 commit comments

Comments
 (0)