From 0afa19d128d2cf8679cc56a7c7e9a47d80ba5d26 Mon Sep 17 00:00:00 2001 From: svkeerthy Date: Thu, 2 Oct 2025 18:14:53 +0000 Subject: [PATCH] MIRVocabulary changes --- llvm/include/llvm/CodeGen/MIR2Vec.h | 31 +++++++++------- llvm/lib/CodeGen/MIR2Vec.cpp | 18 ++++++---- llvm/unittests/CodeGen/MIR2VecTest.cpp | 50 ++++++++++++++++---------- 3 files changed, 62 insertions(+), 37 deletions(-) diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h index 0ccb24448a678..ea68b4594a2ad 100644 --- a/llvm/include/llvm/CodeGen/MIR2Vec.h +++ b/llvm/include/llvm/CodeGen/MIR2Vec.h @@ -8,8 +8,8 @@ /// /// \file /// This file defines the MIR2Vec vocabulary -/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface -/// for generating Machine IR embeddings, and related utilities. +/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder +/// interface for generating Machine IR embeddings, and related utilities. /// /// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the /// LLVM Machine IR as embeddings which can be used as input to machine learning @@ -71,25 +71,31 @@ class MIRVocabulary { size_t TotalEntries = 0; } Layout; + enum class Section : unsigned { Opcodes = 0, MaxSections }; + ir2vec::VocabStorage Storage; mutable std::set UniqueBaseOpcodeNames; - void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII); - void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII); + const TargetInstrInfo &TII; + void generateStorage(const VocabMap &OpcodeMap); + void buildCanonicalOpcodeMapping(); + + /// Get canonical index for a machine opcode + unsigned getCanonicalOpcodeIndex(unsigned Opcode) const; public: - /// Static helper method for extracting base opcode names (public for testing) + /// Static method for extracting base opcode names (public for testing) static std::string extractBaseOpcodeName(StringRef InstrName); - /// Helper method for getting canonical index for base name (public for - /// testing) + /// Get canonical index for base name (public for testing) unsigned getCanonicalIndexForBaseName(StringRef BaseName) const; /// Get the string key for a vocabulary entry at the given position std::string getStringKey(unsigned Pos) const; - MIRVocabulary() = default; + MIRVocabulary() = delete; MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII); - MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {} + MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII) + : Storage(std::move(Storage)), TII(TII) {} bool isValid() const { return UniqueBaseOpcodeNames.size() > 0 && @@ -103,11 +109,10 @@ class MIRVocabulary { } // Accessor methods - const Embedding &operator[](unsigned Index) const { + const Embedding &operator[](unsigned Opcode) const { assert(isValid() && "MIR2Vec Vocabulary is invalid"); - assert(Index < Layout.TotalEntries && "Index out of bounds"); - // Fixme: For now, use section 0 for all entries - return Storage[0][Index]; + unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode); + return Storage[static_cast(Section::Opcodes)][LocalIndex]; } // Iterator access diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index 83c5646629b48..87565c0c77115 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -49,20 +49,21 @@ cl::opt OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), //===----------------------------------------------------------------------===// MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, - const TargetInstrInfo *TII) { + const TargetInstrInfo *TII) + : TII(*TII) { // Fixme: Use static factory methods for creating vocabularies instead of // public constructors // Early return for invalid inputs - creates empty/invalid vocabulary if (!TII || OpcodeEntries.empty()) return; - buildCanonicalOpcodeMapping(*TII); + buildCanonicalOpcodeMapping(); unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size(); assert(CanonicalOpcodeCount > 0 && "No canonical opcodes found for target - invalid vocabulary"); Layout.OperandBase = CanonicalOpcodeCount; - generateStorage(OpcodeEntries, *TII); + generateStorage(OpcodeEntries); Layout.TotalEntries = Storage.size(); } @@ -105,6 +106,12 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const { return std::distance(UniqueBaseOpcodeNames.begin(), It); } +unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const { + assert(isValid() && "MIR2Vec Vocabulary is invalid"); + auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode)); + return getCanonicalIndexForBaseName(BaseOpcode); +} + std::string MIRVocabulary::getStringKey(unsigned Pos) const { assert(isValid() && "MIR2Vec Vocabulary is invalid"); assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary"); @@ -121,8 +128,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const { return ""; } -void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap, - const TargetInstrInfo &TII) { +void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) { // Helper for handling missing entities in the vocabulary. // Currently, we use a zero vector. In the future, we will throw an error to @@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap, Storage = ir2vec::VocabStorage(std::move(Sections)); } -void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) { +void MIRVocabulary::buildCanonicalOpcodeMapping() { // Check if already built if (!UniqueBaseOpcodeNames.empty()) return; diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index 8e3154faac8b6..d243d82c73fc7 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -95,6 +95,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test { } }; +// Function to find an opcode by name +static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) { + for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { + if (TII->getName(Opcode) == Name) + return Opcode; + } + return -1; // Not found +} + TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri"); @@ -106,10 +115,10 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Create a MIRVocabulary instance to test the mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VM; + VocabMap VMap; Embedding Val = Embedding(64, 1.0f); - VM["ADD"] = Val; - MIRVocabulary TestVocab(std::move(VM), TII); + VMap["ADD"] = Val; + MIRVocabulary TestVocab(std::move(VMap), TII); unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2); @@ -140,9 +149,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { 6880u); // X86 has >6880 unique base opcodes // Check that the embeddings for opcodes not in the vocab are zero vectors - EXPECT_TRUE(TestVocab[AddIndex].approximatelyEquals(Val)); - EXPECT_TRUE(TestVocab[SubIndex].approximatelyEquals(Embedding(64, 0.0f))); - EXPECT_TRUE(TestVocab[MovIndex].approximatelyEquals(Embedding(64, 0.0f))); + int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr"); + ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found"; + EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val)); + + int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr"); + ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found"; + EXPECT_TRUE( + TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); + + int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr"); + ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found"; + EXPECT_TRUE( + TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); } // Test deterministic mapping @@ -152,9 +171,9 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Create a MIRVocabulary instance to test deterministic mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VM; - VM["ADD"] = Embedding(64, 1.0f); - MIRVocabulary TestVocab(std::move(VM), TII); + VocabMap VMap; + VMap["ADD"] = Embedding(64, 1.0f); + MIRVocabulary TestVocab(std::move(VMap), TII); unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); @@ -172,16 +191,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { - // Test empty MIRVocabulary - MIRVocabulary EmptyVocab; - EXPECT_FALSE(EmptyVocab.isValid()); - - // Test MIRVocabulary with embeddings via VocabMap - VocabMap VM; - VM["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 - VM["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 + VocabMap VMap; + VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 + VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 - MIRVocabulary Vocab(std::move(VM), TII); + MIRVocabulary Vocab(std::move(VMap), TII); EXPECT_TRUE(Vocab.isValid()); EXPECT_EQ(Vocab.getDimension(), 128u);