Skip to content

Commit 566040e

Browse files
authored
[MIR2Vec] Refactor MIR vocabulary to use opcode-based indexing (#161713)
Refactor MIRVocabulary to improve opcode lookup and add Section enum for better organization. This is useful for embedder lookups (next patches) (Tracking issue - #141817)
1 parent 289e85b commit 566040e

File tree

3 files changed

+62
-37
lines changed

3 files changed

+62
-37
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
///
99
/// \file
1010
/// This file defines the MIR2Vec vocabulary
11-
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
12-
/// for generating Machine IR embeddings, and related utilities.
11+
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
12+
/// interface for generating Machine IR embeddings, and related utilities.
1313
///
1414
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
1515
/// LLVM Machine IR as embeddings which can be used as input to machine learning
@@ -71,25 +71,31 @@ class MIRVocabulary {
7171
size_t TotalEntries = 0;
7272
} Layout;
7373

74+
enum class Section : unsigned { Opcodes = 0, MaxSections };
75+
7476
ir2vec::VocabStorage Storage;
7577
mutable std::set<std::string> UniqueBaseOpcodeNames;
76-
void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
77-
void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
78+
const TargetInstrInfo &TII;
79+
void generateStorage(const VocabMap &OpcodeMap);
80+
void buildCanonicalOpcodeMapping();
81+
82+
/// Get canonical index for a machine opcode
83+
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
7884

7985
public:
80-
/// Static helper method for extracting base opcode names (public for testing)
86+
/// Static method for extracting base opcode names (public for testing)
8187
static std::string extractBaseOpcodeName(StringRef InstrName);
8288

83-
/// Helper method for getting canonical index for base name (public for
84-
/// testing)
89+
/// Get canonical index for base name (public for testing)
8590
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
8691

8792
/// Get the string key for a vocabulary entry at the given position
8893
std::string getStringKey(unsigned Pos) const;
8994

90-
MIRVocabulary() = default;
95+
MIRVocabulary() = delete;
9196
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
92-
MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
97+
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
98+
: Storage(std::move(Storage)), TII(TII) {}
9399

94100
bool isValid() const {
95101
return UniqueBaseOpcodeNames.size() > 0 &&
@@ -103,11 +109,10 @@ class MIRVocabulary {
103109
}
104110

105111
// Accessor methods
106-
const Embedding &operator[](unsigned Index) const {
112+
const Embedding &operator[](unsigned Opcode) const {
107113
assert(isValid() && "MIR2Vec Vocabulary is invalid");
108-
assert(Index < Layout.TotalEntries && "Index out of bounds");
109-
// Fixme: For now, use section 0 for all entries
110-
return Storage[0][Index];
114+
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
115+
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
111116
}
112117

113118
// Iterator access

llvm/lib/CodeGen/MIR2Vec.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,21 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
4949
//===----------------------------------------------------------------------===//
5050

5151
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
52-
const TargetInstrInfo *TII) {
52+
const TargetInstrInfo *TII)
53+
: TII(*TII) {
5354
// Fixme: Use static factory methods for creating vocabularies instead of
5455
// public constructors
5556
// Early return for invalid inputs - creates empty/invalid vocabulary
5657
if (!TII || OpcodeEntries.empty())
5758
return;
5859

59-
buildCanonicalOpcodeMapping(*TII);
60+
buildCanonicalOpcodeMapping();
6061

6162
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
6263
assert(CanonicalOpcodeCount > 0 &&
6364
"No canonical opcodes found for target - invalid vocabulary");
6465
Layout.OperandBase = CanonicalOpcodeCount;
65-
generateStorage(OpcodeEntries, *TII);
66+
generateStorage(OpcodeEntries);
6667
Layout.TotalEntries = Storage.size();
6768
}
6869

@@ -105,6 +106,12 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
105106
return std::distance(UniqueBaseOpcodeNames.begin(), It);
106107
}
107108

109+
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
110+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
111+
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
112+
return getCanonicalIndexForBaseName(BaseOpcode);
113+
}
114+
108115
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
109116
assert(isValid() && "MIR2Vec Vocabulary is invalid");
110117
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
@@ -121,8 +128,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const {
121128
return "";
122129
}
123130

124-
void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
125-
const TargetInstrInfo &TII) {
131+
void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
126132

127133
// Helper for handling missing entities in the vocabulary.
128134
// Currently, we use a zero vector. In the future, we will throw an error to
@@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
168174
Storage = ir2vec::VocabStorage(std::move(Sections));
169175
}
170176

171-
void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) {
177+
void MIRVocabulary::buildCanonicalOpcodeMapping() {
172178
// Check if already built
173179
if (!UniqueBaseOpcodeNames.empty())
174180
return;

llvm/unittests/CodeGen/MIR2VecTest.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
9595
}
9696
};
9797

98+
// Function to find an opcode by name
99+
static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
100+
for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
101+
if (TII->getName(Opcode) == Name)
102+
return Opcode;
103+
}
104+
return -1; // Not found
105+
}
106+
98107
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
99108
// Test that same base opcodes get same canonical indices
100109
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -106,10 +115,10 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
106115

107116
// Create a MIRVocabulary instance to test the mapping
108117
// Use a minimal MIRVocabulary to trigger canonical mapping construction
109-
VocabMap VM;
118+
VocabMap VMap;
110119
Embedding Val = Embedding(64, 1.0f);
111-
VM["ADD"] = Val;
112-
MIRVocabulary TestVocab(std::move(VM), TII);
120+
VMap["ADD"] = Val;
121+
MIRVocabulary TestVocab(std::move(VMap), TII);
113122

114123
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
115124
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -140,9 +149,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
140149
6880u); // X86 has >6880 unique base opcodes
141150

142151
// Check that the embeddings for opcodes not in the vocab are zero vectors
143-
EXPECT_TRUE(TestVocab[AddIndex].approximatelyEquals(Val));
144-
EXPECT_TRUE(TestVocab[SubIndex].approximatelyEquals(Embedding(64, 0.0f)));
145-
EXPECT_TRUE(TestVocab[MovIndex].approximatelyEquals(Embedding(64, 0.0f)));
152+
int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
153+
ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
154+
EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
155+
156+
int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
157+
ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
158+
EXPECT_TRUE(
159+
TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
160+
161+
int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
162+
ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
163+
EXPECT_TRUE(
164+
TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
146165
}
147166

148167
// Test deterministic mapping
@@ -152,9 +171,9 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
152171

153172
// Create a MIRVocabulary instance to test deterministic mapping
154173
// Use a minimal MIRVocabulary to trigger canonical mapping construction
155-
VocabMap VM;
156-
VM["ADD"] = Embedding(64, 1.0f);
157-
MIRVocabulary TestVocab(std::move(VM), TII);
174+
VocabMap VMap;
175+
VMap["ADD"] = Embedding(64, 1.0f);
176+
MIRVocabulary TestVocab(std::move(VMap), TII);
158177

159178
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
160179
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -172,16 +191,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
172191

173192
// Test MIRVocabulary construction
174193
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
175-
// Test empty MIRVocabulary
176-
MIRVocabulary EmptyVocab;
177-
EXPECT_FALSE(EmptyVocab.isValid());
178-
179-
// Test MIRVocabulary with embeddings via VocabMap
180-
VocabMap VM;
181-
VM["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
182-
VM["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
194+
VocabMap VMap;
195+
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
196+
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
183197

184-
MIRVocabulary Vocab(std::move(VM), TII);
198+
MIRVocabulary Vocab(std::move(VMap), TII);
185199
EXPECT_TRUE(Vocab.isValid());
186200
EXPECT_EQ(Vocab.getDimension(), 128u);
187201

0 commit comments

Comments
 (0)