Skip to content

Commit fc83ba0

Browse files
committed
Minor changes to address review comments
1 parent 6bb850a commit fc83ba0

File tree

2 files changed

+35
-34
lines changed

2 files changed

+35
-34
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ class Embedder {
7171
const Function &F;
7272
const Vocab &Vocabulary;
7373

74+
/// Dimension of the vector representation; captured from the input vocabulary
75+
const unsigned Dimension;
76+
7477
/// Weights for different entities (like opcode, arguments, types)
7578
/// in the IR instructions to generate the vector representation.
7679
const float OpcWeight, TypeWeight, ArgWeight;
7780

78-
/// Dimension of the vector representation; captured from the input vocabulary
79-
const unsigned Dimension;
80-
8181
// Utility maps - these are used to store the vector representations of
8282
// instructions, basic blocks and functions.
8383
Embedding FuncVector;
@@ -88,21 +88,21 @@ class Embedder {
8888

8989
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
9090
/// zero vector.
91-
Embedding lookupVocab(const std::string &Key);
91+
Embedding lookupVocab(const std::string &Key) const;
9292

9393
/// Adds two vectors: Dst += Src
94-
void addVectors(Embedding &Dst, const Embedding &Src);
94+
static void addVectors(Embedding &Dst, const Embedding &Src);
9595

9696
/// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
97-
void addScaledVector(Embedding &Dst, const Embedding &Src, float Factor);
97+
static void addScaledVector(Embedding &Dst, const Embedding &Src,
98+
float Factor);
9899

99100
public:
100101
virtual ~Embedder() = default;
101102

102-
/// Top level function to compute embeddings. Given a function, it
103-
/// generates embeddings for all the instructions and basic blocks in that
104-
/// function. Logic of computing the embeddings is specific to the kind of
105-
/// embeddings being computed.
103+
/// Top level function to compute embeddings. It generates embeddings for all
104+
/// the instructions and basic blocks in the function F. Logic of computing
105+
/// the embeddings is specific to the kind of embeddings being computed.
106106
virtual void computeEmbeddings() = 0;
107107

108108
/// Factory method to create an Embedder object.
@@ -126,23 +126,19 @@ class Embedder {
126126
const Embedding &getFunctionVector() const { return FuncVector; }
127127
};
128128

129-
/// Class for computing the Symbolic embeddings of IR2Vec
129+
/// Class for computing the Symbolic embeddings of IR2Vec.
130130
class SymbolicEmbedder : public Embedder {
131131
private:
132132
/// Utility function to compute the vector representation for a given basic
133133
/// block.
134134
Embedding computeBB2Vec(const BasicBlock &BB);
135135

136-
/// Utility function to compute the vector representation for a given
137-
/// function.
138-
Embedding computeFunc2Vec();
139-
140136
/// Utility function to compute the vector representation for a given type.
141-
Embedding getTypeEmbedding(const Type *Ty);
137+
Embedding getTypeEmbedding(const Type *Ty) const;
142138

143139
/// Utility function to compute the vector representation for a given
144140
/// operand.
145-
Embedding getOperandEmbedding(const Value *Op);
141+
Embedding getOperandEmbedding(const Value *Op) const;
146142

147143
public:
148144
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
@@ -168,7 +164,7 @@ class IR2VecVocabResult {
168164
const ir2vec::Vocab &getVocabulary() const;
169165
unsigned getDimension() const;
170166
bool invalidate(Module &M, const PreservedAnalyses &PA,
171-
ModuleAnalysisManager::Invalidator &Inv);
167+
ModuleAnalysisManager::Invalidator &Inv) const;
172168
};
173169

174170
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,11 @@ AnalysisKey IR2VecVocabAnalysis::Key;
6060
// Embedder and its subclasses
6161
//===----------------------------------------------------------------------===//
6262

63-
#define RETURN_LOOKUP_IF(CONDITION, KEY_STR) \
64-
if (CONDITION) \
65-
return lookupVocab(KEY_STR);
66-
6763
Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
6864
unsigned Dimension)
69-
: F(F), Vocabulary(Vocabulary), Dimension(Dimension), OpcWeight(OpcWeight),
70-
TypeWeight(TypeWeight), ArgWeight(ArgWeight) {}
65+
: F(F), Vocabulary(Vocabulary), Dimension(Dimension),
66+
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
67+
}
7168

7269
ErrorOr<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
7370
const Function &F,
@@ -77,7 +74,8 @@ ErrorOr<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
7774
case IR2VecKind::Symbolic:
7875
return std::make_unique<SymbolicEmbedder>(F, Vocabulary, Dimension);
7976
default:
80-
return errc::invalid_argument;
77+
return errorToErrorCode(
78+
make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument));
8179
}
8280
}
8381

@@ -96,7 +94,7 @@ void Embedder::addScaledVector(Embedding &Dst, const Embedding &Src,
9694

9795
// FIXME: Currently lookups are string based. Use numeric Keys
9896
// for efficiency
99-
Embedding Embedder::lookupVocab(const std::string &Key) {
97+
Embedding Embedder::lookupVocab(const std::string &Key) const {
10098
Embedding Vec(Dimension, 0);
10199
// FIXME: Use zero vectors in vocab and assert failure for
102100
// unknown entities rather than silently returning zeroes here.
@@ -108,7 +106,11 @@ Embedding Embedder::lookupVocab(const std::string &Key) {
108106
return Vec;
109107
}
110108

111-
Embedding SymbolicEmbedder::getTypeEmbedding(const Type *Ty) {
109+
#define RETURN_LOOKUP_IF(CONDITION, KEY_STR) \
110+
if (CONDITION) \
111+
return lookupVocab(KEY_STR);
112+
113+
Embedding SymbolicEmbedder::getTypeEmbedding(const Type *Ty) const {
112114
RETURN_LOOKUP_IF(Ty->isVoidTy(), "voidTy");
113115
RETURN_LOOKUP_IF(Ty->isFloatingPointTy(), "floatTy");
114116
RETURN_LOOKUP_IF(Ty->isIntegerTy(), "integerTy");
@@ -124,13 +126,15 @@ Embedding SymbolicEmbedder::getTypeEmbedding(const Type *Ty) {
124126
return lookupVocab("unknownTy");
125127
}
126128

127-
Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) {
129+
Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
128130
RETURN_LOOKUP_IF(isa<Function>(Op), "function");
129131
RETURN_LOOKUP_IF(isa<PointerType>(Op->getType()), "pointer");
130132
RETURN_LOOKUP_IF(isa<Constant>(Op), "constant");
131133
return lookupVocab("variable");
132134
}
133135

136+
#undef RETURN_LOOKUP_IF
137+
134138
void SymbolicEmbedder::computeEmbeddings() {
135139
if (F.isDeclaration())
136140
return;
@@ -147,17 +151,17 @@ Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
147151
for (const auto &I : BB) {
148152
Embedding InstVector(Dimension, 0);
149153

150-
auto OpcVec = lookupVocab(I.getOpcodeName());
154+
const auto OpcVec = lookupVocab(I.getOpcodeName());
151155
addScaledVector(InstVector, OpcVec, OpcWeight);
152156

153157
// FIXME: Currently lookups are string based. Use numeric Keys
154158
// for efficiency.
155-
auto Type = I.getType();
156-
auto TypeVec = getTypeEmbedding(Type);
159+
const auto Type = I.getType();
160+
const auto TypeVec = getTypeEmbedding(Type);
157161
addScaledVector(InstVector, TypeVec, TypeWeight);
158162

159163
for (const auto &Op : I.operands()) {
160-
auto OperandVec = getOperandEmbedding(Op.get());
164+
const auto OperandVec = getOperandEmbedding(Op.get());
161165
addScaledVector(InstVector, OperandVec, ArgWeight);
162166
}
163167
InstVecMap[&I] = InstVector;
@@ -184,8 +188,9 @@ unsigned IR2VecVocabResult::getDimension() const {
184188
}
185189

186190
// For now, assume vocabulary is stable unless explicitly invalidated.
187-
bool IR2VecVocabResult::invalidate(Module &M, const PreservedAnalyses &PA,
188-
ModuleAnalysisManager::Invalidator &Inv) {
191+
bool IR2VecVocabResult::invalidate(
192+
Module &M, const PreservedAnalyses &PA,
193+
ModuleAnalysisManager::Invalidator &Inv) const {
189194
auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
190195
return !(PAC.preservedWhenStateless());
191196
}

0 commit comments

Comments
 (0)