3131
3232#include " llvm/ADT/DenseMap.h"
3333#include " llvm/IR/PassManager.h"
34+ #include " llvm/IR/Type.h"
3435#include " llvm/Support/CommandLine.h"
3536#include " llvm/Support/Compiler.h"
3637#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
4344class BasicBlock ;
4445class Instruction ;
4546class Function ;
46- class Type ;
4747class Value ;
4848class raw_ostream ;
4949class LLVMContext ;
50+ class IR2VecVocabAnalysis ;
5051
5152// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
5253// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -128,9 +129,73 @@ struct Embedding {
128129
129130using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
130131using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
131- // FIXME: Current the keys are strings. This can be changed to
132- // use integers for cheaper lookups.
133- using Vocab = std::map<std::string, Embedding>;
132+
133+ // / Class for storing and accessing the IR2Vec vocabulary.
134+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
135+ class Vocabulary {
136+ friend class llvm ::IR2VecVocabAnalysis;
137+ using VocabVector = std::vector<ir2vec::Embedding>;
138+ VocabVector Vocab;
139+ bool Valid = false ;
140+
141+ // / Operand kinds supported by IR2Vec Vocabulary
142+ #define OPERAND_KINDS \
143+ OPERAND_KIND (FunctionID, " Function" ) \
144+ OPERAND_KIND (PointerID, " Pointer" ) \
145+ OPERAND_KIND (ConstantID, " Constant" ) \
146+ OPERAND_KIND (VariableID, " Variable" )
147+
148+ enum class OperandKind : unsigned {
149+ #define OPERAND_KIND (Name, Str ) Name,
150+ OPERAND_KINDS
151+ #undef OPERAND_KIND
152+ MaxOperandKind
153+ };
154+
155+ #undef OPERAND_KINDS
156+
157+ // / Vocabulary layout constants
158+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
159+ #include " llvm/IR/Instruction.def"
160+ #undef LAST_OTHER_INST
161+
162+ static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1 ;
163+ static constexpr unsigned MaxOperandKinds =
164+ static_cast <unsigned >(OperandKind::MaxOperandKind);
165+
166+ // / Helper function to get vocabulary key for a given OperandKind
167+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
168+
169+ // / Helper function to classify an operand into OperandKind
170+ static OperandKind getOperandKind (const Value *Op);
171+
172+ // / Helper function to get vocabulary key for a given TypeID
173+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
174+
175+ public:
176+ Vocabulary () = default ;
177+ Vocabulary (VocabVector &&Vocab);
178+
179+ bool isValid () const ;
180+ unsigned getDimension () const ;
181+ size_t size () const ;
182+
183+ const ir2vec::Embedding &at (unsigned Position) const ;
184+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
185+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
186+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
187+
188+ // / Returns the string key for a given index position in the vocabulary.
189+ // / This is useful for debugging or printing the vocabulary. Do not use this
190+ // / for embedding generation as string based lookups are inefficient.
191+ static StringRef getStringKey (unsigned Pos);
192+
193+ // / Create a dummy vocabulary for testing purposes.
194+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
195+
196+ bool invalidate (Module &M, const PreservedAnalyses &PA,
197+ ModuleAnalysisManager::Invalidator &Inv) const ;
198+ };
134199
135200// / Embedder provides the interface to generate embeddings (vector
136201// / representations) for instructions, basic blocks, and functions. The
@@ -141,7 +206,7 @@ using Vocab = std::map<std::string, Embedding>;
141206class Embedder {
142207protected:
143208 const Function &F;
144- const Vocab &Vocabulary ;
209+ const Vocabulary &Vocab ;
145210
146211 // / Dimension of the vector representation; captured from the input vocabulary
147212 const unsigned Dimension;
@@ -156,7 +221,7 @@ class Embedder {
156221 mutable BBEmbeddingsMap BBVecMap;
157222 mutable InstEmbeddingsMap InstVecMap;
158223
159- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
224+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
160225
161226 // / Helper function to compute embeddings. It generates embeddings for all
162227 // / the instructions and basic blocks in the function F. Logic of computing
@@ -167,16 +232,12 @@ class Embedder {
167232 // / Specific to the kind of embeddings being computed.
168233 virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
169234
170- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
171- // / zero vector.
172- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
173-
174235public:
175236 virtual ~Embedder () = default ;
176237
177238 // / Factory method to create an Embedder object.
178239 LLVM_ABI static std::unique_ptr<Embedder>
179- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
240+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
180241
181242 // / Returns a map containing instructions and the corresponding embeddings for
182243 // / the function F if it has been computed. If not, it computes the embeddings
@@ -202,56 +263,40 @@ class Embedder {
202263// / representations obtained from the Vocabulary.
203264class LLVM_ABI SymbolicEmbedder : public Embedder {
204265private:
205- // / Utility function to compute the embedding for a given type.
206- Embedding getTypeEmbedding (const Type *Ty) const ;
207-
208- // / Utility function to compute the embedding for a given operand.
209- Embedding getOperandEmbedding (const Value *Op) const ;
210-
211266 void computeEmbeddings () const override ;
212267 void computeEmbeddings (const BasicBlock &BB) const override ;
213268
214269public:
215- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
216- : Embedder(F, Vocabulary ) {
270+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
271+ : Embedder(F, Vocab ) {
217272 FuncVector = Embedding (Dimension, 0 );
218273 }
219274};
220275
221276} // namespace ir2vec
222277
223- // / Class for storing the result of the IR2VecVocabAnalysis.
224- class IR2VecVocabResult {
225- ir2vec::Vocab Vocabulary;
226- bool Valid = false ;
227-
228- public:
229- IR2VecVocabResult () = default ;
230- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
231-
232- bool isValid () const { return Valid; }
233- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
234- LLVM_ABI unsigned getDimension () const ;
235- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
236- ModuleAnalysisManager::Invalidator &Inv) const ;
237- };
238-
239278// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
240279// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
241280// / its corresponding embedding.
242281class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
243- ir2vec::Vocab Vocabulary;
282+ using VocabVector = std::vector<ir2vec::Embedding>;
283+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
284+ VocabMap OpcVocab, TypeVocab, ArgVocab;
285+ VocabVector Vocab;
286+
287+ unsigned Dim = 0 ;
244288 Error readVocabulary ();
245289 Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
246- ir2vec::Vocab &TargetVocab, unsigned &Dim);
290+ VocabMap &TargetVocab, unsigned &Dim);
291+ void generateNumMappedVocab ();
247292 void emitError (Error Err, LLVMContext &Ctx);
248293
249294public:
250295 LLVM_ABI static AnalysisKey Key;
251296 IR2VecVocabAnalysis () = default ;
252- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
253- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
254- using Result = IR2VecVocabResult ;
297+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
298+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
299+ using Result = ir2vec::Vocabulary ;
255300 LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
256301};
257302
0 commit comments