Skip to content

Commit b32710a

Browse files
authored
[MIR2Vec] Added create factory methods for Vocabulary (#162569)
Added factory methods for vocabulary creation. This also would fix UB issue introduced by #161713
1 parent cc1ca59 commit b32710a

File tree

4 files changed

+80
-70
lines changed

4 files changed

+80
-70
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "llvm/IR/PassManager.h"
3939
#include "llvm/Pass.h"
4040
#include "llvm/Support/CommandLine.h"
41+
#include "llvm/Support/Error.h"
4142
#include "llvm/Support/ErrorOr.h"
4243
#include <map>
4344
#include <set>
@@ -92,46 +93,31 @@ class MIRVocabulary {
9293
/// Get the string key for a vocabulary entry at the given position
9394
std::string getStringKey(unsigned Pos) const;
9495

95-
MIRVocabulary() = delete;
96-
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
97-
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
98-
: Storage(std::move(Storage)), TII(TII) {}
99-
100-
bool isValid() const {
101-
return UniqueBaseOpcodeNames.size() > 0 &&
102-
Layout.TotalEntries == Storage.size() && Storage.isValid();
103-
}
104-
105-
unsigned getDimension() const {
106-
if (!isValid())
107-
return 0;
108-
return Storage.getDimension();
109-
}
96+
unsigned getDimension() const { return Storage.getDimension(); }
11097

11198
// Accessor methods
11299
const Embedding &operator[](unsigned Opcode) const {
113-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
114100
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
115101
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
116102
}
117103

118104
// Iterator access
119105
using const_iterator = ir2vec::VocabStorage::const_iterator;
120-
const_iterator begin() const {
121-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
122-
return Storage.begin();
123-
}
106+
const_iterator begin() const { return Storage.begin(); }
124107

125-
const_iterator end() const {
126-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
127-
return Storage.end();
128-
}
108+
const_iterator end() const { return Storage.end(); }
129109

130110
/// Total number of entries in the vocabulary
131-
size_t getCanonicalSize() const {
132-
assert(isValid() && "Invalid vocabulary");
133-
return Storage.size();
134-
}
111+
size_t getCanonicalSize() const { return Storage.size(); }
112+
113+
MIRVocabulary() = delete;
114+
115+
/// Factory method to create MIRVocabulary from vocabulary map
116+
static Expected<MIRVocabulary> create(VocabMap &&Entries,
117+
const TargetInstrInfo &TII);
118+
119+
private:
120+
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
135121
};
136122

137123
} // namespace mir2vec
@@ -145,7 +131,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
145131

146132
StringRef getPassName() const override;
147133
Error readVocabulary();
148-
void emitError(Error Err, LLVMContext &Ctx);
149134

150135
protected:
151136
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -156,7 +141,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
156141
public:
157142
static char ID;
158143
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
159-
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
144+
Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
160145
};
161146

162147
/// This pass prints the embeddings in the MIR2Vec vocabulary

llvm/lib/CodeGen/MIR2Vec.cpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,8 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
4949
//===----------------------------------------------------------------------===//
5050

5151
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
52-
const TargetInstrInfo *TII)
53-
: TII(*TII) {
54-
// Fixme: Use static factory methods for creating vocabularies instead of
55-
// public constructors
56-
// Early return for invalid inputs - creates empty/invalid vocabulary
57-
if (!TII || OpcodeEntries.empty())
58-
return;
59-
52+
const TargetInstrInfo &TII)
53+
: TII(TII) {
6054
buildCanonicalOpcodeMapping();
6155

6256
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
@@ -67,6 +61,15 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
6761
Layout.TotalEntries = Storage.size();
6862
}
6963

64+
Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
65+
const TargetInstrInfo &TII) {
66+
if (Entries.empty())
67+
return createStringError(errc::invalid_argument,
68+
"Empty vocabulary entries provided");
69+
70+
return MIRVocabulary(std::move(Entries), TII);
71+
}
72+
7073
std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
7174
// Extract base instruction name using regex to capture letters and
7275
// underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
@@ -107,13 +110,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
107110
}
108111

109112
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
110-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
111113
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
112114
return getCanonicalIndexForBaseName(BaseOpcode);
113115
}
114116

115117
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
116-
assert(isValid() && "MIR2Vec Vocabulary is invalid");
117118
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
118119

119120
// For now, all entries are opcodes since we only have one section
@@ -232,16 +233,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
232233
return Error::success();
233234
}
234235

235-
void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
236-
Ctx.emitError(toString(std::move(Err)));
237-
}
238-
239-
mir2vec::MIRVocabulary
236+
Expected<mir2vec::MIRVocabulary>
240237
MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
241238
if (StrVocabMap.empty()) {
242239
if (Error Err = readVocabulary()) {
243-
emitError(std::move(Err), M.getContext());
244-
return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
240+
return std::move(Err);
245241
}
246242
}
247243

@@ -255,15 +251,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
255251

256252
if (auto *MF = MMI.getMachineFunction(F)) {
257253
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
258-
return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
254+
return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
259255
}
260256
}
261257

262-
// No machine functions available - return invalid vocabulary
263-
emitError(make_error<StringError>("No machine functions found in module",
264-
inconvertibleErrorCode()),
265-
M.getContext());
266-
return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
258+
// No machine functions available - return error
259+
return createStringError(errc::invalid_argument,
260+
"No machine functions found in module");
267261
}
268262

269263
//===----------------------------------------------------------------------===//
@@ -284,13 +278,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
284278

285279
bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
286280
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
287-
auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
281+
auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
288282

289-
if (!MIR2VecVocab.isValid()) {
290-
OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
283+
if (!MIR2VecVocabOrErr) {
284+
OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
285+
<< toString(MIR2VecVocabOrErr.takeError()) << "\n";
291286
return false;
292287
}
293288

289+
auto &MIR2VecVocab = *MIR2VecVocabOrErr;
294290
unsigned Pos = 0;
295291
for (const auto &Entry : MIR2VecVocab) {
296292
OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
; REQUIRES: x86_64-linux
2-
; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
3-
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
4-
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
5-
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
2+
; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
3+
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
4+
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
5+
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
66

77
define dso_local void @test() {
88
entry:
99
ret void
1010
}
1111

12-
; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
13-
; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero
14-
; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file
15-
; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension
12+
; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
13+
; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero
14+
; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file
15+
; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension

llvm/unittests/CodeGen/MIR2VecTest.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/IR/Module.h"
1818
#include "llvm/MC/TargetRegistry.h"
1919
#include "llvm/Support/TargetSelect.h"
20+
#include "llvm/Support/raw_ostream.h"
2021
#include "llvm/Target/TargetMachine.h"
2122
#include "llvm/Target/TargetOptions.h"
2223
#include "llvm/TargetParser/Triple.h"
@@ -52,7 +53,7 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
5253
std::unique_ptr<LLVMContext> Ctx;
5354
std::unique_ptr<Module> M;
5455
std::unique_ptr<TargetMachine> TM;
55-
const TargetInstrInfo *TII;
56+
const TargetInstrInfo *TII = nullptr;
5657

5758
static void SetUpTestCase() {
5859
InitializeAllTargets();
@@ -93,6 +94,8 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
9394
return;
9495
}
9596
}
97+
98+
void TearDown() override { TII = nullptr; }
9699
};
97100

98101
// Function to find an opcode by name
@@ -118,7 +121,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
118121
VocabMap VMap;
119122
Embedding Val = Embedding(64, 1.0f);
120123
VMap["ADD"] = Val;
121-
MIRVocabulary TestVocab(std::move(VMap), TII);
124+
auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
125+
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
126+
<< "Failed to create vocabulary: "
127+
<< toString(TestVocabOrErr.takeError());
128+
auto &TestVocab = *TestVocabOrErr;
122129

123130
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
124131
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -173,7 +180,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
173180
// Use a minimal MIRVocabulary to trigger canonical mapping construction
174181
VocabMap VMap;
175182
VMap["ADD"] = Embedding(64, 1.0f);
176-
MIRVocabulary TestVocab(std::move(VMap), TII);
183+
auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
184+
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
185+
<< "Failed to create vocabulary: "
186+
<< toString(TestVocabOrErr.takeError());
187+
auto &TestVocab = *TestVocabOrErr;
177188

178189
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
179190
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -195,8 +206,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
195206
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
196207
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
197208

198-
MIRVocabulary Vocab(std::move(VMap), TII);
199-
EXPECT_TRUE(Vocab.isValid());
209+
auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
210+
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
211+
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
212+
auto &Vocab = *VocabOrErr;
200213
EXPECT_EQ(Vocab.getDimension(), 128u);
201214

202215
// Test iterator - iterates over individual embeddings
@@ -214,4 +227,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
214227
EXPECT_GT(Count, 0u);
215228
}
216229

217-
} // namespace
230+
// Test factory method with empty vocabulary
231+
TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
232+
VocabMap EmptyVMap;
233+
234+
auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
235+
EXPECT_FALSE(static_cast<bool>(VocabOrErr))
236+
<< "Factory method should fail with empty vocabulary";
237+
238+
// Consume the error
239+
if (!VocabOrErr) {
240+
auto Err = VocabOrErr.takeError();
241+
std::string ErrorMsg = toString(std::move(Err));
242+
EXPECT_FALSE(ErrorMsg.empty());
243+
}
244+
}
245+
246+
} // namespace

0 commit comments

Comments
 (0)