@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243243 return Vocab[MaxOpcodes + MaxTypeIDs + static_cast <unsigned >(ArgKind)];
244244}
245245
246+ StringRef Vocabulary::getVocabKeyForOpcode (unsigned Opcode) {
247+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
248+ #define HANDLE_INST (NUM, OPCODE, CLASS ) \
249+ if (Opcode == NUM) { \
250+ return #OPCODE; \
251+ }
252+ #include " llvm/IR/Instruction.def"
253+ #undef HANDLE_INST
254+ return " UnknownOpcode" ;
255+ }
256+
246257StringRef Vocabulary::getVocabKeyForTypeID (Type::TypeID TypeID) {
247258 switch (TypeID) {
248259 case Type::VoidTyID:
@@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
280291 default :
281292 return " UnknownTy" ;
282293 }
294+ return " UnknownTy" ;
283295}
284296
285297// Operand kinds supported by IR2Vec - string mappings
@@ -297,9 +309,9 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
297309 OPERAND_KINDS
298310#undef OPERAND_KIND
299311 case Vocabulary::OperandKind::MaxOperandKind:
300- llvm_unreachable ( " Invalid OperandKind " ) ;
312+ return " UnknownOperand " ;
301313 }
302- llvm_unreachable ( " Unknown OperandKind " ) ;
314+ return " UnknownOperand " ;
303315}
304316
305317#undef OPERAND_KINDS
@@ -332,14 +344,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
332344 assert (Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
333345 " Position out of bounds in vocabulary" );
334346 // Opcode
335- if (Pos < MaxOpcodes) {
336- #define HANDLE_INST (NUM, OPCODE, CLASS ) \
337- if (Pos == NUM - 1 ) { \
338- return #OPCODE; \
339- }
340- #include " llvm/IR/Instruction.def"
341- #undef HANDLE_INST
342- }
347+ if (Pos < MaxOpcodes)
348+ return getVocabKeyForOpcode (Pos + 1 );
343349 // Type
344350 if (Pos < MaxOpcodes + MaxTypeIDs)
345351 return getVocabKeyForTypeID (static_cast <Type::TypeID>(Pos - MaxOpcodes));
@@ -447,21 +453,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
447453 // Handle Opcodes
448454 std::vector<Embedding> NumericOpcodeEmbeddings (Vocabulary::MaxOpcodes,
449455 Embedding (Dim, 0 ));
450- # define HANDLE_INST ( NUM, OPCODE, CLASS ) \
451- { \
452- auto It = OpcVocab.find (#OPCODE); \
453- if (It != OpcVocab.end ()) \
454- NumericOpcodeEmbeddings[NUM - 1 ] = It->second ; \
455- else \
456- handleMissingEntity (#OPCODE); \
456+ for ( unsigned Opcode : seq ( 0u , Vocabulary::MaxOpcodes)) {
457+ StringRef VocabKey = Vocabulary::getVocabKeyForOpcode (Opcode + 1 );
458+ auto It = OpcVocab.find (VocabKey. str ());
459+ if (It != OpcVocab.end ())
460+ NumericOpcodeEmbeddings[Opcode ] = It->second ;
461+ else
462+ handleMissingEntity (VocabKey. str ());
457463 }
458- #include " llvm/IR/Instruction.def"
459- #undef HANDLE_INST
460464 Vocab.insert (Vocab.end (), NumericOpcodeEmbeddings.begin (),
461465 NumericOpcodeEmbeddings.end ());
462466
463- // Handle Types using direct iteration through TypeID enum
464- // We iterate through all possible TypeID values and map them to embeddings
467+ // Handle Types
465468 std::vector<Embedding> NumericTypeEmbeddings (Vocabulary::MaxTypeIDs,
466469 Embedding (Dim, 0 ));
467470 for (unsigned TypeID : seq (0u , Vocabulary::MaxTypeIDs)) {
0 commit comments