-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MIR2Vec] Handle Operands #163281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/svkeerthy/10-06-mir2vec_embedding
Are you sure you want to change the base?
[MIR2Vec] Handle Operands #163281
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,8 @@ | |
#include "llvm/CodeGen/MachineFunctionPass.h" | ||
#include "llvm/CodeGen/MachineInstr.h" | ||
#include "llvm/CodeGen/MachineModuleInfo.h" | ||
#include "llvm/CodeGen/MachineOperand.h" | ||
#include "llvm/CodeGen/MachineRegisterInfo.h" | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/Pass.h" | ||
#include "llvm/Support/CommandLine.h" | ||
|
@@ -61,7 +63,7 @@ class MIREmbedder; | |
class SymbolicMIREmbedder; | ||
|
||
extern llvm::cl::OptionCategory MIR2VecCategory; | ||
extern cl::opt<float> OpcWeight; | ||
extern cl::opt<float> OpcWeight, CommonOperandWeight, RegOperandWeight; | ||
|
||
using Embedding = ir2vec::Embedding; | ||
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>; | ||
|
@@ -74,31 +76,114 @@ class MIRVocabulary { | |
friend class llvm::MIR2VecVocabLegacyAnalysis; | ||
using VocabMap = std::map<std::string, ir2vec::Embedding>; | ||
|
||
private: | ||
// Define vocabulary layout - adapted for MIR | ||
// MIRVocabulary Layout: | ||
// +-------------------+-----------------------------------------------------+ | ||
// | Entity Type | Description | | ||
// +-------------------+-----------------------------------------------------+ | ||
// | 1. Opcodes | Target specific opcodes derived from TII, grouped | | ||
// | | by instruction semantics. | | ||
// | 2. Common Operands| All common operand types, except register operands, | | ||
// | | defined by MachineOperand::MachineOperandType enum. | | ||
// | 3. Physical | Register classes defined by the target, specialized | | ||
// | Reg classes | by physical registers. | | ||
// | 4. Virtual | Register classes defined by the target, specialized | | ||
// | Reg classes | by virtual and physical registers. | | ||
// +-------------------+-----------------------------------------------------+ | ||
|
||
/// Layout information for the MIR vocabulary. Defines the starting index | ||
/// and size of each section in the vocabulary. | ||
struct { | ||
size_t OpcodeBase = 0; | ||
size_t OperandBase = 0; | ||
size_t CommonOperandBase = 0; | ||
size_t PhyRegBase = 0; | ||
size_t VirtRegBase = 0; | ||
size_t TotalEntries = 0; | ||
} Layout; | ||
|
||
enum class Section : unsigned { Opcodes = 0, MaxSections }; | ||
enum class Section : unsigned { | ||
Opcodes = 0, | ||
CommonOperands = 1, | ||
PhyRegisters = 2, | ||
VirtRegisters = 3, | ||
MaxSections | ||
}; | ||
|
||
ir2vec::VocabStorage Storage; | ||
mutable std::set<std::string> UniqueBaseOpcodeNames; | ||
mutable SmallVector<std::string, 24> RegisterOperandNames; | ||
|
||
// Some instructions have optional register operands that may be NoRegister. | ||
// We return a zero vector in such cases. | ||
mutable Embedding ZeroEmbedding; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds like a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Problem with making it as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so why does it need to be mutable then? |
||
|
||
// We have specialized MO_Register handling in the Register operand section, | ||
// so we don't include it here. Also, no MO_DbgInstrRef for now. | ||
static constexpr StringLiteral CommonOperandNames[] = { | ||
"Immediate", "CImmediate", "FPImmediate", "MBB", | ||
"FrameIndex", "ConstantPoolIndex", "TargetIndex", "JumpTableIndex", | ||
"ExternalSymbol", "GlobalAddress", "BlockAddress", "RegisterMask", | ||
"RegisterLiveOut", "Metadata", "MCSymbol", "CFIIndex", | ||
"IntrinsicID", "Predicate", "ShuffleMask"}; | ||
static_assert(std::size(CommonOperandNames) == MachineOperand::MO_Last - 1 && | ||
"Common operand names size changed, update accordingly"); | ||
|
||
const TargetInstrInfo &TII; | ||
void generateStorage(const VocabMap &OpcodeMap); | ||
const TargetRegisterInfo &TRI; | ||
const MachineRegisterInfo &MRI; | ||
|
||
void generateStorage(const VocabMap &OpcodeMap, | ||
const VocabMap &CommonOperandMap, | ||
const VocabMap &PhyRegMap, const VocabMap &VirtRegMap); | ||
void buildCanonicalOpcodeMapping(); | ||
void buildRegisterOperandMapping(); | ||
|
||
/// Get canonical index for a machine opcode | ||
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const; | ||
|
||
/// Get index for a common (non-register) machine operand | ||
unsigned | ||
getCommonOperandIndex(MachineOperand::MachineOperandType OperandType) const; | ||
|
||
/// Get index for a register machine operand | ||
unsigned getRegisterOperandIndex(Register Reg) const; | ||
|
||
// Accessors for operand types | ||
const Embedding & | ||
operator[](MachineOperand::MachineOperandType OperandType) const { | ||
unsigned LocalIndex = getCommonOperandIndex(OperandType); | ||
return Storage[static_cast<unsigned>(Section::CommonOperands)][LocalIndex]; | ||
} | ||
|
||
const Embedding &operator[](Register Reg) const { | ||
// Reg is sometimes NoRegister (0) for optional operands. We return a zero | ||
// vector in this case. | ||
if (!Reg.isValid()) | ||
return ZeroEmbedding; | ||
// TODO: Implement proper stack slot handling for MIR2Vec embeddings. | ||
// Stack slots represent frame indices and should have their own | ||
// embedding strategy rather than defaulting to register class 0. | ||
// Consider: 1) Separate vocabulary section for stack slots | ||
// 2) Stack slot size/alignment based embeddings | ||
// 3) Frame index based categorization | ||
if (Reg.isStack()) | ||
return ZeroEmbedding; | ||
|
||
unsigned LocalIndex = getRegisterOperandIndex(Reg); | ||
auto SectionID = | ||
Reg.isPhysical() ? Section::PhyRegisters : Section::VirtRegisters; | ||
return Storage[static_cast<unsigned>(SectionID)][LocalIndex]; | ||
} | ||
|
||
public: | ||
/// Static method for extracting base opcode names (public for testing) | ||
static std::string extractBaseOpcodeName(StringRef InstrName); | ||
|
||
/// Get canonical index for base name (public for testing) | ||
/// Get indices from opcode or operand names. These are public for testing. | ||
/// String based lookups are inefficient and should be avoided in general. | ||
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const; | ||
unsigned getCanonicalIndexForOperandName(StringRef OperandName) const; | ||
unsigned getCanonicalIndexForRegisterClass(StringRef RegName, | ||
bool IsPhysical = true) const; | ||
|
||
/// Get the string key for a vocabulary entry at the given position | ||
std::string getStringKey(unsigned Pos) const; | ||
|
@@ -111,6 +196,14 @@ class MIRVocabulary { | |
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex]; | ||
} | ||
|
||
const Embedding &operator[](MachineOperand Operand) const { | ||
auto OperandType = Operand.getType(); | ||
if (OperandType == MachineOperand::MO_Register) | ||
return operator[](Operand.getReg()); | ||
else | ||
return operator[](OperandType); | ||
} | ||
|
||
// Iterator access | ||
using const_iterator = ir2vec::VocabStorage::const_iterator; | ||
const_iterator begin() const { return Storage.begin(); } | ||
|
@@ -120,18 +213,25 @@ class MIRVocabulary { | |
MIRVocabulary() = delete; | ||
|
||
/// Factory method to create MIRVocabulary from vocabulary map | ||
static Expected<MIRVocabulary> create(VocabMap &&Entries, | ||
const TargetInstrInfo &TII); | ||
static Expected<MIRVocabulary> | ||
create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap, | ||
VocabMap &&VirtRegMap, const TargetInstrInfo &TII, | ||
const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI); | ||
|
||
/// Create a dummy vocabulary for testing purposes. | ||
static Expected<MIRVocabulary> | ||
createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1); | ||
createDummyVocabForTest(const TargetInstrInfo &TII, | ||
const TargetRegisterInfo &TRI, | ||
const MachineRegisterInfo &MRI, unsigned Dim = 1); | ||
|
||
/// Total number of entries in the vocabulary | ||
size_t getCanonicalSize() const { return Storage.size(); } | ||
|
||
private: | ||
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII); | ||
MIRVocabulary(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, | ||
VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, | ||
const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, | ||
const MachineRegisterInfo &MRI); | ||
}; | ||
|
||
/// Base class for MIR embedders | ||
|
@@ -144,11 +244,13 @@ class MIREmbedder { | |
const unsigned Dimension; | ||
|
||
/// Weight for opcode embeddings | ||
const float OpcWeight; | ||
const float OpcWeight, CommonOperandWeight, RegOperandWeight; | ||
|
||
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab) | ||
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()), | ||
OpcWeight(mir2vec::OpcWeight) {} | ||
OpcWeight(mir2vec::OpcWeight), | ||
CommonOperandWeight(mir2vec::CommonOperandWeight), | ||
RegOperandWeight(mir2vec::RegOperandWeight) {} | ||
|
||
/// Function to compute embeddings. | ||
Embedding computeEmbeddings() const; | ||
|
@@ -208,11 +310,11 @@ class SymbolicMIREmbedder : public MIREmbedder { | |
class MIR2VecVocabLegacyAnalysis : public ImmutablePass { | ||
using VocabVector = std::vector<mir2vec::Embedding>; | ||
using VocabMap = std::map<std::string, mir2vec::Embedding>; | ||
VocabMap StrVocabMap; | ||
VocabVector Vocab; | ||
std::optional<mir2vec::MIRVocabulary> Vocab; | ||
|
||
StringRef getPassName() const override; | ||
Error readVocabulary(); | ||
Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab, | ||
VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap); | ||
|
||
protected: | ||
void getAnalysisUsage(AnalysisUsage &AU) const override { | ||
|
@@ -275,4 +377,4 @@ MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS); | |
|
||
} // namespace llvm | ||
|
||
#endif // LLVM_CODEGEN_MIR2VEC_H | ||
#endif // LLVM_CODEGEN_MIR2VEC_H | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Re: line +381] spurious change, or fixing an existing "no end of file newline" case? See this comment inline on Graphite. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixing existing case |
Uh oh!
There was an error while loading. Please reload this page.