@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243243 return Vocab[MaxOpcodes + MaxTypeIDs + static_cast <unsigned >(ArgKind)];
244244}
245245
246+ StringRef Vocabulary::getVocabKeyForOpcode (unsigned Opcode) {
247+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
248+ #define HANDLE_INST (NUM, OPCODE, CLASS ) \
249+ if (Opcode == NUM) { \
250+ return #OPCODE; \
251+ }
252+ #include " llvm/IR/Instruction.def"
253+ #undef HANDLE_INST
254+ return " UnknownOpcode" ;
255+ }
256+
246257StringRef Vocabulary::getVocabKeyForTypeID (Type::TypeID TypeID) {
247258 switch (TypeID) {
248259 case Type::VoidTyID:
@@ -279,6 +290,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
279290 case Type::TargetExtTyID:
280291 return " UnknownTy" ;
281292 }
293+ return " UnknownTy" ;
282294}
283295
284296StringRef Vocabulary::getVocabKeyForOperandKind (Vocabulary::OperandKind Kind) {
@@ -315,14 +327,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
315327 assert (Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
316328 " Position out of bounds in vocabulary" );
317329 // Opcode
318- if (Pos < MaxOpcodes) {
319- #define HANDLE_INST (NUM, OPCODE, CLASS ) \
320- if (Pos == NUM - 1 ) { \
321- return #OPCODE; \
322- }
323- #include " llvm/IR/Instruction.def"
324- #undef HANDLE_INST
325- }
330+ if (Pos < MaxOpcodes)
331+ return getVocabKeyForOpcode (Pos + 1 );
326332 // Type
327333 if (Pos < MaxOpcodes + MaxTypeIDs)
328334 return getVocabKeyForTypeID (static_cast <Type::TypeID>(Pos - MaxOpcodes));
@@ -430,21 +436,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
430436 // Handle Opcodes
431437 std::vector<Embedding> NumericOpcodeEmbeddings (Vocabulary::MaxOpcodes,
432438 Embedding (Dim, 0 ));
433- # define HANDLE_INST ( NUM, OPCODE, CLASS ) \
434- { \
435- auto It = OpcVocab.find (#OPCODE); \
436- if (It != OpcVocab.end ()) \
437- NumericOpcodeEmbeddings[NUM - 1 ] = It->second ; \
438- else \
439- handleMissingEntity (#OPCODE); \
439+ for ( unsigned Opcode : seq ( 0u , Vocabulary::MaxOpcodes)) {
440+ StringRef VocabKey = Vocabulary::getVocabKeyForOpcode (Opcode + 1 );
441+ auto It = OpcVocab.find (VocabKey. str ());
442+ if (It != OpcVocab.end ())
443+ NumericOpcodeEmbeddings[Opcode ] = It->second ;
444+ else
445+ handleMissingEntity (VocabKey. str ());
440446 }
441- #include " llvm/IR/Instruction.def"
442- #undef HANDLE_INST
443447 Vocab.insert (Vocab.end (), NumericOpcodeEmbeddings.begin (),
444448 NumericOpcodeEmbeddings.end ());
445449
446- // Handle Types using direct iteration through TypeID enum
447- // We iterate through all possible TypeID values and map them to embeddings
450+ // Handle Types
448451 std::vector<Embedding> NumericTypeEmbeddings (Vocabulary::MaxTypeIDs,
449452 Embedding (Dim, 0 ));
450453 for (unsigned TypeID : seq (0u , Vocabulary::MaxTypeIDs)) {
0 commit comments