3131
3232#include " llvm/ADT/DenseMap.h"
3333#include " llvm/IR/PassManager.h"
34+ #include " llvm/IR/Type.h"
3435#include " llvm/Support/CommandLine.h"
3536#include " llvm/Support/ErrorOr.h"
3637#include " llvm/Support/JSON.h"
@@ -42,10 +43,10 @@ class Module;
4243class BasicBlock ;
4344class Instruction ;
4445class Function ;
45- class Type ;
4646class Value ;
4747class raw_ostream ;
4848class LLVMContext ;
49+ class IR2VecVocabAnalysis ;
4950
5051// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
5152// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -124,9 +125,73 @@ struct Embedding {
124125
125126using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
126127using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
127- // FIXME: Current the keys are strings. This can be changed to
128- // use integers for cheaper lookups.
129- using Vocab = std::map<std::string, Embedding>;
128+
129+ // / Class for storing and accessing the IR2Vec vocabulary.
130+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
131+ class Vocabulary {
132+ friend class llvm ::IR2VecVocabAnalysis;
133+ using VocabVector = std::vector<ir2vec::Embedding>;
134+ VocabVector Vocab;
135+ bool Valid = false ;
136+
137+ // / Operand kinds supported by IR2Vec Vocabulary
138+ #define OPERAND_KINDS \
139+ OPERAND_KIND (FunctionID, " Function" ) \
140+ OPERAND_KIND (PointerID, " Pointer" ) \
141+ OPERAND_KIND (ConstantID, " Constant" ) \
142+ OPERAND_KIND (VariableID, " Variable" )
143+
144+ enum class OperandKind : unsigned {
145+ #define OPERAND_KIND (Name, Str ) Name,
146+ OPERAND_KINDS
147+ #undef OPERAND_KIND
148+ MaxOperandKind
149+ };
150+
151+ #undef OPERAND_KINDS
152+
153+ // / Vocabulary layout constants
154+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
155+ #include " llvm/IR/Instruction.def"
156+ #undef LAST_OTHER_INST
157+
158+ static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1 ;
159+ static constexpr unsigned MaxOperandKinds =
160+ static_cast <unsigned >(OperandKind::MaxOperandKind);
161+
162+ // / Helper function to get vocabulary key for a given OperandKind
163+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
164+
165+ // / Helper function to classify an operand into OperandKind
166+ static OperandKind getOperandKind (const Value *Op);
167+
168+ // / Helper function to get vocabulary key for a given TypeID
169+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
170+
171+ public:
172+ Vocabulary () = default ;
173+ Vocabulary (VocabVector &&Vocab);
174+
175+ bool isValid () const ;
176+ unsigned getDimension () const ;
177+ unsigned size () const ;
178+
179+ const ir2vec::Embedding &at (unsigned Position) const ;
180+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
181+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
182+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
183+
184+ // / Returns the string key for a given index position in the vocabulary.
185+ // / This is useful for debugging or printing the vocabulary. Do not use this
186+ // / for embedding generation as string based lookups are inefficient.
187+ static StringRef getStringKey (unsigned Pos);
188+
189+ // / Create a dummy vocabulary for testing purposes.
190+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
191+
192+ bool invalidate (Module &M, const PreservedAnalyses &PA,
193+ ModuleAnalysisManager::Invalidator &Inv) const ;
194+ };
130195
131196// / Embedder provides the interface to generate embeddings (vector
132197// / representations) for instructions, basic blocks, and functions. The
@@ -137,7 +202,7 @@ using Vocab = std::map<std::string, Embedding>;
137202class Embedder {
138203protected:
139204 const Function &F;
140- const Vocab &Vocabulary ;
205+ const Vocabulary &Vocab ;
141206
142207 // / Dimension of the vector representation; captured from the input vocabulary
143208 const unsigned Dimension;
@@ -152,7 +217,7 @@ class Embedder {
152217 mutable BBEmbeddingsMap BBVecMap;
153218 mutable InstEmbeddingsMap InstVecMap;
154219
155- Embedder (const Function &F, const Vocab &Vocabulary );
220+ Embedder (const Function &F, const Vocabulary &Vocab );
156221
157222 // / Helper function to compute embeddings. It generates embeddings for all
158223 // / the instructions and basic blocks in the function F. Logic of computing
@@ -163,16 +228,12 @@ class Embedder {
163228 // / Specific to the kind of embeddings being computed.
164229 virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
165230
166- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
167- // / zero vector.
168- Embedding lookupVocab (const std::string &Key) const ;
169-
170231public:
171232 virtual ~Embedder () = default ;
172233
173234 // / Factory method to create an Embedder object.
174235 static std::unique_ptr<Embedder> create (IR2VecKind Mode, const Function &F,
175- const Vocab &Vocabulary );
236+ const Vocabulary &Vocab );
176237
177238 // / Returns a map containing instructions and the corresponding embeddings for
178239 // / the function F if it has been computed. If not, it computes the embeddings
@@ -198,56 +259,40 @@ class Embedder {
198259// / representations obtained from the Vocabulary.
199260class SymbolicEmbedder : public Embedder {
200261private:
201- // / Utility function to compute the embedding for a given type.
202- Embedding getTypeEmbedding (const Type *Ty) const ;
203-
204- // / Utility function to compute the embedding for a given operand.
205- Embedding getOperandEmbedding (const Value *Op) const ;
206-
207262 void computeEmbeddings () const override ;
208263 void computeEmbeddings (const BasicBlock &BB) const override ;
209264
210265public:
211- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
212- : Embedder(F, Vocabulary ) {
266+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
267+ : Embedder(F, Vocab ) {
213268 FuncVector = Embedding (Dimension, 0 );
214269 }
215270};
216271
217272} // namespace ir2vec
218273
219- // / Class for storing the result of the IR2VecVocabAnalysis.
220- class IR2VecVocabResult {
221- ir2vec::Vocab Vocabulary;
222- bool Valid = false ;
223-
224- public:
225- IR2VecVocabResult () = default ;
226- IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
227-
228- bool isValid () const { return Valid; }
229- const ir2vec::Vocab &getVocabulary () const ;
230- unsigned getDimension () const ;
231- bool invalidate (Module &M, const PreservedAnalyses &PA,
232- ModuleAnalysisManager::Invalidator &Inv) const ;
233- };
234-
235274// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
236275// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
237276// / its corresponding embedding.
238277class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
239- ir2vec::Vocab Vocabulary;
278+ using VocabVector = std::vector<ir2vec::Embedding>;
279+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
280+ VocabMap OpcVocab, TypeVocab, ArgVocab;
281+ VocabVector Vocab;
282+
283+ unsigned Dim = 0 ;
240284 Error readVocabulary ();
241285 Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
242- ir2vec::Vocab &TargetVocab, unsigned &Dim);
286+ VocabMap &TargetVocab, unsigned &Dim);
287+ void generateNumMappedVocab ();
243288 void emitError (Error Err, LLVMContext &Ctx);
244289
245290public:
246291 static AnalysisKey Key;
247292 IR2VecVocabAnalysis () = default ;
248- explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
249- explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
250- using Result = IR2VecVocabResult ;
293+ explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
294+ explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
295+ using Result = ir2vec::Vocabulary ;
251296 Result run (Module &M, ModuleAnalysisManager &MAM);
252297};
253298
0 commit comments