3131
3232#include " llvm/ADT/DenseMap.h"
3333#include " llvm/IR/PassManager.h"
34+ #include " llvm/IR/Type.h"
3435#include " llvm/Support/CommandLine.h"
3536#include " llvm/Support/Compiler.h"
3637#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
4344class BasicBlock ;
4445class Instruction ;
4546class Function ;
46- class Type ;
4747class Value ;
4848class raw_ostream ;
4949class LLVMContext ;
50+ class IR2VecVocabAnalysis ;
5051
5152// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
5253// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -128,9 +129,94 @@ struct Embedding {
128129
129130using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
130131using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
131- // FIXME: Current the keys are strings. This can be changed to
132- // use integers for cheaper lookups.
133- using Vocab = std::map<std::string, Embedding>;
132+
133+ // / Class for storing and accessing the IR2Vec vocabulary.
134+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
135+ class Vocabulary {
136+ friend class llvm ::IR2VecVocabAnalysis;
137+ using VocabVector = std::vector<ir2vec::Embedding>;
138+ VocabVector Vocab;
139+ bool Valid = false ;
140+
141+ // / Operand kinds supported by IR2Vec Vocabulary
142+ enum class OperandKind : unsigned {
143+ FunctionID,
144+ PointerID,
145+ ConstantID,
146+ VariableID,
147+ MaxOperandKind
148+ };
149+ // / String mappings for OperandKind values
150+ static constexpr StringLiteral OperandKindNames[] = {" Function" , " Pointer" ,
151+ " Constant" , " Variable" };
152+ static_assert (std::size(OperandKindNames) ==
153+ static_cast <unsigned >(OperandKind::MaxOperandKind),
154+ " OperandKindNames array size must match MaxOperandKind" );
155+
156+ // / Vocabulary layout constants
157+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
158+ #include " llvm/IR/Instruction.def"
159+ #undef LAST_OTHER_INST
160+
161+ static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1 ;
162+ static constexpr unsigned MaxOperandKinds =
163+ static_cast <unsigned >(OperandKind::MaxOperandKind);
164+
165+ // / Helper function to get vocabulary key for a given OperandKind
166+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
167+
168+ // / Helper function to classify an operand into OperandKind
169+ static OperandKind getOperandKind (const Value *Op);
170+
171+ // / Helper function to get vocabulary key for a given TypeID
172+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
173+
174+ public:
175+ Vocabulary () = default ;
176+ Vocabulary (VocabVector &&Vocab);
177+
178+ bool isValid () const ;
179+ unsigned getDimension () const ;
180+ size_t size () const ;
181+
182+ // / Accessors to get the embedding for a given entity.
183+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
184+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
185+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
186+
187+ // / Const Iterator type aliases
188+ using const_iterator = VocabVector::const_iterator;
189+ const_iterator begin () const {
190+ assert (Valid && " IR2Vec Vocabulary is invalid" );
191+ return Vocab.begin ();
192+ }
193+
194+ const_iterator cbegin () const {
195+ assert (Valid && " IR2Vec Vocabulary is invalid" );
196+ return Vocab.cbegin ();
197+ }
198+
199+ const_iterator end () const {
200+ assert (Valid && " IR2Vec Vocabulary is invalid" );
201+ return Vocab.end ();
202+ }
203+
204+ const_iterator cend () const {
205+ assert (Valid && " IR2Vec Vocabulary is invalid" );
206+ return Vocab.cend ();
207+ }
208+
209+ // / Returns the string key for a given index position in the vocabulary.
210+ // / This is useful for debugging or printing the vocabulary. Do not use this
211+ // / for embedding generation as string based lookups are inefficient.
212+ static StringRef getStringKey (unsigned Pos);
213+
214+ // / Create a dummy vocabulary for testing purposes.
215+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
216+
217+ bool invalidate (Module &M, const PreservedAnalyses &PA,
218+ ModuleAnalysisManager::Invalidator &Inv) const ;
219+ };
134220
135221// / Embedder provides the interface to generate embeddings (vector
136222// / representations) for instructions, basic blocks, and functions. The
@@ -141,7 +227,7 @@ using Vocab = std::map<std::string, Embedding>;
141227class Embedder {
142228protected:
143229 const Function &F;
144- const Vocab &Vocabulary ;
230+ const Vocabulary &Vocab ;
145231
146232 // / Dimension of the vector representation; captured from the input vocabulary
147233 const unsigned Dimension;
@@ -156,7 +242,7 @@ class Embedder {
156242 mutable BBEmbeddingsMap BBVecMap;
157243 mutable InstEmbeddingsMap InstVecMap;
158244
159- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
245+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
160246
161247 // / Helper function to compute embeddings. It generates embeddings for all
162248 // / the instructions and basic blocks in the function F. Logic of computing
@@ -167,16 +253,12 @@ class Embedder {
167253 // / Specific to the kind of embeddings being computed.
168254 virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
169255
170- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
171- // / zero vector.
172- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
173-
174256public:
175257 virtual ~Embedder () = default ;
176258
177259 // / Factory method to create an Embedder object.
178260 LLVM_ABI static std::unique_ptr<Embedder>
179- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
261+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
180262
181263 // / Returns a map containing instructions and the corresponding embeddings for
182264 // / the function F if it has been computed. If not, it computes the embeddings
@@ -202,56 +284,39 @@ class Embedder {
202284// / representations obtained from the Vocabulary.
203285class LLVM_ABI SymbolicEmbedder : public Embedder {
204286private:
205- // / Utility function to compute the embedding for a given type.
206- Embedding getTypeEmbedding (const Type *Ty) const ;
207-
208- // / Utility function to compute the embedding for a given operand.
209- Embedding getOperandEmbedding (const Value *Op) const ;
210-
211287 void computeEmbeddings () const override ;
212288 void computeEmbeddings (const BasicBlock &BB) const override ;
213289
214290public:
215- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
216- : Embedder(F, Vocabulary ) {
291+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
292+ : Embedder(F, Vocab ) {
217293 FuncVector = Embedding (Dimension, 0 );
218294 }
219295};
220296
221297} // namespace ir2vec
222298
223- // / Class for storing the result of the IR2VecVocabAnalysis.
224- class IR2VecVocabResult {
225- ir2vec::Vocab Vocabulary;
226- bool Valid = false ;
227-
228- public:
229- IR2VecVocabResult () = default ;
230- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
231-
232- bool isValid () const { return Valid; }
233- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
234- LLVM_ABI unsigned getDimension () const ;
235- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
236- ModuleAnalysisManager::Invalidator &Inv) const ;
237- };
238-
239299// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
240300// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
241301// / its corresponding embedding.
242302class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
243- ir2vec::Vocab Vocabulary;
303+ using VocabVector = std::vector<ir2vec::Embedding>;
304+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
305+ VocabMap OpcVocab, TypeVocab, ArgVocab;
306+ VocabVector Vocab;
307+
244308 Error readVocabulary ();
245309 Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
246- ir2vec::Vocab &TargetVocab, unsigned &Dim);
310+ VocabMap &TargetVocab, unsigned &Dim);
311+ void generateNumMappedVocab ();
247312 void emitError (Error Err, LLVMContext &Ctx);
248313
249314public:
250315 LLVM_ABI static AnalysisKey Key;
251316 IR2VecVocabAnalysis () = default ;
252- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
253- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
254- using Result = IR2VecVocabResult ;
317+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
318+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
319+ using Result = ir2vec::Vocabulary ;
255320 LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
256321};
257322
0 commit comments