Skip to content

Commit 3c77b49

Browse files
authored
[MIR2Vec] Add embedder for machine instructions (#162161)
Implement MIR2Vec embedder for generating vector representations of Machine IR instructions, basic blocks, and functions. This patch introduces changes necessary to *embed* machine opcodes. Machine operands would be handled incrementally in the upcoming patches.
1 parent 7287816 commit 3c77b49

File tree

11 files changed

+808
-39
lines changed

11 files changed

+808
-39
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,21 @@ class LLVMContext;
5252
class MIR2VecVocabLegacyAnalysis;
5353
class TargetInstrInfo;
5454

55+
enum class MIR2VecKind { Symbolic };
56+
5557
namespace mir2vec {
58+
59+
// Forward declarations
60+
class MIREmbedder;
61+
class SymbolicMIREmbedder;
62+
5663
extern llvm::cl::OptionCategory MIR2VecCategory;
5764
extern cl::opt<float> OpcWeight;
5865

5966
using Embedding = ir2vec::Embedding;
67+
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
68+
using MachineBlockEmbeddingsMap =
69+
DenseMap<const MachineBasicBlock *, Embedding>;
6070

6171
/// Class for storing and accessing the MIR2Vec vocabulary.
6272
/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
@@ -107,19 +117,91 @@ class MIRVocabulary {
107117

108118
const_iterator end() const { return Storage.end(); }
109119

110-
/// Total number of entries in the vocabulary
111-
size_t getCanonicalSize() const { return Storage.size(); }
112-
113120
MIRVocabulary() = delete;
114121

115122
/// Factory method to create MIRVocabulary from vocabulary map
116123
static Expected<MIRVocabulary> create(VocabMap &&Entries,
117124
const TargetInstrInfo &TII);
118125

126+
/// Create a dummy vocabulary for testing purposes.
127+
static Expected<MIRVocabulary>
128+
createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1);
129+
130+
/// Total number of entries in the vocabulary
131+
size_t getCanonicalSize() const { return Storage.size(); }
132+
119133
private:
120134
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
121135
};
122136

137+
/// Base class for MIR embedders
138+
class MIREmbedder {
139+
protected:
140+
const MachineFunction &MF;
141+
const MIRVocabulary &Vocab;
142+
143+
/// Dimension of the embeddings; Captured from the vocabulary
144+
const unsigned Dimension;
145+
146+
/// Weight for opcode embeddings
147+
const float OpcWeight;
148+
149+
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
150+
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
151+
OpcWeight(mir2vec::OpcWeight) {}
152+
153+
/// Function to compute embeddings.
154+
Embedding computeEmbeddings() const;
155+
156+
/// Function to compute the embedding for a given machine basic block.
157+
Embedding computeEmbeddings(const MachineBasicBlock &MBB) const;
158+
159+
/// Function to compute the embedding for a given machine instruction.
160+
/// Specific to the kind of embeddings being computed.
161+
virtual Embedding computeEmbeddings(const MachineInstr &MI) const = 0;
162+
163+
public:
164+
virtual ~MIREmbedder() = default;
165+
166+
/// Factory method to create an Embedder object of the specified kind
167+
/// Returns nullptr if the requested kind is not supported.
168+
static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode,
169+
const MachineFunction &MF,
170+
const MIRVocabulary &Vocab);
171+
172+
/// Computes and returns the embedding for a given machine instruction MI in
173+
/// the machine function MF.
174+
Embedding getMInstVector(const MachineInstr &MI) const {
175+
return computeEmbeddings(MI);
176+
}
177+
178+
/// Computes and returns the embedding for a given machine basic block in the
179+
/// machine function MF.
180+
Embedding getMBBVector(const MachineBasicBlock &MBB) const {
181+
return computeEmbeddings(MBB);
182+
}
183+
184+
/// Computes and returns the embedding for the current machine function.
185+
Embedding getMFunctionVector() const {
186+
// Currently, we always (re)compute the embeddings for the function. This is
187+
// cheaper than caching the vector.
188+
return computeEmbeddings();
189+
}
190+
};
191+
192+
/// Class for computing Symbolic embeddings
193+
/// Symbolic embeddings are constructed based on the entity-level
194+
/// representations obtained from the MIR Vocabulary.
195+
class SymbolicMIREmbedder : public MIREmbedder {
196+
private:
197+
Embedding computeEmbeddings(const MachineInstr &MI) const override;
198+
199+
public:
200+
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab);
201+
static std::unique_ptr<SymbolicMIREmbedder>
202+
create(const MachineFunction &MF, const MIRVocabulary &Vocab);
203+
};
204+
123205
} // namespace mir2vec
124206

125207
/// Pass to analyze and populate MIR2Vec vocabulary from a module
@@ -166,6 +248,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
166248
}
167249
};
168250

251+
/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks,
252+
/// and instructions
253+
class MIR2VecPrinterLegacyPass : public MachineFunctionPass {
254+
raw_ostream &OS;
255+
256+
public:
257+
static char ID;
258+
explicit MIR2VecPrinterLegacyPass(raw_ostream &OS)
259+
: MachineFunctionPass(ID), OS(OS) {}
260+
261+
bool runOnMachineFunction(MachineFunction &MF) override;
262+
void getAnalysisUsage(AnalysisUsage &AU) const override {
263+
AU.addRequired<MIR2VecVocabLegacyAnalysis>();
264+
AU.setPreservesAll();
265+
MachineFunctionPass::getAnalysisUsage(AU);
266+
}
267+
268+
StringRef getPassName() const override {
269+
return "MIR2Vec Embedder Printer Pass";
270+
}
271+
};
272+
273+
/// Create a machine pass that prints MIR2Vec embeddings
274+
MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
275+
169276
} // namespace llvm
170277

171278
#endif // LLVM_CODEGEN_MIR2VEC_H

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ createMachineFunctionPrinterPass(raw_ostream &OS,
9393
LLVM_ABI MachineFunctionPass *
9494
createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
9595

96+
/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for
97+
/// machine functions, basic blocks and instructions.
98+
LLVM_ABI MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
99+
96100
/// StackFramePrinter pass - This pass prints out the machine function's
97101
/// stack frame to the given stream as a debugging tool.
98102
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ LLVM_ABI void
222222
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
223223
LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
224224
LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
225+
LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &);
225226
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
226227
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
227228
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);

llvm/lib/CodeGen/CodeGen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
9898
initializeMachineUniformityAnalysisPassPass(Registry);
9999
initializeMIR2VecVocabLegacyAnalysisPass(Registry);
100100
initializeMIR2VecVocabPrinterLegacyPassPass(Registry);
101+
initializeMIR2VecPrinterLegacyPassPass(Registry);
101102
initializeMachineUniformityInfoPrinterPassPass(Registry);
102103
initializeMachineVerifierLegacyPassPass(Registry);
103104
initializeObjCARCContractLegacyPassPass(Registry);

llvm/lib/CodeGen/MIR2Vec.cpp

Lines changed: 159 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "llvm/CodeGen/MIR2Vec.h"
15+
#include "llvm/ADT/DepthFirstIterator.h"
1516
#include "llvm/ADT/Statistic.h"
1617
#include "llvm/CodeGen/TargetInstrInfo.h"
1718
#include "llvm/IR/Module.h"
@@ -29,20 +30,30 @@ using namespace mir2vec;
2930
STATISTIC(MIRVocabMissCounter,
3031
"Number of lookups to MIR entities not present in the vocabulary");
3132

32-
cl::OptionCategory llvm::mir2vec::MIR2VecCategory("MIR2Vec Options");
33+
namespace llvm {
34+
namespace mir2vec {
35+
cl::OptionCategory MIR2VecCategory("MIR2Vec Options");
3336

3437
// FIXME: Use a default vocab when not specified
3538
static cl::opt<std::string>
3639
VocabFile("mir2vec-vocab-path", cl::Optional,
3740
cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""),
3841
cl::cat(MIR2VecCategory));
39-
cl::opt<float>
40-
llvm::mir2vec::OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
41-
cl::desc("Weight for machine opcode embeddings"),
42-
cl::cat(MIR2VecCategory));
42+
cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
43+
cl::desc("Weight for machine opcode embeddings"),
44+
cl::cat(MIR2VecCategory));
45+
cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
46+
"mir2vec-kind", cl::Optional,
47+
cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
48+
"Generate symbolic embeddings for MIR")),
49+
cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
50+
cl::cat(MIR2VecCategory));
51+
52+
} // namespace mir2vec
53+
} // namespace llvm
4354

4455
//===----------------------------------------------------------------------===//
45-
// Vocabulary Implementation
56+
// Vocabulary
4657
//===----------------------------------------------------------------------===//
4758

4859
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
@@ -188,6 +199,28 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
188199
<< " unique base opcodes\n");
189200
}
190201

202+
Expected<MIRVocabulary>
203+
MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
204+
unsigned Dim) {
205+
assert(Dim > 0 && "Dimension must be greater than zero");
206+
207+
float DummyVal = 0.1f;
208+
209+
// Create dummy embeddings for all canonical opcode names
210+
VocabMap DummyVocabMap;
211+
for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
212+
std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
213+
if (DummyVocabMap.count(BaseOpcode) == 0) {
214+
// Only add if not already present
215+
DummyVocabMap[BaseOpcode] = Embedding(Dim, DummyVal);
216+
DummyVal += 0.1f;
217+
}
218+
}
219+
220+
// Create and return vocabulary with dummy embeddings
221+
return MIRVocabulary::create(std::move(DummyVocabMap), TII);
222+
}
223+
191224
//===----------------------------------------------------------------------===//
192225
// MIR2VecVocabLegacyAnalysis Implementation
193226
//===----------------------------------------------------------------------===//
@@ -258,7 +291,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
258291
}
259292

260293
//===----------------------------------------------------------------------===//
261-
// Printer Passes Implementation
294+
// MIREmbedder and its subclasses
295+
//===----------------------------------------------------------------------===//
296+
297+
std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
298+
const MachineFunction &MF,
299+
const MIRVocabulary &Vocab) {
300+
switch (Mode) {
301+
case MIR2VecKind::Symbolic:
302+
return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
303+
}
304+
return nullptr;
305+
}
306+
307+
Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const {
308+
Embedding MBBVector(Dimension, 0);
309+
310+
// Get instruction info for opcode name resolution
311+
const auto &Subtarget = MF.getSubtarget();
312+
const auto *TII = Subtarget.getInstrInfo();
313+
if (!TII) {
314+
MF.getFunction().getContext().emitError(
315+
"MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
316+
return MBBVector;
317+
}
318+
319+
// Process each machine instruction in the basic block
320+
for (const auto &MI : MBB) {
321+
// Skip debug instructions and other metadata
322+
if (MI.isDebugInstr())
323+
continue;
324+
MBBVector += computeEmbeddings(MI);
325+
}
326+
327+
return MBBVector;
328+
}
329+
330+
Embedding MIREmbedder::computeEmbeddings() const {
331+
Embedding MFuncVector(Dimension, 0);
332+
333+
// Consider all reachable machine basic blocks in the function
334+
for (const auto *MBB : depth_first(&MF))
335+
MFuncVector += computeEmbeddings(*MBB);
336+
return MFuncVector;
337+
}
338+
339+
SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
340+
const MIRVocabulary &Vocab)
341+
: MIREmbedder(MF, Vocab) {}
342+
343+
std::unique_ptr<SymbolicMIREmbedder>
344+
SymbolicMIREmbedder::create(const MachineFunction &MF,
345+
const MIRVocabulary &Vocab) {
346+
return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
347+
}
348+
349+
Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const {
350+
// Skip debug instructions and other metadata
351+
if (MI.isDebugInstr())
352+
return Embedding(Dimension, 0);
353+
354+
// Todo: Add operand/argument contributions
355+
356+
return Vocab[MI.getOpcode()];
357+
}
358+
359+
//===----------------------------------------------------------------------===//
360+
// Printer Passes
262361
//===----------------------------------------------------------------------===//
263362

264363
char MIR2VecVocabPrinterLegacyPass::ID = 0;
@@ -297,3 +396,56 @@ MachineFunctionPass *
297396
llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
298397
return new MIR2VecVocabPrinterLegacyPass(OS);
299398
}
399+
400+
char MIR2VecPrinterLegacyPass::ID = 0;
401+
INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
402+
"MIR2Vec Embedder Printer Pass", false, true)
403+
INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
404+
INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
405+
INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
406+
"MIR2Vec Embedder Printer Pass", false, true)
407+
408+
bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
409+
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
410+
auto VocabOrErr =
411+
Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
412+
assert(VocabOrErr && "Failed to get MIR2Vec vocabulary");
413+
auto &MIRVocab = *VocabOrErr;
414+
415+
auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
416+
if (!Emb) {
417+
OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
418+
<< "\n";
419+
return false;
420+
}
421+
422+
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
423+
OS << "Machine Function vector: ";
424+
Emb->getMFunctionVector().print(OS);
425+
426+
OS << "Machine basic block vectors:\n";
427+
for (const MachineBasicBlock &MBB : MF) {
428+
OS << "Machine basic block: " << MBB.getFullName() << ":\n";
429+
Emb->getMBBVector(MBB).print(OS);
430+
}
431+
432+
OS << "Machine instruction vectors:\n";
433+
for (const MachineBasicBlock &MBB : MF) {
434+
for (const MachineInstr &MI : MBB) {
435+
// Skip debug instructions as they are not
436+
// embedded
437+
if (MI.isDebugInstr())
438+
continue;
439+
440+
OS << "Machine instruction: ";
441+
MI.print(OS);
442+
Emb->getMInstVector(MI).print(OS);
443+
}
444+
}
445+
446+
return false;
447+
}
448+
449+
MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
450+
return new MIR2VecPrinterLegacyPass(OS);
451+
}

0 commit comments

Comments
 (0)