From 1b9aab1424abdfaa59d65fefcad0753ad1f66e1a Mon Sep 17 00:00:00 2001 From: svkeerthy Date: Wed, 8 Oct 2025 23:10:15 +0000 Subject: [PATCH 1/3] Added create factory methods for MIR2Vec Vocabulary --- llvm/include/llvm/CodeGen/MIR2Vec.h | 33 +++++------ llvm/lib/CodeGen/MIR2Vec.cpp | 57 ++++++++++--------- .../CodeGen/MIR2Vec/vocab-error-handling.ll | 16 +++--- llvm/unittests/CodeGen/MIR2VecTest.cpp | 35 ++++++++++-- 4 files changed, 85 insertions(+), 56 deletions(-) diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h index ea68b4594a2ad..dbffede50df81 100644 --- a/llvm/include/llvm/CodeGen/MIR2Vec.h +++ b/llvm/include/llvm/CodeGen/MIR2Vec.h @@ -38,6 +38,7 @@ #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" #include "llvm/Support/ErrorOr.h" #include #include @@ -92,25 +93,12 @@ class MIRVocabulary { /// Get the string key for a vocabulary entry at the given position std::string getStringKey(unsigned Pos) const; - MIRVocabulary() = delete; - MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII); - MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII) - : Storage(std::move(Storage)), TII(TII) {} - - bool isValid() const { - return UniqueBaseOpcodeNames.size() > 0 && - Layout.TotalEntries == Storage.size() && Storage.isValid(); - } - unsigned getDimension() const { - if (!isValid()) - return 0; return Storage.getDimension(); } // Accessor methods const Embedding &operator[](unsigned Opcode) const { - assert(isValid() && "MIR2Vec Vocabulary is invalid"); unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode); return Storage[static_cast(Section::Opcodes)][LocalIndex]; } @@ -118,20 +106,30 @@ class MIRVocabulary { // Iterator access using const_iterator = ir2vec::VocabStorage::const_iterator; const_iterator begin() const { - assert(isValid() && "MIR2Vec Vocabulary is invalid"); return Storage.begin(); } const_iterator end() const { - assert(isValid() && "MIR2Vec Vocabulary is invalid"); return Storage.end(); } /// Total number of entries in the vocabulary size_t getCanonicalSize() const { - assert(isValid() && "Invalid vocabulary"); return Storage.size(); } + + MIRVocabulary() = delete; + + /// Factory method to create MIRVocabulary from vocabulary map + static Expected create(VocabMap &&Entries, const TargetInstrInfo &TII); + + /// Factory method to create MIRVocabulary from existing storage + static Expected create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII); + +private: + MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII); + MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII) + : Storage(std::move(Storage)), TII(TII) {} }; } // namespace mir2vec @@ -145,7 +143,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass { StringRef getPassName() const override; Error readVocabulary(); - void emitError(Error Err, LLVMContext &Ctx); protected: void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -156,7 +153,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass { public: static char ID; MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {} - mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M); + Expected getMIR2VecVocabulary(const Module &M); }; /// This pass prints the embeddings in the MIR2Vec vocabulary diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index 87565c0c77115..669c11d5f739c 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -49,14 +49,8 @@ cl::opt OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), //===----------------------------------------------------------------------===// MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, - const TargetInstrInfo *TII) - : TII(*TII) { - // Fixme: Use static factory methods for creating vocabularies instead of - // public constructors - // Early return for invalid inputs - creates empty/invalid vocabulary - if (!TII || OpcodeEntries.empty()) - return; - + const TargetInstrInfo &TII) + : TII(TII) { buildCanonicalOpcodeMapping(); unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size(); @@ -67,6 +61,24 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, Layout.TotalEntries = Storage.size(); } +Expected MIRVocabulary::create(VocabMap &&Entries, + const TargetInstrInfo &TII) { + if (Entries.empty()) + return createStringError(errc::invalid_argument, + "Empty vocabulary entries provided"); + + return MIRVocabulary(std::move(Entries), TII); +} + +Expected MIRVocabulary::create(ir2vec::VocabStorage &&Storage, + const TargetInstrInfo &TII) { + if (!Storage.isValid()) + return createStringError(errc::invalid_argument, + "Invalid vocabulary storage provided"); + + return MIRVocabulary(std::move(Storage), TII); +} + std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) { // Extract base instruction name using regex to capture letters and // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE" @@ -107,13 +119,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const { } unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const { - assert(isValid() && "MIR2Vec Vocabulary is invalid"); auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode)); return getCanonicalIndexForBaseName(BaseOpcode); } std::string MIRVocabulary::getStringKey(unsigned Pos) const { - assert(isValid() && "MIR2Vec Vocabulary is invalid"); assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary"); // For now, all entries are opcodes since we only have one section @@ -232,16 +242,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() { return Error::success(); } -void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) { - Ctx.emitError(toString(std::move(Err))); -} - -mir2vec::MIRVocabulary +Expected MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { if (StrVocabMap.empty()) { if (Error Err = readVocabulary()) { - emitError(std::move(Err), M.getContext()); - return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr); + return std::move(Err); } } @@ -255,15 +260,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { if (auto *MF = MMI.getMachineFunction(F)) { const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo(); - return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII); + return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII); } } - // No machine functions available - return invalid vocabulary - emitError(make_error("No machine functions found in module", - inconvertibleErrorCode()), - M.getContext()); - return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr); + // No machine functions available - return error + return createStringError(errc::invalid_argument, + "No machine functions found in module"); } //===----------------------------------------------------------------------===// @@ -284,13 +287,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) { bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) { auto &Analysis = getAnalysis(); - auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M); + auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M); - if (!MIR2VecVocab.isValid()) { - OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n"; + if (!MIR2VecVocabOrErr) { + OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - " + << toString(MIR2VecVocabOrErr.takeError()) << "\n"; return false; } + auto &MIR2VecVocab = *MIR2VecVocabOrErr; unsigned Pos = 0; for (const auto &Entry : MIR2VecVocab) { OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": "; diff --git a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll index 1da516a6cd3b9..80b4048cea0c3 100644 --- a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll +++ b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll @@ -1,15 +1,15 @@ ; REQUIRES: x86_64-linux -; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID -; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM -; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES -; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS +; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID +; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM +; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES +; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS define dso_local void @test() { entry: ret void } -; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path -; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero -; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file -; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension +; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path +; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero +; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file +; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index d243d82c73fc7..269e3b515c6fc 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/Module.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/TargetParser/Triple.h" @@ -118,7 +119,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { VocabMap VMap; Embedding Val = Embedding(64, 1.0f); VMap["ADD"] = Val; - MIRVocabulary TestVocab(std::move(VMap), TII); + auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + ASSERT_TRUE(static_cast(TestVocabOrErr)) + << "Failed to create vocabulary: " + << toString(TestVocabOrErr.takeError()); + auto &TestVocab = *TestVocabOrErr; unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2); @@ -173,7 +178,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Use a minimal MIRVocabulary to trigger canonical mapping construction VocabMap VMap; VMap["ADD"] = Embedding(64, 1.0f); - MIRVocabulary TestVocab(std::move(VMap), TII); + auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + ASSERT_TRUE(static_cast(TestVocabOrErr)) + << "Failed to create vocabulary: " + << toString(TestVocabOrErr.takeError()); + auto &TestVocab = *TestVocabOrErr; unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); @@ -195,8 +204,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 - MIRVocabulary Vocab(std::move(VMap), TII); - EXPECT_TRUE(Vocab.isValid()); + auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + ASSERT_TRUE(static_cast(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; EXPECT_EQ(Vocab.getDimension(), 128u); // Test iterator - iterates over individual embeddings @@ -214,4 +225,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { EXPECT_GT(Count, 0u); } +// Test factory method with empty vocabulary +TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) { + VocabMap EmptyVMap; + + auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII); + EXPECT_FALSE(static_cast(VocabOrErr)) + << "Factory method should fail with empty vocabulary"; + + // Consume the error + if (!VocabOrErr) { + auto Err = VocabOrErr.takeError(); + std::string ErrorMsg = toString(std::move(Err)); + EXPECT_FALSE(ErrorMsg.empty()); + } +} + } // namespace \ No newline at end of file From 7214a4794cf3a4ec7fa508b1bb4757940c2ccb8b Mon Sep 17 00:00:00 2001 From: svkeerthy Date: Wed, 8 Oct 2025 23:27:34 +0000 Subject: [PATCH 2/3] Removed unused create and formatting fixes --- llvm/include/llvm/CodeGen/MIR2Vec.h | 24 ++++++------------------ llvm/lib/CodeGen/MIR2Vec.cpp | 9 --------- llvm/unittests/CodeGen/MIR2VecTest.cpp | 2 +- 3 files changed, 7 insertions(+), 28 deletions(-) diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h index dbffede50df81..7b1b5d9aee15d 100644 --- a/llvm/include/llvm/CodeGen/MIR2Vec.h +++ b/llvm/include/llvm/CodeGen/MIR2Vec.h @@ -93,9 +93,7 @@ class MIRVocabulary { /// Get the string key for a vocabulary entry at the given position std::string getStringKey(unsigned Pos) const; - unsigned getDimension() const { - return Storage.getDimension(); - } + unsigned getDimension() const { return Storage.getDimension(); } // Accessor methods const Embedding &operator[](unsigned Opcode) const { @@ -105,31 +103,21 @@ class MIRVocabulary { // Iterator access using const_iterator = ir2vec::VocabStorage::const_iterator; - const_iterator begin() const { - return Storage.begin(); - } + const_iterator begin() const { return Storage.begin(); } - const_iterator end() const { - return Storage.end(); - } + const_iterator end() const { return Storage.end(); } /// Total number of entries in the vocabulary - size_t getCanonicalSize() const { - return Storage.size(); - } + size_t getCanonicalSize() const { return Storage.size(); } MIRVocabulary() = delete; /// Factory method to create MIRVocabulary from vocabulary map - static Expected create(VocabMap &&Entries, const TargetInstrInfo &TII); - - /// Factory method to create MIRVocabulary from existing storage - static Expected create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII); + static Expected create(VocabMap &&Entries, + const TargetInstrInfo &TII); private: MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII); - MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII) - : Storage(std::move(Storage)), TII(TII) {} }; } // namespace mir2vec diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index 669c11d5f739c..e85976547a2c2 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -70,15 +70,6 @@ Expected MIRVocabulary::create(VocabMap &&Entries, return MIRVocabulary(std::move(Entries), TII); } -Expected MIRVocabulary::create(ir2vec::VocabStorage &&Storage, - const TargetInstrInfo &TII) { - if (!Storage.isValid()) - return createStringError(errc::invalid_argument, - "Invalid vocabulary storage provided"); - - return MIRVocabulary(std::move(Storage), TII); -} - std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) { // Extract base instruction name using regex to capture letters and // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE" diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index 269e3b515c6fc..6ce791695e3e4 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -241,4 +241,4 @@ TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) { } } -} // namespace \ No newline at end of file +} // namespace From 43e7b9a0f29e6b7aa74f219b83c54f0d1cf8171d Mon Sep 17 00:00:00 2001 From: svkeerthy Date: Thu, 9 Oct 2025 06:59:28 +0000 Subject: [PATCH 3/3] Addressing review comments --- llvm/unittests/CodeGen/MIR2VecTest.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index 6ce791695e3e4..11222b4d02fa3 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -53,7 +53,7 @@ class MIR2VecVocabTestFixture : public ::testing::Test { std::unique_ptr Ctx; std::unique_ptr M; std::unique_ptr TM; - const TargetInstrInfo *TII; + const TargetInstrInfo *TII = nullptr; static void SetUpTestCase() { InitializeAllTargets(); @@ -94,6 +94,8 @@ class MIR2VecVocabTestFixture : public ::testing::Test { return; } } + + void TearDown() override { TII = nullptr; } }; // Function to find an opcode by name