-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MIR2Vec] Added create factory methods for Vocabulary #162569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesAdded factory methods for vocabulary creation. This also would fix UB issue introduced by #161713 Full diff: https://github.com/llvm/llvm-project/pull/162569.diff 4 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
index ea68b4594a2ad..dbffede50df81 100644
--- a/llvm/include/llvm/CodeGen/MIR2Vec.h
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -38,6 +38,7 @@
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorOr.h"
#include <map>
#include <set>
@@ -92,25 +93,12 @@ class MIRVocabulary {
/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;
- MIRVocabulary() = delete;
- MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
- MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
- : Storage(std::move(Storage)), TII(TII) {}
-
- bool isValid() const {
- return UniqueBaseOpcodeNames.size() > 0 &&
- Layout.TotalEntries == Storage.size() && Storage.isValid();
- }
-
unsigned getDimension() const {
- if (!isValid())
- return 0;
return Storage.getDimension();
}
// Accessor methods
const Embedding &operator[](unsigned Opcode) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
}
@@ -118,20 +106,30 @@ class MIRVocabulary {
// Iterator access
using const_iterator = ir2vec::VocabStorage::const_iterator;
const_iterator begin() const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.begin();
}
const_iterator end() const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.end();
}
/// Total number of entries in the vocabulary
size_t getCanonicalSize() const {
- assert(isValid() && "Invalid vocabulary");
return Storage.size();
}
+
+ MIRVocabulary() = delete;
+
+ /// Factory method to create MIRVocabulary from vocabulary map
+ static Expected<MIRVocabulary> create(VocabMap &&Entries, const TargetInstrInfo &TII);
+
+ /// Factory method to create MIRVocabulary from existing storage
+ static Expected<MIRVocabulary> create(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII);
+
+private:
+ MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
+ MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
+ : Storage(std::move(Storage)), TII(TII) {}
};
} // namespace mir2vec
@@ -145,7 +143,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
StringRef getPassName() const override;
Error readVocabulary();
- void emitError(Error Err, LLVMContext &Ctx);
protected:
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -156,7 +153,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
public:
static char ID;
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
- mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
+ Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
};
/// This pass prints the embeddings in the MIR2Vec vocabulary
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 87565c0c77115..669c11d5f739c 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -49,14 +49,8 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
//===----------------------------------------------------------------------===//
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
- const TargetInstrInfo *TII)
- : TII(*TII) {
- // Fixme: Use static factory methods for creating vocabularies instead of
- // public constructors
- // Early return for invalid inputs - creates empty/invalid vocabulary
- if (!TII || OpcodeEntries.empty())
- return;
-
+ const TargetInstrInfo &TII)
+ : TII(TII) {
buildCanonicalOpcodeMapping();
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
@@ -67,6 +61,24 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
Layout.TotalEntries = Storage.size();
}
+Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
+ const TargetInstrInfo &TII) {
+ if (Entries.empty())
+ return createStringError(errc::invalid_argument,
+ "Empty vocabulary entries provided");
+
+ return MIRVocabulary(std::move(Entries), TII);
+}
+
+Expected<MIRVocabulary> MIRVocabulary::create(ir2vec::VocabStorage &&Storage,
+ const TargetInstrInfo &TII) {
+ if (!Storage.isValid())
+ return createStringError(errc::invalid_argument,
+ "Invalid vocabulary storage provided");
+
+ return MIRVocabulary(std::move(Storage), TII);
+}
+
std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
// Extract base instruction name using regex to capture letters and
// underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
@@ -107,13 +119,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
}
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
return getCanonicalIndexForBaseName(BaseOpcode);
}
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
- assert(isValid() && "MIR2Vec Vocabulary is invalid");
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
// For now, all entries are opcodes since we only have one section
@@ -232,16 +242,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
return Error::success();
}
-void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
- Ctx.emitError(toString(std::move(Err)));
-}
-
-mir2vec::MIRVocabulary
+Expected<mir2vec::MIRVocabulary>
MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
if (StrVocabMap.empty()) {
if (Error Err = readVocabulary()) {
- emitError(std::move(Err), M.getContext());
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+ return std::move(Err);
}
}
@@ -255,15 +260,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
if (auto *MF = MMI.getMachineFunction(F)) {
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
+ return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
}
}
- // No machine functions available - return invalid vocabulary
- emitError(make_error<StringError>("No machine functions found in module",
- inconvertibleErrorCode()),
- M.getContext());
- return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
+ // No machine functions available - return error
+ return createStringError(errc::invalid_argument,
+ "No machine functions found in module");
}
//===----------------------------------------------------------------------===//
@@ -284,13 +287,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
- auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
+ auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
- if (!MIR2VecVocab.isValid()) {
- OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
+ if (!MIR2VecVocabOrErr) {
+ OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
+ << toString(MIR2VecVocabOrErr.takeError()) << "\n";
return false;
}
+ auto &MIR2VecVocab = *MIR2VecVocabOrErr;
unsigned Pos = 0;
for (const auto &Entry : MIR2VecVocab) {
OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
diff --git a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
index 1da516a6cd3b9..80b4048cea0c3 100644
--- a/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
+++ b/llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
@@ -1,15 +1,15 @@
; REQUIRES: x86_64-linux
-; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
-; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
+; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
+; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
define dso_local void @test() {
entry:
ret void
}
-; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
-; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero
-; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file
-; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension
+; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
+; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero
+; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file
+; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index d243d82c73fc7..269e3b515c6fc 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
@@ -118,7 +119,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
VMap["ADD"] = Val;
- MIRVocabulary TestVocab(std::move(VMap), TII);
+ auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+ << "Failed to create vocabulary: "
+ << toString(TestVocabOrErr.takeError());
+ auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -173,7 +178,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VMap;
VMap["ADD"] = Embedding(64, 1.0f);
- MIRVocabulary TestVocab(std::move(VMap), TII);
+ auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+ << "Failed to create vocabulary: "
+ << toString(TestVocabOrErr.takeError());
+ auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -195,8 +204,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
- MIRVocabulary Vocab(std::move(VMap), TII);
- EXPECT_TRUE(Vocab.isValid());
+ auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
EXPECT_EQ(Vocab.getDimension(), 128u);
// Test iterator - iterates over individual embeddings
@@ -214,4 +225,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
EXPECT_GT(Count, 0u);
}
+// Test factory method with empty vocabulary
+TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
+ VocabMap EmptyVMap;
+
+ auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
+ EXPECT_FALSE(static_cast<bool>(VocabOrErr))
+ << "Factory method should fail with empty vocabulary";
+
+ // Consume the error
+ if (!VocabOrErr) {
+ auto Err = VocabOrErr.takeError();
+ std::string ErrorMsg = toString(std::move(Err));
+ EXPECT_FALSE(ErrorMsg.empty());
+ }
+}
+
} // namespace
\ No newline at end of file
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds factory methods for MIRVocabulary creation and improves error handling by replacing direct constructor usage with factory methods that return Expected. It also fixes undefined behavior issues introduced by a previous PR.
- Factory methods
MIRVocabulary::create()
replace direct constructor calls - Error handling is improved with proper Expected<> return types instead of invalid object creation
- Test output format is updated to reflect new error handling patterns
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
llvm/include/llvm/CodeGen/MIR2Vec.h | Added factory method declarations and made constructors private |
llvm/lib/CodeGen/MIR2Vec.cpp | Implemented factory methods and updated error handling logic |
llvm/unittests/CodeGen/MIR2VecTest.cpp | Updated tests to use factory methods and added error handling validation |
llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll | Updated test expectations to match new error message format |
✅ With the latest revision this PR passed the C/C++ code formatter. |
std::unique_ptr<LLVMContext> Ctx; | ||
std::unique_ptr<Module> M; | ||
std::unique_ptr<TargetMachine> TM; | ||
const TargetInstrInfo *TII; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init to nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also add a teardown where you reset it to nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SetUp() method either assigns a valid value to TII or skips the test entirely. Also, TII will be destroyed when TM is destroyed automatically after each cycle right? Trying to understand if explicitly setting it to nullptr make a difference..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just simpler state to track and maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm after addressing comments.
Added factory methods for vocabulary creation. This also would fix UB issue introduced by #161713
Added factory methods for vocabulary creation. This also would fix UB issue introduced by llvm#161713
Added factory methods for vocabulary creation. This also would fix UB issue introduced by llvm#161713
Added factory methods for vocabulary creation. This also would fix UB issue introduced by llvm#161713
Added factory methods for vocabulary creation. This also would fix UB issue introduced by #161713