Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,11 @@ class VocabStorage {
/// Section-based storage
std::vector<std::vector<Embedding>> Sections;

const size_t TotalSize;
const unsigned Dimension;
// Fixme: Check if these members can be made const (and delete move
// assignment) after changing Vocabulary creation by using static factory
// methods.
size_t TotalSize = 0;
unsigned Dimension = 0;

public:
/// Default constructor creates empty storage (invalid state)
Expand All @@ -164,7 +167,7 @@ class VocabStorage {
VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);

VocabStorage(VocabStorage &&) = default;
VocabStorage &operator=(VocabStorage &&) = delete;
VocabStorage &operator=(VocabStorage &&) = default;

VocabStorage(const VocabStorage &) = delete;
VocabStorage &operator=(const VocabStorage &) = delete;
Expand Down
181 changes: 181 additions & 0 deletions llvm/include/llvm/CodeGen/MIR2Vec.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
//===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions. See the LICENSE file for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file defines the MIR2Vec vocabulary
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
/// for generating Machine IR embeddings, and related utilities.
///
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
/// LLVM Machine IR as embeddings which can be used as input to machine learning
/// algorithms.
///
/// The original idea of MIR2Vec is described in the following paper:
///
/// RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
/// Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
/// Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
/// Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
/// Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
/// https://arxiv.org/abs/2204.02013
///
//===----------------------------------------------------------------------===//

#ifndef LLVM_CODEGEN_MIR2VEC_H
#define LLVM_CODEGEN_MIR2VEC_H

#include "llvm/Analysis/IR2Vec.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include <map>
#include <set>
#include <string>

namespace llvm {

class Module;
class raw_ostream;
class LLVMContext;
class MIR2VecVocabLegacyAnalysis;
class TargetInstrInfo;

namespace mir2vec {
extern llvm::cl::OptionCategory MIR2VecCategory;
extern cl::opt<float> OpcWeight;

using Embedding = ir2vec::Embedding;

/// Class for storing and accessing the MIR2Vec vocabulary.
/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
class MIRVocabulary {
friend class llvm::MIR2VecVocabLegacyAnalysis;
using VocabMap = std::map<std::string, ir2vec::Embedding>;

private:
// Define vocabulary layout - adapted for MIR
struct {
size_t OpcodeBase = 0;
size_t OperandBase = 0;
size_t TotalEntries = 0;
} Layout;

ir2vec::VocabStorage Storage;
mutable std::set<std::string> UniqueBaseOpcodeNames;
void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);

public:
/// Static helper method for extracting base opcode names (public for testing)
static std::string extractBaseOpcodeName(StringRef InstrName);

/// Helper method for getting canonical index for base name (public for
/// testing)
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;

/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;

MIRVocabulary() = default;
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}

bool isValid() const {
return UniqueBaseOpcodeNames.size() > 0 &&
Layout.TotalEntries == Storage.size() && Storage.isValid();
}

unsigned getDimension() const {
if (!isValid())
return 0;
return Storage.getDimension();
}

// Accessor methods
const Embedding &operator[](unsigned Index) const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
assert(Index < Layout.TotalEntries && "Index out of bounds");
// Fixme: For now, use section 0 for all entries
return Storage[0][Index];
}

// Iterator access
using const_iterator = ir2vec::VocabStorage::const_iterator;
const_iterator begin() const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.begin();
}

const_iterator end() const {
assert(isValid() && "MIR2Vec Vocabulary is invalid");
return Storage.end();
}

/// Total number of entries in the vocabulary
size_t getCanonicalSize() const {
assert(isValid() && "Invalid vocabulary");
return Storage.size();
}
};

} // namespace mir2vec

/// Pass to analyze and populate MIR2Vec vocabulary from a module
class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
using VocabVector = std::vector<mir2vec::Embedding>;
using VocabMap = std::map<std::string, mir2vec::Embedding>;
VocabMap StrVocabMap;
VocabVector Vocab;

StringRef getPassName() const override;
Error readVocabulary();
void emitError(Error Err, LLVMContext &Ctx);

protected:
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MachineModuleInfoWrapperPass>();
AU.setPreservesAll();
}

public:
static char ID;
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
};

/// This pass prints the embeddings in the MIR2Vec vocabulary
class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
raw_ostream &OS;

public:
static char ID;
explicit MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
: MachineFunctionPass(ID), OS(OS) {}

bool runOnMachineFunction(MachineFunction &MF) override;
bool doFinalization(Module &M) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<MIR2VecVocabLegacyAnalysis>();
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
}

StringRef getPassName() const override {
return "MIR2Vec Vocabulary Printer Pass";
}
};

} // namespace llvm

#endif // LLVM_CODEGEN_MIR2VEC_H
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ LLVM_ABI MachineFunctionPass *
createMachineFunctionPrinterPass(raw_ostream &OS,
const std::string &Banner = "");

/// MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary
/// contents to the given stream as a debugging tool.
LLVM_ABI MachineFunctionPass *
createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);

/// StackFramePrinter pass - This pass prints out the machine function's
/// stack frame to the given stream as a debugging tool.
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ LLVM_ABI void initializeMachinePostDominatorTreeWrapperPassPass(PassRegistry &);
LLVM_ABI void initializeMachineRegionInfoPassPass(PassRegistry &);
LLVM_ABI void
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);
Expand Down
Loading