Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions llvm/include/llvm/CodeGen/MIR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
///
/// \file
/// This file defines the MIR2Vec vocabulary
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
/// for generating Machine IR embeddings, and related utilities.
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
/// interface for generating Machine IR embeddings, and related utilities.
///
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
/// LLVM Machine IR as embeddings which can be used as input to machine learning
Expand Down Expand Up @@ -71,25 +71,31 @@ class MIRVocabulary {
size_t TotalEntries = 0;
} Layout;

enum class Section : unsigned { Opcodes = 0, MaxSections };

ir2vec::VocabStorage Storage;
mutable std::set<std::string> UniqueBaseOpcodeNames;
void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
const TargetInstrInfo &TII;
void generateStorage(const VocabMap &OpcodeMap);
void buildCanonicalOpcodeMapping();

/// Get canonical index for a machine opcode
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;

public:
/// Static helper method for extracting base opcode names (public for testing)
/// Static method for extracting base opcode names (public for testing)
static std::string extractBaseOpcodeName(StringRef InstrName);

/// Helper method for getting canonical index for base name (public for
/// testing)
/// Get canonical index for base name (public for testing)
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;

/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;

MIRVocabulary() = default;
MIRVocabulary() = delete;
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
: Storage(std::move(Storage)), TII(TII) {}

bool isValid() const {
return UniqueBaseOpcodeNames.size() > 0 &&
Expand All @@ -103,11 +109,10 @@ class MIRVocabulary {
}

// Accessor methods
const Embedding &operator[](unsigned Index) const {
const Embedding &operator[](unsigned Opcode) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
assert(Index < Layout.TotalEntries && "Index out of bounds");
// Fixme: For now, use section 0 for all entries
return Storage[0][Index];
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
}

// Iterator access
Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/CodeGen/MIR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,21 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
//===----------------------------------------------------------------------===//

MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
const TargetInstrInfo *TII) {
const TargetInstrInfo *TII)
: TII(*TII) {
// Fixme: Use static factory methods for creating vocabularies instead of
// public constructors
// Early return for invalid inputs - creates empty/invalid vocabulary
if (!TII || OpcodeEntries.empty())
return;

buildCanonicalOpcodeMapping(*TII);
buildCanonicalOpcodeMapping();

unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
assert(CanonicalOpcodeCount > 0 &&
"No canonical opcodes found for target - invalid vocabulary");
Layout.OperandBase = CanonicalOpcodeCount;
generateStorage(OpcodeEntries, *TII);
generateStorage(OpcodeEntries);
Layout.TotalEntries = Storage.size();
}

Expand Down Expand Up @@ -105,6 +106,12 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
return std::distance(UniqueBaseOpcodeNames.begin(), It);
}

unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
return getCanonicalIndexForBaseName(BaseOpcode);
}

std::string MIRVocabulary::getStringKey(unsigned Pos) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
Expand All @@ -121,8 +128,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const {
return "";
}

void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
const TargetInstrInfo &TII) {
void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {

// Helper for handling missing entities in the vocabulary.
// Currently, we use a zero vector. In the future, we will throw an error to
Expand Down Expand Up @@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
Storage = ir2vec::VocabStorage(std::move(Sections));
}

void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) {
void MIRVocabulary::buildCanonicalOpcodeMapping() {
// Check if already built
if (!UniqueBaseOpcodeNames.empty())
return;
Expand Down
50 changes: 32 additions & 18 deletions llvm/unittests/CodeGen/MIR2VecTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
}
};

// Function to find an opcode by name
static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
if (TII->getName(Opcode) == Name)
return Opcode;
}
return -1; // Not found
}

TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
Expand All @@ -106,10 +115,10 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {

// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VM;
VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
VM["ADD"] = Val;
MIRVocabulary TestVocab(std::move(VM), TII);
VMap["ADD"] = Val;
MIRVocabulary TestVocab(std::move(VMap), TII);

unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
Expand Down Expand Up @@ -140,9 +149,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
6880u); // X86 has >6880 unique base opcodes

// Check that the embeddings for opcodes not in the vocab are zero vectors
EXPECT_TRUE(TestVocab[AddIndex].approximatelyEquals(Val));
EXPECT_TRUE(TestVocab[SubIndex].approximatelyEquals(Embedding(64, 0.0f)));
EXPECT_TRUE(TestVocab[MovIndex].approximatelyEquals(Embedding(64, 0.0f)));
int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));

int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
EXPECT_TRUE(
TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));

int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
EXPECT_TRUE(
TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
}

// Test deterministic mapping
Expand All @@ -152,9 +171,9 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {

// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VM;
VM["ADD"] = Embedding(64, 1.0f);
MIRVocabulary TestVocab(std::move(VM), TII);
VocabMap VMap;
VMap["ADD"] = Embedding(64, 1.0f);
MIRVocabulary TestVocab(std::move(VMap), TII);

unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
Expand All @@ -172,16 +191,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {

// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
// Test empty MIRVocabulary
MIRVocabulary EmptyVocab;
EXPECT_FALSE(EmptyVocab.isValid());

// Test MIRVocabulary with embeddings via VocabMap
VocabMap VM;
VM["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
VM["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
VocabMap VMap;
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0

MIRVocabulary Vocab(std::move(VM), TII);
MIRVocabulary Vocab(std::move(VMap), TII);
EXPECT_TRUE(Vocab.isValid());
EXPECT_EQ(Vocab.getDimension(), 128u);

Expand Down