Skip to content

Commit 9f24293

Browse files
svkeerthyaokblast
authored andcommitted
[MIR2Vec] Handle Operands (llvm#163281)
Handling opcodes in embedding computation. - Revamped MIR Vocabulary with four sections - `Opcodes`, `Common Operands`, `Physical Registers`, and `Virtual Registers` - Operands broadly fall into 3 categories -- the generic MO types that are common across architectures, physical and virtual register classes. We handle these categories separately in MIR2Vec. (Though we have same classes for both physical and virtual registers, their embeddings vary).
1 parent 3dc6fc2 commit 9f24293

File tree

13 files changed

+1422
-166
lines changed

13 files changed

+1422
-166
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 120 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "llvm/CodeGen/MachineFunctionPass.h"
3636
#include "llvm/CodeGen/MachineInstr.h"
3737
#include "llvm/CodeGen/MachineModuleInfo.h"
38+
#include "llvm/CodeGen/MachineOperand.h"
39+
#include "llvm/CodeGen/MachineRegisterInfo.h"
3840
#include "llvm/IR/PassManager.h"
3941
#include "llvm/Pass.h"
4042
#include "llvm/Support/CommandLine.h"
@@ -61,7 +63,7 @@ class MIREmbedder;
6163
class SymbolicMIREmbedder;
6264

6365
extern llvm::cl::OptionCategory MIR2VecCategory;
64-
extern cl::opt<float> OpcWeight;
66+
extern cl::opt<float> OpcWeight, CommonOperandWeight, RegOperandWeight;
6567

6668
using Embedding = ir2vec::Embedding;
6769
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
@@ -74,31 +76,114 @@ class MIRVocabulary {
7476
friend class llvm::MIR2VecVocabLegacyAnalysis;
7577
using VocabMap = std::map<std::string, ir2vec::Embedding>;
7678

77-
private:
78-
// Define vocabulary layout - adapted for MIR
79+
// MIRVocabulary Layout:
80+
// +-------------------+-----------------------------------------------------+
81+
// | Entity Type | Description |
82+
// +-------------------+-----------------------------------------------------+
83+
// | 1. Opcodes | Target specific opcodes derived from TII, grouped |
84+
// | | by instruction semantics. |
85+
// | 2. Common Operands| All common operand types, except register operands, |
86+
// | | defined by MachineOperand::MachineOperandType enum. |
87+
// | 3. Physical | Register classes defined by the target, specialized |
88+
// | Reg classes | by physical registers. |
89+
// | 4. Virtual | Register classes defined by the target, specialized |
90+
// | Reg classes | by virtual and physical registers. |
91+
// +-------------------+-----------------------------------------------------+
92+
93+
/// Layout information for the MIR vocabulary. Defines the starting index
94+
/// and size of each section in the vocabulary.
7995
struct {
8096
size_t OpcodeBase = 0;
81-
size_t OperandBase = 0;
97+
size_t CommonOperandBase = 0;
98+
size_t PhyRegBase = 0;
99+
size_t VirtRegBase = 0;
82100
size_t TotalEntries = 0;
83101
} Layout;
84102

85-
enum class Section : unsigned { Opcodes = 0, MaxSections };
103+
enum class Section : unsigned {
104+
Opcodes = 0,
105+
CommonOperands = 1,
106+
PhyRegisters = 2,
107+
VirtRegisters = 3,
108+
MaxSections
109+
};
86110

87111
ir2vec::VocabStorage Storage;
88-
mutable std::set<std::string> UniqueBaseOpcodeNames;
112+
std::set<std::string> UniqueBaseOpcodeNames;
113+
SmallVector<std::string, 24> RegisterOperandNames;
114+
115+
// Some instructions have optional register operands that may be NoRegister.
116+
// We return a zero vector in such cases.
117+
Embedding ZeroEmbedding;
118+
119+
// We have specialized MO_Register handling in the Register operand section,
120+
// so we don't include it here. Also, no MO_DbgInstrRef for now.
121+
static constexpr StringLiteral CommonOperandNames[] = {
122+
"Immediate", "CImmediate", "FPImmediate", "MBB",
123+
"FrameIndex", "ConstantPoolIndex", "TargetIndex", "JumpTableIndex",
124+
"ExternalSymbol", "GlobalAddress", "BlockAddress", "RegisterMask",
125+
"RegisterLiveOut", "Metadata", "MCSymbol", "CFIIndex",
126+
"IntrinsicID", "Predicate", "ShuffleMask"};
127+
static_assert(std::size(CommonOperandNames) == MachineOperand::MO_Last - 1 &&
128+
"Common operand names size changed, update accordingly");
129+
89130
const TargetInstrInfo &TII;
90-
void generateStorage(const VocabMap &OpcodeMap);
131+
const TargetRegisterInfo &TRI;
132+
const MachineRegisterInfo &MRI;
133+
134+
void generateStorage(const VocabMap &OpcodeMap,
135+
const VocabMap &CommonOperandMap,
136+
const VocabMap &PhyRegMap, const VocabMap &VirtRegMap);
91137
void buildCanonicalOpcodeMapping();
138+
void buildRegisterOperandMapping();
92139

93140
/// Get canonical index for a machine opcode
94141
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
95142

143+
/// Get index for a common (non-register) machine operand
144+
unsigned
145+
getCommonOperandIndex(MachineOperand::MachineOperandType OperandType) const;
146+
147+
/// Get index for a register machine operand
148+
unsigned getRegisterOperandIndex(Register Reg) const;
149+
150+
// Accessors for operand types
151+
const Embedding &
152+
operator[](MachineOperand::MachineOperandType OperandType) const {
153+
unsigned LocalIndex = getCommonOperandIndex(OperandType);
154+
return Storage[static_cast<unsigned>(Section::CommonOperands)][LocalIndex];
155+
}
156+
157+
const Embedding &operator[](Register Reg) const {
158+
// Reg is sometimes NoRegister (0) for optional operands. We return a zero
159+
// vector in this case.
160+
if (!Reg.isValid())
161+
return ZeroEmbedding;
162+
// TODO: Implement proper stack slot handling for MIR2Vec embeddings.
163+
// Stack slots represent frame indices and should have their own
164+
// embedding strategy rather than defaulting to register class 0.
165+
// Consider: 1) Separate vocabulary section for stack slots
166+
// 2) Stack slot size/alignment based embeddings
167+
// 3) Frame index based categorization
168+
if (Reg.isStack())
169+
return ZeroEmbedding;
170+
171+
unsigned LocalIndex = getRegisterOperandIndex(Reg);
172+
auto SectionID =
173+
Reg.isPhysical() ? Section::PhyRegisters : Section::VirtRegisters;
174+
return Storage[static_cast<unsigned>(SectionID)][LocalIndex];
175+
}
176+
96177
public:
97178
/// Static method for extracting base opcode names (public for testing)
98179
static std::string extractBaseOpcodeName(StringRef InstrName);
99180

100-
/// Get canonical index for base name (public for testing)
181+
/// Get indices from opcode or operand names. These are public for testing.
182+
/// String based lookups are inefficient and should be avoided in general.
101183
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
184+
unsigned getCanonicalIndexForOperandName(StringRef OperandName) const;
185+
unsigned getCanonicalIndexForRegisterClass(StringRef RegName,
186+
bool IsPhysical = true) const;
102187

103188
/// Get the string key for a vocabulary entry at the given position
104189
std::string getStringKey(unsigned Pos) const;
@@ -111,6 +196,14 @@ class MIRVocabulary {
111196
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
112197
}
113198

199+
const Embedding &operator[](MachineOperand Operand) const {
200+
auto OperandType = Operand.getType();
201+
if (OperandType == MachineOperand::MO_Register)
202+
return operator[](Operand.getReg());
203+
else
204+
return operator[](OperandType);
205+
}
206+
114207
// Iterator access
115208
using const_iterator = ir2vec::VocabStorage::const_iterator;
116209
const_iterator begin() const { return Storage.begin(); }
@@ -120,18 +213,25 @@ class MIRVocabulary {
120213
MIRVocabulary() = delete;
121214

122215
/// Factory method to create MIRVocabulary from vocabulary map
123-
static Expected<MIRVocabulary> create(VocabMap &&Entries,
124-
const TargetInstrInfo &TII);
216+
static Expected<MIRVocabulary>
217+
create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap,
218+
VocabMap &&VirtRegMap, const TargetInstrInfo &TII,
219+
const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI);
125220

126221
/// Create a dummy vocabulary for testing purposes.
127222
static Expected<MIRVocabulary>
128-
createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1);
223+
createDummyVocabForTest(const TargetInstrInfo &TII,
224+
const TargetRegisterInfo &TRI,
225+
const MachineRegisterInfo &MRI, unsigned Dim = 1);
129226

130227
/// Total number of entries in the vocabulary
131228
size_t getCanonicalSize() const { return Storage.size(); }
132229

133230
private:
134-
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
231+
MIRVocabulary(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap,
232+
VocabMap &&PhyRegMap, VocabMap &&VirtRegMap,
233+
const TargetInstrInfo &TII, const TargetRegisterInfo &TRI,
234+
const MachineRegisterInfo &MRI);
135235
};
136236

137237
/// Base class for MIR embedders
@@ -144,11 +244,13 @@ class MIREmbedder {
144244
const unsigned Dimension;
145245

146246
/// Weight for opcode embeddings
147-
const float OpcWeight;
247+
const float OpcWeight, CommonOperandWeight, RegOperandWeight;
148248

149249
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
150250
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
151-
OpcWeight(mir2vec::OpcWeight) {}
251+
OpcWeight(mir2vec::OpcWeight),
252+
CommonOperandWeight(mir2vec::CommonOperandWeight),
253+
RegOperandWeight(mir2vec::RegOperandWeight) {}
152254

153255
/// Function to compute embeddings.
154256
Embedding computeEmbeddings() const;
@@ -208,11 +310,11 @@ class SymbolicMIREmbedder : public MIREmbedder {
208310
class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
209311
using VocabVector = std::vector<mir2vec::Embedding>;
210312
using VocabMap = std::map<std::string, mir2vec::Embedding>;
211-
VocabMap StrVocabMap;
212-
VocabVector Vocab;
313+
std::optional<mir2vec::MIRVocabulary> Vocab;
213314

214315
StringRef getPassName() const override;
215-
Error readVocabulary();
316+
Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab,
317+
VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);
216318

217319
protected:
218320
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -275,4 +377,4 @@ MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
275377

276378
} // namespace llvm
277379

278-
#endif // LLVM_CODEGEN_MIR2VEC_H
380+
#endif // LLVM_CODEGEN_MIR2VEC_H

0 commit comments

Comments
 (0)