3131
3232#include " llvm/ADT/DenseMap.h"
3333#include " llvm/IR/PassManager.h"
34+ #include " llvm/IR/Type.h"
3435#include " llvm/Support/CommandLine.h"
3536#include " llvm/Support/Compiler.h"
3637#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
4344class BasicBlock ;
4445class Instruction ;
4546class Function ;
46- class Type ;
4747class Value ;
4848class raw_ostream ;
4949class LLVMContext ;
50+ class IR2VecVocabAnalysis ;
5051
5152// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
5253// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -126,9 +127,73 @@ struct Embedding {
126127
127128using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
128129using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
129- // FIXME: Current the keys are strings. This can be changed to
130- // use integers for cheaper lookups.
131- using Vocab = std::map<std::string, Embedding>;
130+
131+ // / Class for storing and accessing the IR2Vec vocabulary.
132+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
133+ class Vocabulary {
134+ friend class llvm ::IR2VecVocabAnalysis;
135+ using VocabVector = std::vector<ir2vec::Embedding>;
136+ VocabVector Vocab;
137+ bool Valid = false ;
138+
139+ // / Operand kinds supported by IR2Vec Vocabulary
140+ #define OPERAND_KINDS \
141+ OPERAND_KIND (FunctionID, " Function" ) \
142+ OPERAND_KIND (PointerID, " Pointer" ) \
143+ OPERAND_KIND (ConstantID, " Constant" ) \
144+ OPERAND_KIND (VariableID, " Variable" )
145+
146+ enum class OperandKind : unsigned {
147+ #define OPERAND_KIND (Name, Str ) Name,
148+ OPERAND_KINDS
149+ #undef OPERAND_KIND
150+ MaxOperandKind
151+ };
152+
153+ #undef OPERAND_KINDS
154+
155+ // / Vocabulary layout constants
156+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
157+ #include " llvm/IR/Instruction.def"
158+ #undef LAST_OTHER_INST
159+
160+ static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1 ;
161+ static constexpr unsigned MaxOperandKinds =
162+ static_cast <unsigned >(OperandKind::MaxOperandKind);
163+
164+ // / Helper function to get vocabulary key for a given OperandKind
165+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
166+
167+ // / Helper function to classify an operand into OperandKind
168+ static OperandKind getOperandKind (const Value *Op);
169+
170+ // / Helper function to get vocabulary key for a given TypeID
171+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
172+
173+ public:
174+ Vocabulary () = default ;
175+ Vocabulary (VocabVector &&Vocab);
176+
177+ bool isValid () const ;
178+ unsigned getDimension () const ;
179+ unsigned size () const ;
180+
181+ const ir2vec::Embedding &at (unsigned Position) const ;
182+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
183+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
184+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
185+
186+ // / Returns the string key for a given index position in the vocabulary.
187+ // / This is useful for debugging or printing the vocabulary. Do not use this
188+ // / for embedding generation as string based lookups are inefficient.
189+ static StringRef getStringKey (unsigned Pos);
190+
191+ // / Create a dummy vocabulary for testing purposes.
192+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
193+
194+ bool invalidate (Module &M, const PreservedAnalyses &PA,
195+ ModuleAnalysisManager::Invalidator &Inv) const ;
196+ };
132197
133198// / Embedder provides the interface to generate embeddings (vector
134199// / representations) for instructions, basic blocks, and functions. The
@@ -139,7 +204,7 @@ using Vocab = std::map<std::string, Embedding>;
139204class Embedder {
140205protected:
141206 const Function &F;
142- const Vocab &Vocabulary ;
207+ const Vocabulary &Vocab ;
143208
144209 // / Dimension of the vector representation; captured from the input vocabulary
145210 const unsigned Dimension;
@@ -154,7 +219,7 @@ class Embedder {
154219 mutable BBEmbeddingsMap BBVecMap;
155220 mutable InstEmbeddingsMap InstVecMap;
156221
157- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
222+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
158223
159224 // / Helper function to compute embeddings. It generates embeddings for all
160225 // / the instructions and basic blocks in the function F. Logic of computing
@@ -165,16 +230,12 @@ class Embedder {
165230 // / Specific to the kind of embeddings being computed.
166231 virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
167232
168- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
169- // / zero vector.
170- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
171-
172233public:
173234 virtual ~Embedder () = default ;
174235
175236 // / Factory method to create an Embedder object.
176237 LLVM_ABI static std::unique_ptr<Embedder>
177- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
238+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
178239
179240 // / Returns a map containing instructions and the corresponding embeddings for
180241 // / the function F if it has been computed. If not, it computes the embeddings
@@ -200,56 +261,40 @@ class Embedder {
200261// / representations obtained from the Vocabulary.
201262class LLVM_ABI SymbolicEmbedder : public Embedder {
202263private:
203- // / Utility function to compute the embedding for a given type.
204- Embedding getTypeEmbedding (const Type *Ty) const ;
205-
206- // / Utility function to compute the embedding for a given operand.
207- Embedding getOperandEmbedding (const Value *Op) const ;
208-
209264 void computeEmbeddings () const override ;
210265 void computeEmbeddings (const BasicBlock &BB) const override ;
211266
212267public:
213- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
214- : Embedder(F, Vocabulary ) {
268+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
269+ : Embedder(F, Vocab ) {
215270 FuncVector = Embedding (Dimension, 0 );
216271 }
217272};
218273
219274} // namespace ir2vec
220275
221- // / Class for storing the result of the IR2VecVocabAnalysis.
222- class IR2VecVocabResult {
223- ir2vec::Vocab Vocabulary;
224- bool Valid = false ;
225-
226- public:
227- IR2VecVocabResult () = default ;
228- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
229-
230- bool isValid () const { return Valid; }
231- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
232- LLVM_ABI unsigned getDimension () const ;
233- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
234- ModuleAnalysisManager::Invalidator &Inv) const ;
235- };
236-
237276// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
238277// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
239278// / its corresponding embedding.
240279class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
241- ir2vec::Vocab Vocabulary;
280+ using VocabVector = std::vector<ir2vec::Embedding>;
281+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
282+ VocabMap OpcVocab, TypeVocab, ArgVocab;
283+ VocabVector Vocab;
284+
285+ unsigned Dim = 0 ;
242286 Error readVocabulary ();
243287 Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
244- ir2vec::Vocab &TargetVocab, unsigned &Dim);
288+ VocabMap &TargetVocab, unsigned &Dim);
289+ void generateNumMappedVocab ();
245290 void emitError (Error Err, LLVMContext &Ctx);
246291
247292public:
248293 LLVM_ABI static AnalysisKey Key;
249294 IR2VecVocabAnalysis () = default ;
250- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
251- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
252- using Result = IR2VecVocabResult ;
295+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
296+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
297+ using Result = ir2vec::Vocabulary ;
253298 LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
254299};
255300
0 commit comments