Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 119 additions & 17 deletions llvm/include/llvm/CodeGen/MIR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
Expand All @@ -61,7 +63,7 @@ class MIREmbedder;
class SymbolicMIREmbedder;

extern llvm::cl::OptionCategory MIR2VecCategory;
extern cl::opt<float> OpcWeight;
extern cl::opt<float> OpcWeight, CommonOperandWeight, RegOperandWeight;

using Embedding = ir2vec::Embedding;
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
Expand All @@ -74,31 +76,114 @@ class MIRVocabulary {
friend class llvm::MIR2VecVocabLegacyAnalysis;
using VocabMap = std::map<std::string, ir2vec::Embedding>;

private:
// Define vocabulary layout - adapted for MIR
// MIRVocabulary Layout:
// +-------------------+-----------------------------------------------------+
// | Entity Type | Description |
// +-------------------+-----------------------------------------------------+
// | 1. Opcodes | Target specific opcodes derived from TII, grouped |
// | | by instruction semantics. |
// | 2. Common Operands| All common operand types, except register operands, |
// | | defined by MachineOperand::MachineOperandType enum. |
// | 3. Physical | Register classes defined by the target, specialized |
// | Reg classes | by physical registers. |
// | 4. Virtual | Register classes defined by the target, specialized |
// | Reg classes | by virtual and physical registers. |
// +-------------------+-----------------------------------------------------+

/// Layout information for the MIR vocabulary. Defines the starting index
/// and size of each section in the vocabulary.
struct {
size_t OpcodeBase = 0;
size_t OperandBase = 0;
size_t CommonOperandBase = 0;
size_t PhyRegBase = 0;
size_t VirtRegBase = 0;
size_t TotalEntries = 0;
} Layout;

enum class Section : unsigned { Opcodes = 0, MaxSections };
enum class Section : unsigned {
Opcodes = 0,
CommonOperands = 1,
PhyRegisters = 2,
VirtRegisters = 3,
MaxSections
};

ir2vec::VocabStorage Storage;
mutable std::set<std::string> UniqueBaseOpcodeNames;
mutable SmallVector<std::string, 24> RegisterOperandNames;

// Some instructions have optional register operands that may be NoRegister.
// We return a zero vector in such cases.
mutable Embedding ZeroEmbedding;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like a const, not a mutable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem with making it as const is that Dimension is not known till Storage is created in the end of the constructor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so why does it need to be mutable then?


// We have specialized MO_Register handling in the Register operand section,
// so we don't include it here. Also, no MO_DbgInstrRef for now.
static constexpr StringLiteral CommonOperandNames[] = {
"Immediate", "CImmediate", "FPImmediate", "MBB",
"FrameIndex", "ConstantPoolIndex", "TargetIndex", "JumpTableIndex",
"ExternalSymbol", "GlobalAddress", "BlockAddress", "RegisterMask",
"RegisterLiveOut", "Metadata", "MCSymbol", "CFIIndex",
"IntrinsicID", "Predicate", "ShuffleMask"};
static_assert(std::size(CommonOperandNames) == MachineOperand::MO_Last - 1 &&
"Common operand names size changed, update accordingly");

const TargetInstrInfo &TII;
void generateStorage(const VocabMap &OpcodeMap);
const TargetRegisterInfo &TRI;
const MachineRegisterInfo &MRI;

void generateStorage(const VocabMap &OpcodeMap,
const VocabMap &CommonOperandMap,
const VocabMap &PhyRegMap, const VocabMap &VirtRegMap);
void buildCanonicalOpcodeMapping();
void buildRegisterOperandMapping();

/// Get canonical index for a machine opcode
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;

/// Get index for a common (non-register) machine operand
unsigned
getCommonOperandIndex(MachineOperand::MachineOperandType OperandType) const;

/// Get index for a register machine operand
unsigned getRegisterOperandIndex(Register Reg) const;

// Accessors for operand types
const Embedding &
operator[](MachineOperand::MachineOperandType OperandType) const {
unsigned LocalIndex = getCommonOperandIndex(OperandType);
return Storage[static_cast<unsigned>(Section::CommonOperands)][LocalIndex];
}

const Embedding &operator[](Register Reg) const {
// Reg is sometimes NoRegister (0) for optional operands. We return a zero
// vector in this case.
if (!Reg.isValid())
return ZeroEmbedding;
// TODO: Implement proper stack slot handling for MIR2Vec embeddings.
// Stack slots represent frame indices and should have their own
// embedding strategy rather than defaulting to register class 0.
// Consider: 1) Separate vocabulary section for stack slots
// 2) Stack slot size/alignment based embeddings
// 3) Frame index based categorization
if (Reg.isStack())
return ZeroEmbedding;

unsigned LocalIndex = getRegisterOperandIndex(Reg);
auto SectionID =
Reg.isPhysical() ? Section::PhyRegisters : Section::VirtRegisters;
return Storage[static_cast<unsigned>(SectionID)][LocalIndex];
}

public:
/// Static method for extracting base opcode names (public for testing)
static std::string extractBaseOpcodeName(StringRef InstrName);

/// Get canonical index for base name (public for testing)
/// Get indices from opcode or operand names. These are public for testing.
/// String based lookups are inefficient and should be avoided in general.
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
unsigned getCanonicalIndexForOperandName(StringRef OperandName) const;
unsigned getCanonicalIndexForRegisterClass(StringRef RegName,
bool IsPhysical = true) const;

/// Get the string key for a vocabulary entry at the given position
std::string getStringKey(unsigned Pos) const;
Expand All @@ -111,6 +196,14 @@ class MIRVocabulary {
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
}

const Embedding &operator[](MachineOperand Operand) const {
auto OperandType = Operand.getType();
if (OperandType == MachineOperand::MO_Register)
return operator[](Operand.getReg());
else
return operator[](OperandType);
}

// Iterator access
using const_iterator = ir2vec::VocabStorage::const_iterator;
const_iterator begin() const { return Storage.begin(); }
Expand All @@ -120,18 +213,25 @@ class MIRVocabulary {
MIRVocabulary() = delete;

/// Factory method to create MIRVocabulary from vocabulary map
static Expected<MIRVocabulary> create(VocabMap &&Entries,
const TargetInstrInfo &TII);
static Expected<MIRVocabulary>
create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap,
VocabMap &&VirtRegMap, const TargetInstrInfo &TII,
const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI);

/// Create a dummy vocabulary for testing purposes.
static Expected<MIRVocabulary>
createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1);
createDummyVocabForTest(const TargetInstrInfo &TII,
const TargetRegisterInfo &TRI,
const MachineRegisterInfo &MRI, unsigned Dim = 1);

/// Total number of entries in the vocabulary
size_t getCanonicalSize() const { return Storage.size(); }

private:
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
MIRVocabulary(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap,
VocabMap &&PhyRegMap, VocabMap &&VirtRegMap,
const TargetInstrInfo &TII, const TargetRegisterInfo &TRI,
const MachineRegisterInfo &MRI);
};

/// Base class for MIR embedders
Expand All @@ -144,11 +244,13 @@ class MIREmbedder {
const unsigned Dimension;

/// Weight for opcode embeddings
const float OpcWeight;
const float OpcWeight, CommonOperandWeight, RegOperandWeight;

MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
OpcWeight(mir2vec::OpcWeight) {}
OpcWeight(mir2vec::OpcWeight),
CommonOperandWeight(mir2vec::CommonOperandWeight),
RegOperandWeight(mir2vec::RegOperandWeight) {}

/// Function to compute embeddings.
Embedding computeEmbeddings() const;
Expand Down Expand Up @@ -208,11 +310,11 @@ class SymbolicMIREmbedder : public MIREmbedder {
class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
using VocabVector = std::vector<mir2vec::Embedding>;
using VocabMap = std::map<std::string, mir2vec::Embedding>;
VocabMap StrVocabMap;
VocabVector Vocab;
std::optional<mir2vec::MIRVocabulary> Vocab;

StringRef getPassName() const override;
Error readVocabulary();
Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab,
VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);

protected:
void getAnalysisUsage(AnalysisUsage &AU) const override {
Expand Down Expand Up @@ -275,4 +377,4 @@ MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);

} // namespace llvm

#endif // LLVM_CODEGEN_MIR2VEC_H
#endif // LLVM_CODEGEN_MIR2VEC_H
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Re: line +381]

spurious change, or fixing an existing "no end of file newline" case?

See this comment inline on Graphite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing existing case

Loading