- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[IR2Vec][NFC] Add helper methods for numeric ID mapping in Vocabulary #149212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 
          
 @llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesAdd helper methods to IR2Vec's Vocabulary class for numeric ID mapping and vocabulary size calculation. These APIs will be useful in triplet generation for  (Tracking issue - #141817) Full diff: https://github.com/llvm/llvm-project/pull/149212.diff 3 Files Affected: 
 diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3d7edf08c8807..d87457cac7642 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,6 +170,10 @@ class Vocabulary {
   unsigned getDimension() const;
   size_t size() const;
 
+  static size_t expectedSize() {
+    return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
+  }
+
   /// Helper function to get vocabulary key for a given Opcode
   static StringRef getVocabKeyForOpcode(unsigned Opcode);
 
@@ -182,6 +186,11 @@ class Vocabulary {
   /// Helper function to classify an operand into OperandKind
   static OperandKind getOperandKind(const Value *Op);
 
+  /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
+  static unsigned getNumericID(unsigned Opcode);
+  static unsigned getNumericID(Type::TypeID TypeID);
+  static unsigned getNumericID(const Value *Op);
+
   /// Accessors to get the embedding for a given entity.
   const ir2vec::Embedding &operator[](unsigned Opcode) const;
   const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 898bf5b202feb..95f30fd3f4275 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -215,7 +215,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
     : Vocab(std::move(Vocab)), Valid(true) {}
 
 bool Vocabulary::isValid() const {
-  return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
+  return Vocab.size() == Vocabulary::expectedSize() && Valid;
 }
 
 size_t Vocabulary::size() const {
@@ -324,8 +324,24 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   return OperandKind::VariableID;
 }
 
+unsigned Vocabulary::getNumericID(unsigned Opcode) {
+  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+  return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
+  assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+  return MaxOpcodes + static_cast<unsigned>(TypeID);
+}
+
+unsigned Vocabulary::getNumericID(const Value *Op) {
+  unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+  assert(Index < MaxOperandKinds && "Invalid OperandKind");
+  return MaxOpcodes + MaxTypeIDs + Index;
+}
+
 StringRef Vocabulary::getStringKey(unsigned Pos) {
-  assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
+  assert(Pos < Vocabulary::expectedSize() &&
          "Position out of bounds in vocabulary");
   // Opcode
   if (Pos < MaxOpcodes)
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index cb6d633306a81..7c9a5464bfe1d 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -396,6 +396,69 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
   }
 }
 
+TEST(IR2VecVocabularyTest, NumericIDMap) {
+  // Test getNumericID for opcodes
+  EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
+  EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
+  EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
+
+  // Test getNumericID for Type IDs
+  EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
+  EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
+  EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
+  EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
+  EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
+            MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
+
+  // Test getNumericID for Value operands
+  LLVMContext Ctx;
+  Module M("TestM", Ctx);
+  FunctionType *FTy =
+      FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);
+
+  // Test Function operand
+  EXPECT_EQ(Vocabulary::getNumericID(F),
+            MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
+
+  // Test Constant operand
+  Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
+  EXPECT_EQ(Vocabulary::getNumericID(C),
+            MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
+
+  // Test Pointer operand
+  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
+  AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
+  EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
+            MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
+
+  // Test Variable operand (function argument)
+  Argument *Arg = F->getArg(0);
+  EXPECT_EQ(Vocabulary::getNumericID(Arg),
+            MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
+  // Test invalid opcode IDs
+  EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
+  EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
+
+  // Test invalid type IDs
+  EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
+               "Invalid type ID");
+  EXPECT_DEATH(
+      Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
+      "Invalid type ID");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
 TEST(IR2VecVocabularyTest, StringKeyGeneration) {
   EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
   EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
 | 
    
6ae5021    to
    3ad45e3      
    Compare
  
    bc03736    to
    68ae9f5      
    Compare
  
    42671b8    to
    a395af5      
    Compare
  
    68ae9f5    to
    1d7ca80      
    Compare
  
    a395af5    to
    586947a      
    Compare
  
    01c6091    to
    f24c6f1      
    Compare
  
    
          Merge activity
  | 
    
f24c6f1    to
    faf9baa      
    Compare
  
    | 
           @svkeerthy This didn't get reviewed at all?  | 
    
| 
           Right. Pushed it as it was a minor refactoring. Feel free to add any comments. Will fix it.  | 
    
| 
           LLVM Buildbot has detected a new failure on builder  Full details are available at: https://lab.llvm.org/buildbot/#/builders/162/builds/27073 Here is the relevant piece of the build log for the reference | 
    

Add helper methods to IR2Vec's Vocabulary class for numeric ID mapping and vocabulary size calculation. These APIs will be useful in triplet generation for
llvm-ir2vectool (See #149214).(Tracking issue - #141817)