Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 15 additions & 30 deletions llvm/include/llvm/CodeGen/MIR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorOr.h"
#include <map>
#include <set>
Expand Down Expand Up @@ -92,46 +93,31 @@ class MIRVocabulary {
/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;

MIRVocabulary() = delete;
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
: Storage(std::move(Storage)), TII(TII) {}

bool isValid() const {
return UniqueBaseOpcodeNames.size() > 0 &&
Layout.TotalEntries == Storage.size() && Storage.isValid();
}

unsigned getDimension() const {
if (!isValid())
return 0;
return Storage.getDimension();
}
unsigned getDimension() const { return Storage.getDimension(); }

// Accessor methods
const Embedding &operator[](unsigned Opcode) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
}

// Iterator access
using const_iterator = ir2vec::VocabStorage::const_iterator;
const_iterator begin() const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.begin();
}
const_iterator begin() const { return Storage.begin(); }

const_iterator end() const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.end();
}
const_iterator end() const { return Storage.end(); }

/// Total number of entries in the vocabulary
size_t getCanonicalSize() const {
assert(isValid() && "Invalid vocabulary");
return Storage.size();
}
size_t getCanonicalSize() const { return Storage.size(); }

MIRVocabulary() = delete;

/// Factory method to create MIRVocabulary from vocabulary map
static Expected<MIRVocabulary> create(VocabMap &&Entries,
const TargetInstrInfo &TII);

private:
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
};

} // namespace mir2vec
Expand All @@ -145,7 +131,6 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {

StringRef getPassName() const override;
Error readVocabulary();
void emitError(Error Err, LLVMContext &Ctx);

protected:
void getAnalysisUsage(AnalysisUsage &AU) const override {
Expand All @@ -156,7 +141,7 @@ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
public:
static char ID;
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
Expected<mir2vec::MIRVocabulary> getMIR2VecVocabulary(const Module &M);
};

/// This pass prints the embeddings in the MIR2Vec vocabulary
Expand Down
48 changes: 22 additions & 26 deletions llvm/lib/CodeGen/MIR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,8 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
//===----------------------------------------------------------------------===//

MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
const TargetInstrInfo *TII)
: TII(*TII) {
// Fixme: Use static factory methods for creating vocabularies instead of
// public constructors
// Early return for invalid inputs - creates empty/invalid vocabulary
if (!TII || OpcodeEntries.empty())
return;

const TargetInstrInfo &TII)
: TII(TII) {
buildCanonicalOpcodeMapping();

unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
Expand All @@ -67,6 +61,15 @@ MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
Layout.TotalEntries = Storage.size();
}

Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
const TargetInstrInfo &TII) {
if (Entries.empty())
return createStringError(errc::invalid_argument,
"Empty vocabulary entries provided");

return MIRVocabulary(std::move(Entries), TII);
}

std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
// Extract base instruction name using regex to capture letters and
// underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
Expand Down Expand Up @@ -107,13 +110,11 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
}

unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
return getCanonicalIndexForBaseName(BaseOpcode);
}

std::string MIRVocabulary::getStringKey(unsigned Pos) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");

// For now, all entries are opcodes since we only have one section
Expand Down Expand Up @@ -232,16 +233,11 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
return Error::success();
}

void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
Ctx.emitError(toString(std::move(Err)));
}

mir2vec::MIRVocabulary
Expected<mir2vec::MIRVocabulary>
MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
if (StrVocabMap.empty()) {
if (Error Err = readVocabulary()) {
emitError(std::move(Err), M.getContext());
return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
return std::move(Err);
}
}

Expand All @@ -255,15 +251,13 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {

if (auto *MF = MMI.getMachineFunction(F)) {
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
}
}

// No machine functions available - return invalid vocabulary
emitError(make_error<StringError>("No machine functions found in module",
inconvertibleErrorCode()),
M.getContext());
return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
// No machine functions available - return error
return createStringError(errc::invalid_argument,
"No machine functions found in module");
}

//===----------------------------------------------------------------------===//
Expand All @@ -284,13 +278,15 @@ bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {

bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);

if (!MIR2VecVocab.isValid()) {
OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
if (!MIR2VecVocabOrErr) {
OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
<< toString(MIR2VecVocabOrErr.takeError()) << "\n";
return false;
}

auto &MIR2VecVocab = *MIR2VecVocabOrErr;
unsigned Pos = 0;
for (const auto &Entry : MIR2VecVocab) {
OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
Expand Down
16 changes: 8 additions & 8 deletions llvm/test/CodeGen/MIR2Vec/vocab-error-handling.ll
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
; REQUIRES: x86_64-linux
; RUN: not llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
; RUN: not llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS
; RUN: llc -o /dev/null -print-mir2vec-vocab %s 2>&1 | FileCheck %s --check-prefix=CHECK-INVALID
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_zero_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-ZERO-DIM
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_invalid_vocab.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ENTITIES
; RUN: llc -o /dev/null -print-mir2vec-vocab -mir2vec-vocab-path=%S/Inputs/mir2vec_inconsistent_dims.json %s 2>&1 | FileCheck %s --check-prefix=CHECK-INCONSISTENT-DIMS

define dso_local void @test() {
entry:
ret void
}

; CHECK-INVALID: error: MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
; CHECK-ZERO-DIM: error: Dimension of 'entities' section of the vocabulary is zero
; CHECK-NO-ENTITIES: error: Missing 'entities' section in vocabulary file
; CHECK-INCONSISTENT-DIMS: error: All vectors in the 'entities' section of the vocabulary are not of the same dimension
; CHECK-INVALID: MIR2Vec Vocabulary Printer: Failed to get vocabulary - MIR2Vec vocabulary file path not specified; set it using --mir2vec-vocab-path
; CHECK-ZERO-DIM: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Dimension of 'entities' section of the vocabulary is zero
; CHECK-NO-ENTITIES: MIR2Vec Vocabulary Printer: Failed to get vocabulary - Missing 'entities' section in vocabulary file
; CHECK-INCONSISTENT-DIMS: MIR2Vec Vocabulary Printer: Failed to get vocabulary - All vectors in the 'entities' section of the vocabulary are not of the same dimension
37 changes: 32 additions & 5 deletions llvm/unittests/CodeGen/MIR2VecTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
Expand Down Expand Up @@ -118,7 +119,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
VMap["ADD"] = Val;
MIRVocabulary TestVocab(std::move(VMap), TII);
auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
auto &TestVocab = *TestVocabOrErr;

unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
Expand Down Expand Up @@ -173,7 +178,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VMap;
VMap["ADD"] = Embedding(64, 1.0f);
MIRVocabulary TestVocab(std::move(VMap), TII);
auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
auto &TestVocab = *TestVocabOrErr;

unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
Expand All @@ -195,8 +204,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0

MIRVocabulary Vocab(std::move(VMap), TII);
EXPECT_TRUE(Vocab.isValid());
auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
EXPECT_EQ(Vocab.getDimension(), 128u);

// Test iterator - iterates over individual embeddings
Expand All @@ -214,4 +225,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
EXPECT_GT(Count, 0u);
}

} // namespace
// Test factory method with empty vocabulary
TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
VocabMap EmptyVMap;

auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
EXPECT_FALSE(static_cast<bool>(VocabOrErr))
<< "Factory method should fail with empty vocabulary";

// Consume the error
if (!VocabOrErr) {
auto Err = VocabOrErr.takeError();
std::string ErrorMsg = toString(std::move(Err));
EXPECT_FALSE(ErrorMsg.empty());
}
}

} // namespace