Skip to content

Commit 1b9aab1

Browse files
committed
Added create factory methods for MIR2Vec Vocabulary
1 parent f1eb7e5 commit 1b9aab1

File tree

4 files changed

+85
-56
lines changed

4 files changed

+85
-56
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "llvm/IR/PassManager.h"
3939
#include "llvm/Pass.h"
4040
#include "llvm/Support/CommandLine.h"
41+
#include "llvm/Support/Error.h"
4142
#include "llvm/Support/ErrorOr.h"
4243
#include <map>
4344
#include <set>
@@ -92,46 +93,43 @@ class MIRVocabulary {
9293
/// Get the string key for a vocabulary entry at the given position
9394
std::string getStringKey(unsigned Pos) const;
9495

95-
MIRVocabulary() = delete;
96-
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
97-
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
98-
: Storage(std::move(Storage)), TII(TII) {}
99-
100-
bool isValid() const {
101-
return UniqueBaseOpcodeNames.size() > 0 &&
102-
Layout.TotalEntries == Storage.size() && Storage.isValid();
103-
}
104-
10596
unsigned getDimension() const {
106-
if (!isValid())
107-
return 0;
10897
return Storage.getDimension();
10998
}
11099

111100
// Accessor methods
112101
const Embedding &operator[](unsigned Opcode) const {
113-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
114102
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
115103
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
116104
}
117105

118106
// Iterator access
119107
using const_iterator = ir2vec::VocabStorage::const_iterator;
120108
const_iterator begin() const {
121-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
122109
return Storage.begin();
123110
}
124111

125112
const_iterator end() const {
126-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
127113
return Storage.end();
128114
}
129115

130116
/// Total number of entries in the vocabulary
131117
size_t getCanonicalSize() const {
132-
assert(isValid() && "Invalid vocabulary");
133118
return Storage.size();
134119
}
120+
121+
MIRVocabulary() = delete;
122+
123+
/// Factory method to create MIRVocabulary from vocabulary map
124+
static Expected<MIRVocabulary> create(VocabMap &&Entries, const TargetInstrInfo &TII);
125+
126+
/// Factory method to create MIRVocabulary from existing storage
127+
static Expected<MIRVocabulary> create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII);
128+
129+
private:
130+
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
131+
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
132+
: Storage(std::move(Storage)), TII(TII) {}
135133
};
136134

137135
} // namespace mir2vec
@@ -145,7 +143,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
145143

146144
StringRef getPassName() const override;
147145
Error readVocabulary();
148-
void emitError(Error Err, LLVMContext &Ctx);
149146

150147
protected:
151148
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -156,7 +153,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
156153
public:
157154
static char ID;
158155
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
159-
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
156+
Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
160157
};
161158

162159
/// This pass prints the embeddings in the MIR2Vec vocabulary

llvm/lib/CodeGen/MIR2Vec.cpp

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,8 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
4949
//===----------------------------------------------------------------------===//
5050

5151
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
52-
const TargetInstrInfo *TII)
53-
: TII(*TII) {
54-
// Fixme: Use static factory methods for creating vocabularies instead of
55-
// public constructors
56-
// Early return for invalid inputs - creates empty/invalid vocabulary
57-
if (!TII || OpcodeEntries.empty())
58-
return;
59-
52+
const TargetInstrInfo &TII)
53+
: TII(TII) {
6054
buildCanonicalOpcodeMapping();
6155

6256
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
@@ -67,6 +61,24 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
6761
Layout.TotalEntries = Storage.size();
6862
}
6963

64+
Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
65+
const TargetInstrInfo &TII) {
66+
if (Entries.empty())
67+
return createStringError(errc::invalid_argument,
68+
"Empty vocabulary entries provided");
69+
70+
return MIRVocabulary(std::move(Entries), TII);
71+
}
72+
73+
Expected<MIRVocabulary> MIRVocabulary::create(ir2vec::VocabStorage &&Storage,
74+
const TargetInstrInfo &TII) {
75+
if (!Storage.isValid())
76+
return createStringError(errc::invalid_argument,
77+
"Invalid vocabulary storage provided");
78+
79+
return MIRVocabulary(std::move(Storage), TII);
80+
}
81+
7082
std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
7183
// Extract base instruction name using regex to capture letters and
7284
// underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
@@ -107,13 +119,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
107119
}
108120

109121
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
110-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
111122
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
112123
return getCanonicalIndexForBaseName(BaseOpcode);
113124
}
114125

115126
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
116-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
117127
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
118128

119129
// For now, all entries are opcodes since we only have one section
@@ -232,16 +242,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
232242
return Error::success();
233243
}
234244

235-
void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
236-
Ctx.emitError(toString(std::move(Err)));
237-
}
238-
239-
mir2vec::MIRVocabulary
245+
Expected<mir2vec::MIRVocabulary>
240246
MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
241247
if (StrVocabMap.empty()) {
242248
if (Error Err = readVocabulary()) {
243-
emitError(std::move(Err), M.getContext());
244-
return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
249+
return std::move(Err);
245250
}
246251
}
247252

@@ -255,15 +260,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
255260

256261
if (auto *MF = MMI.getMachineFunction(F)) {
257262
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
258-
return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
263+
return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
259264
}
260265
}
261266

262-
// No machine functions available - return invalid vocabulary
263-
emitError(make_error<StringError>("No machine functions found in module",
264-
inconvertibleErrorCode()),
265-
M.getContext());
266-
return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
267+
// No machine functions available - return error
268+
return createStringError(errc::invalid_argument,
269+
"No machine functions found in module");
267270
}
268271

269272
//===----------------------------------------------------------------------===//
@@ -284,13 +287,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
284287

285288
bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
286289
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
287-
auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
290+
auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
288291

289-
if (!MIR2VecVocab.isValid()) {
290-
OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
292+
if (!MIR2VecVocabOrErr) {
293+
OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
294+
<< toString(MIR2VecVocabOrErr.takeError()) << "\n";
291295
return false;
292296
}
293297

298+
auto &MIR2VecVocab = *MIR2VecVocabOrErr;
294299
unsigned Pos = 0;
295300
for (const auto &Entry : MIR2VecVocab) {
296301
OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
; REQUIRES: x86_64-linux
2-
; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
3-
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
4-
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
5-
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
2+
; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
3+
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
4+
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
5+
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
66

77
define dso_local void @test() {
88
entry:
99
ret void
1010
}
1111

12-
; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
13-
; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero
14-
; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file
15-
; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension
12+
; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
13+
; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero
14+
; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file
15+
; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension

llvm/unittests/CodeGen/MIR2VecTest.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/IR/Module.h"
1818
#include "llvm/MC/TargetRegistry.h"
1919
#include "llvm/Support/TargetSelect.h"
20+
#include "llvm/Support/raw_ostream.h"
2021
#include "llvm/Target/TargetMachine.h"
2122
#include "llvm/Target/TargetOptions.h"
2223
#include "llvm/TargetParser/Triple.h"
@@ -118,7 +119,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
118119
VocabMap VMap;
119120
Embedding Val = Embedding(64, 1.0f);
120121
VMap["ADD"] = Val;
121-
MIRVocabulary TestVocab(std::move(VMap), TII);
122+
auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
123+
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
124+
<< "Failed to create vocabulary: "
125+
<< toString(TestVocabOrErr.takeError());
126+
auto &TestVocab = *TestVocabOrErr;
122127

123128
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
124129
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -173,7 +178,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
173178
// Use a minimal MIRVocabulary to trigger canonical mapping construction
174179
VocabMap VMap;
175180
VMap["ADD"] = Embedding(64, 1.0f);
176-
MIRVocabulary TestVocab(std::move(VMap), TII);
181+
auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
182+
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
183+
<< "Failed to create vocabulary: "
184+
<< toString(TestVocabOrErr.takeError());
185+
auto &TestVocab = *TestVocabOrErr;
177186

178187
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
179188
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -195,8 +204,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
195204
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
196205
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
197206

198-
MIRVocabulary Vocab(std::move(VMap), TII);
199-
EXPECT_TRUE(Vocab.isValid());
207+
auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
208+
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
209+
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
210+
auto &Vocab = *VocabOrErr;
200211
EXPECT_EQ(Vocab.getDimension(), 128u);
201212

202213
// Test iterator - iterates over individual embeddings
@@ -214,4 +225,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
214225
EXPECT_GT(Count, 0u);
215226
}
216227

228+
// Test factory method with empty vocabulary
229+
TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
230+
VocabMap EmptyVMap;
231+
232+
auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
233+
EXPECT_FALSE(static_cast<bool>(VocabOrErr))
234+
<< "Factory method should fail with empty vocabulary";
235+
236+
// Consume the error
237+
if (!VocabOrErr) {
238+
auto Err = VocabOrErr.takeError();
239+
std::string ErrorMsg = toString(std::move(Err));
240+
EXPECT_FALSE(ErrorMsg.empty());
241+
}
242+
}
243+
217244
} // namespace

0 commit comments

Comments
 (0)