@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243243  return  Vocab[MaxOpcodes + MaxTypeIDs + static_cast <unsigned >(ArgKind)];
244244}
245245
246+ StringRef Vocabulary::getVocabKeyForOpcode (unsigned  Opcode) {
247+   assert (Opcode >= 1  && Opcode <= MaxOpcodes && " Invalid opcode"  );
248+ #define  HANDLE_INST (NUM, OPCODE, CLASS )                                        \
249+   if  (Opcode == NUM) {                                                         \
250+     return  #OPCODE;                                                            \
251+   }
252+ #include  " llvm/IR/Instruction.def" 
253+ #undef  HANDLE_INST
254+   return  " UnknownOpcode"  ;
255+ }
256+ 
246257StringRef Vocabulary::getVocabKeyForTypeID (Type::TypeID TypeID) {
247258  switch  (TypeID) {
248259  case  Type::VoidTyID:
@@ -279,6 +290,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
279290  case  Type::TargetExtTyID:
280291    return  " UnknownTy"  ;
281292  }
293+   return  " UnknownTy"  ;
282294}
283295
284296StringRef Vocabulary::getVocabKeyForOperandKind (Vocabulary::OperandKind Kind) {
@@ -315,14 +327,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
315327  assert (Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
316328         " Position out of bounds in vocabulary"  );
317329  //  Opcode
318-   if  (Pos < MaxOpcodes) {
319- #define  HANDLE_INST (NUM, OPCODE, CLASS )                                        \
320-   if  (Pos == NUM - 1 ) {                                                        \
321-     return  #OPCODE;                                                            \
322-   }
323- #include  " llvm/IR/Instruction.def" 
324- #undef  HANDLE_INST
325-   }
330+   if  (Pos < MaxOpcodes)
331+     return  getVocabKeyForOpcode (Pos + 1 );
326332  //  Type
327333  if  (Pos < MaxOpcodes + MaxTypeIDs)
328334    return  getVocabKeyForTypeID (static_cast <Type::TypeID>(Pos - MaxOpcodes));
@@ -430,21 +436,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
430436  //  Handle Opcodes
431437  std::vector<Embedding> NumericOpcodeEmbeddings (Vocabulary::MaxOpcodes,
432438                                                 Embedding (Dim, 0 ));
433- # define   HANDLE_INST ( NUM, OPCODE, CLASS )                                        \ 
434-   {                                                                            \ 
435-     auto  It = OpcVocab.find (#OPCODE);                                          \ 
436-     if  (It != OpcVocab.end ())                                                  \ 
437-       NumericOpcodeEmbeddings[NUM -  1 ] = It->second ;                           \ 
438-     else                                                                        \ 
439-       handleMissingEntity (#OPCODE);                                            \ 
439+    for  ( unsigned  Opcode :  seq ( 0u , Vocabulary::MaxOpcodes)) { 
440+     StringRef VocabKey =  Vocabulary::getVocabKeyForOpcode (Opcode +  1 ); 
441+     auto  It = OpcVocab.find (VocabKey. str ()); 
442+     if  (It != OpcVocab.end ())
443+       NumericOpcodeEmbeddings[Opcode ] = It->second ;
444+     else 
445+       handleMissingEntity (VocabKey. str ()); 
440446  }
441- #include  " llvm/IR/Instruction.def" 
442- #undef  HANDLE_INST
443447  Vocab.insert (Vocab.end (), NumericOpcodeEmbeddings.begin (),
444448               NumericOpcodeEmbeddings.end ());
445449
446-   //  Handle Types using direct iteration through TypeID enum
447-   //  We iterate through all possible TypeID values and map them to embeddings
450+   //  Handle Types
448451  std::vector<Embedding> NumericTypeEmbeddings (Vocabulary::MaxTypeIDs,
449452                                               Embedding (Dim, 0 ));
450453  for  (unsigned  TypeID : seq (0u , Vocabulary::MaxTypeIDs)) {
0 commit comments