Skip to content

Conversation

svkeerthy
Copy link
Contributor

Added factory methods for vocabulary creation. This also would fix UB issue introduced by #161713

@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

Added factory methods for vocabulary creation. This also would fix UB issue introduced by #161713


Full diff: https://github.com/llvm/llvm-project/pull/162569.diff

4 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/MIR2Vec.h (+15-18)
  • (modified) llvm/lib/CodeGen/MIR2Vec.cpp (+31-26)
  • (modified) llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll (+8-8)
  • (modified) llvm/unittests/CodeGen/MIR2VecTest.cpp (+31-4)
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index ea68b4594a2ad..dbffede50df81 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -38,6 +38,7 @@
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorOr.h"
 #include <map>
 #include <set>
@@ -92,25 +93,12 @@ class MIRVocabulary {
   /// Get the string key for a vocabulary entry at the given position
   std::string getStringKey(unsigned Pos) const;
 
-  MIRVocabulary() = delete;
-  MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
-  MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
-      : Storage(std::move(Storage)), TII(TII) {}
-
-  bool isValid() const {
-    return UniqueBaseOpcodeNames.size() > 0 &&
-           Layout.TotalEntries == Storage.size() && Storage.isValid();
-  }
-
   unsigned getDimension() const {
-    if (!isValid())
-      return 0;
     return Storage.getDimension();
   }
 
   // Accessor methods
   const Embedding &operator[](unsigned Opcode) const {
-    assert(isValid() && "MIR2Vec Vocabulary is invalid");
     unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
     return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
   }
@@ -118,20 +106,30 @@ class MIRVocabulary {
   // Iterator access
   using const_iterator = ir2vec::VocabStorage::const_iterator;
   const_iterator begin() const {
-    assert(isValid() && "MIR2Vec Vocabulary is invalid");
     return Storage.begin();
   }
 
   const_iterator end() const {
-    assert(isValid() && "MIR2Vec Vocabulary is invalid");
     return Storage.end();
   }
 
   /// Total number of entries in the vocabulary
   size_t getCanonicalSize() const {
-    assert(isValid() && "Invalid vocabulary");
     return Storage.size();
   }
+
+  MIRVocabulary() = delete;
+
+  /// Factory method to create MIRVocabulary from vocabulary map
+  static Expected<MIRVocabulary> create(VocabMap &&Entries, const TargetInstrInfo &TII);
+  
+  /// Factory method to create MIRVocabulary from existing storage
+  static Expected<MIRVocabulary> create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII);
+
+private:
+  MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
+  MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
+      : Storage(std::move(Storage)), TII(TII) {}
 };
 
 } // namespace mir2vec
@@ -145,7 +143,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
 
   StringRef getPassName() const override;
   Error readVocabulary();
-  void emitError(Error Err, LLVMContext &Ctx);
 
 protected:
   void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -156,7 +153,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
 public:
   static char ID;
   MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
-  mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
+  Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
 };
 
 /// This pass prints the embeddings in the MIR2Vec vocabulary
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 87565c0c77115..669c11d5f739c 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -49,14 +49,8 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
 //===----------------------------------------------------------------------===//
 
 MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
-                             const TargetInstrInfo *TII)
-    : TII(*TII) {
-  // Fixme: Use static factory methods for creating vocabularies instead of
-  // public constructors
-  // Early return for invalid inputs - creates empty/invalid vocabulary
-  if (!TII || OpcodeEntries.empty())
-    return;
-
+                             const TargetInstrInfo &TII)
+    : TII(TII) {
   buildCanonicalOpcodeMapping();
 
   unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
@@ -67,6 +61,24 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
   Layout.TotalEntries = Storage.size();
 }
 
+Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
+                                              const TargetInstrInfo &TII) {
+  if (Entries.empty())
+    return createStringError(errc::invalid_argument,
+                             "Empty vocabulary entries provided");
+
+  return MIRVocabulary(std::move(Entries), TII);
+}
+
+Expected<MIRVocabulary> MIRVocabulary::create(ir2vec::VocabStorage &&Storage,
+                                              const TargetInstrInfo &TII) {
+  if (!Storage.isValid())
+    return createStringError(errc::invalid_argument,
+                             "Invalid vocabulary storage provided");
+
+  return MIRVocabulary(std::move(Storage), TII);
+}
+
 std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
   // Extract base instruction name using regex to capture letters and
   // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
@@ -107,13 +119,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
 }
 
 unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
-  assert(isValid() && "MIR2Vec Vocabulary is invalid");
   auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
   return getCanonicalIndexForBaseName(BaseOpcode);
 }
 
 std::string MIRVocabulary::getStringKey(unsigned Pos) const {
-  assert(isValid() && "MIR2Vec Vocabulary is invalid");
   assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
 
   // For now, all entries are opcodes since we only have one section
@@ -232,16 +242,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
   return Error::success();
 }
 
-void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
-  Ctx.emitError(toString(std::move(Err)));
-}
-
-mir2vec::MIRVocabulary
+Expected<mir2vec::MIRVocabulary>
 MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
   if (StrVocabMap.empty()) {
     if (Error Err = readVocabulary()) {
-      emitError(std::move(Err), M.getContext());
-      return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+      return std::move(Err);
     }
   }
 
@@ -255,15 +260,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
 
     if (auto *MF = MMI.getMachineFunction(F)) {
       const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
-      return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
+      return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
     }
   }
 
-  // No machine functions available - return invalid vocabulary
-  emitError(make_error<StringError>("No machine functions found in module",
-                                    inconvertibleErrorCode()),
-            M.getContext());
-  return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+  // No machine functions available - return error
+  return createStringError(errc::invalid_argument,
+                           "No machine functions found in module");
 }
 
 //===----------------------------------------------------------------------===//
@@ -284,13 +287,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
 
 bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
   auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
-  auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
+  auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
 
-  if (!MIR2VecVocab.isValid()) {
-    OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
+  if (!MIR2VecVocabOrErr) {
+    OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
+       << toString(MIR2VecVocabOrErr.takeError()) << "\n";
     return false;
   }
 
+  auto &MIR2VecVocab = *MIR2VecVocabOrErr;
   unsigned Pos = 0;
   for (const auto &Entry : MIR2VecVocab) {
     OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
diff --git a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
index 1da516a6cd3b9..80b4048cea0c3 100644
--- a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
+++ b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
@@ -1,15 +1,15 @@
 ; REQUIRES: x86_64-linux
-; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
+; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
 
 define dso_local void @test() {
   entry:
     ret void
 }
 
-; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
-; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero
-; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file
-; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension
+; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
+; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero
+; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file
+; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index d243d82c73fc7..269e3b515c6fc 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
 #include "llvm/TargetParser/Triple.h"
@@ -118,7 +119,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
   VocabMap VMap;
   Embedding Val = Embedding(64, 1.0f);
   VMap["ADD"] = Val;
-  MIRVocabulary TestVocab(std::move(VMap), TII);
+  auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+      << "Failed to create vocabulary: "
+      << toString(TestVocabOrErr.takeError());
+  auto &TestVocab = *TestVocabOrErr;
 
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -173,7 +178,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
   // Use a minimal MIRVocabulary to trigger canonical mapping construction
   VocabMap VMap;
   VMap["ADD"] = Embedding(64, 1.0f);
-  MIRVocabulary TestVocab(std::move(VMap), TII);
+  auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+      << "Failed to create vocabulary: "
+      << toString(TestVocabOrErr.takeError());
+  auto &TestVocab = *TestVocabOrErr;
 
   unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
   unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -195,8 +204,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
   VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
   VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
 
-  MIRVocabulary Vocab(std::move(VMap), TII);
-  EXPECT_TRUE(Vocab.isValid());
+  auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+  ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+      << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+  auto &Vocab = *VocabOrErr;
   EXPECT_EQ(Vocab.getDimension(), 128u);
 
   // Test iterator - iterates over individual embeddings
@@ -214,4 +225,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
   EXPECT_GT(Count, 0u);
 }
 
+// Test factory method with empty vocabulary
+TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
+  VocabMap EmptyVMap;
+
+  auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
+  EXPECT_FALSE(static_cast<bool>(VocabOrErr))
+      << "Factory method should fail with empty vocabulary";
+
+  // Consume the error
+  if (!VocabOrErr) {
+    auto Err = VocabOrErr.takeError();
+    std::string ErrorMsg = toString(std::move(Err));
+    EXPECT_FALSE(ErrorMsg.empty());
+  }
+}
+
 } // namespace
\ No newline at end of file

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds factory methods for MIRVocabulary creation and improves error handling by replacing direct constructor usage with factory methods that return Expected. It also fixes undefined behavior issues introduced by a previous PR.

  • Factory methods MIRVocabulary::create() replace direct constructor calls
  • Error handling is improved with proper Expected<> return types instead of invalid object creation
  • Test output format is updated to reflect new error handling patterns

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
llvm/include/llvm/CodeGen/MIR2Vec.h Added factory method declarations and made constructors private
llvm/lib/CodeGen/MIR2Vec.cpp Implemented factory methods and updated error handling logic
llvm/unittests/CodeGen/MIR2VecTest.cpp Updated tests to use factory methods and added error handling validation
llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll Updated test expectations to match new error message format

Copy link

github-actions bot commented Oct 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

std::unique_ptr<LLVMContext> Ctx;
std::unique_ptr<Module> M;
std::unique_ptr<TargetMachine> TM;
const TargetInstrInfo *TII;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init to nullptr

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also add a teardown where you reset it to nullptr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetUp() method either assigns a valid value to TII or skips the test entirely. Also, TII will be destroyed when TM is destroyed automatically after each cycle right? Trying to understand if explicitly setting it to nullptr make a difference..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just simpler state to track and maintain.

@svkeerthy svkeerthy requested review from jyknight and mtrofin October 8, 2025 23:31
Copy link
Member

@mtrofin mtrofin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm after addressing comments.

@svkeerthy svkeerthy merged commit b32710a into main Oct 9, 2025
5 of 8 checks passed
@svkeerthy svkeerthy deleted the users/svkeerthy/10-08-ub-fix branch October 9, 2025 07:12
svkeerthy added a commit that referenced this pull request Oct 9, 2025
Added factory methods for vocabulary creation. This also would fix UB
issue introduced by #161713
clingfei pushed a commit to clingfei/llvm-project that referenced this pull request Oct 10, 2025
Added factory methods for vocabulary creation. This also would fix UB
issue introduced by llvm#161713
DharuniRAcharya pushed a commit to DharuniRAcharya/llvm-project that referenced this pull request Oct 13, 2025
Added factory methods for vocabulary creation. This also would fix UB
issue introduced by llvm#161713
akadutta pushed a commit to akadutta/llvm-project that referenced this pull request Oct 14, 2025
Added factory methods for vocabulary creation. This also would fix UB
issue introduced by llvm#161713
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants