1212// ===----------------------------------------------------------------------===//
1313
1414#include " llvm/CodeGen/MIR2Vec.h"
15+ #include " llvm/ADT/DepthFirstIterator.h"
1516#include " llvm/ADT/Statistic.h"
1617#include " llvm/CodeGen/TargetInstrInfo.h"
1718#include " llvm/IR/Module.h"
@@ -29,20 +30,30 @@ using namespace mir2vec;
2930STATISTIC (MIRVocabMissCounter,
3031 " Number of lookups to MIR entities not present in the vocabulary" );
3132
32- cl::OptionCategory llvm::mir2vec::MIR2VecCategory (" MIR2Vec Options" );
33+ namespace llvm {
34+ namespace mir2vec {
35+ cl::OptionCategory MIR2VecCategory (" MIR2Vec Options" );
3336
3437// FIXME: Use a default vocab when not specified
3538static cl::opt<std::string>
3639 VocabFile (" mir2vec-vocab-path" , cl::Optional,
3740 cl::desc (" Path to the vocabulary file for MIR2Vec" ), cl::init(" " ),
3841 cl::cat(MIR2VecCategory));
39- cl::opt<float >
40- llvm::mir2vec::OpcWeight (" mir2vec-opc-weight" , cl::Optional, cl::init(1.0 ),
41- cl::desc(" Weight for machine opcode embeddings" ),
42- cl::cat(MIR2VecCategory));
42+ cl::opt<float > OpcWeight (" mir2vec-opc-weight" , cl::Optional, cl::init(1.0 ),
43+ cl::desc(" Weight for machine opcode embeddings" ),
44+ cl::cat(MIR2VecCategory));
45+ cl::opt<MIR2VecKind> MIR2VecEmbeddingKind (
46+ " mir2vec-kind" , cl::Optional,
47+ cl::values (clEnumValN(MIR2VecKind::Symbolic, " symbolic" ,
48+ " Generate symbolic embeddings for MIR" )),
49+ cl::init(MIR2VecKind::Symbolic), cl::desc(" MIR2Vec embedding kind" ),
50+ cl::cat(MIR2VecCategory));
51+
52+ } // namespace mir2vec
53+ } // namespace llvm
4354
4455// ===----------------------------------------------------------------------===//
45- // Vocabulary Implementation
56+ // Vocabulary
4657// ===----------------------------------------------------------------------===//
4758
4859MIRVocabulary::MIRVocabulary (VocabMap &&OpcodeEntries,
@@ -188,6 +199,28 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
188199 << " unique base opcodes\n " );
189200}
190201
202+ Expected<MIRVocabulary>
203+ MIRVocabulary::createDummyVocabForTest (const TargetInstrInfo &TII,
204+ unsigned Dim) {
205+ assert (Dim > 0 && " Dimension must be greater than zero" );
206+
207+ float DummyVal = 0 .1f ;
208+
209+ // Create dummy embeddings for all canonical opcode names
210+ VocabMap DummyVocabMap;
211+ for (unsigned Opcode = 0 ; Opcode < TII.getNumOpcodes (); ++Opcode) {
212+ std::string BaseOpcode = extractBaseOpcodeName (TII.getName (Opcode));
213+ if (DummyVocabMap.count (BaseOpcode) == 0 ) {
214+ // Only add if not already present
215+ DummyVocabMap[BaseOpcode] = Embedding (Dim, DummyVal);
216+ DummyVal += 0 .1f ;
217+ }
218+ }
219+
220+ // Create and return vocabulary with dummy embeddings
221+ return MIRVocabulary::create (std::move (DummyVocabMap), TII);
222+ }
223+
191224// ===----------------------------------------------------------------------===//
192225// MIR2VecVocabLegacyAnalysis Implementation
193226// ===----------------------------------------------------------------------===//
@@ -258,7 +291,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
258291}
259292
260293// ===----------------------------------------------------------------------===//
261- // Printer Passes Implementation
294+ // MIREmbedder and its subclasses
295+ // ===----------------------------------------------------------------------===//
296+
297+ std::unique_ptr<MIREmbedder> MIREmbedder::create (MIR2VecKind Mode,
298+ const MachineFunction &MF,
299+ const MIRVocabulary &Vocab) {
300+ switch (Mode) {
301+ case MIR2VecKind::Symbolic:
302+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
303+ }
304+ return nullptr ;
305+ }
306+
307+ Embedding MIREmbedder::computeEmbeddings (const MachineBasicBlock &MBB) const {
308+ Embedding MBBVector (Dimension, 0 );
309+
310+ // Get instruction info for opcode name resolution
311+ const auto &Subtarget = MF.getSubtarget ();
312+ const auto *TII = Subtarget.getInstrInfo ();
313+ if (!TII) {
314+ MF.getFunction ().getContext ().emitError (
315+ " MIR2Vec: No TargetInstrInfo available; cannot compute embeddings" );
316+ return MBBVector;
317+ }
318+
319+ // Process each machine instruction in the basic block
320+ for (const auto &MI : MBB) {
321+ // Skip debug instructions and other metadata
322+ if (MI.isDebugInstr ())
323+ continue ;
324+ MBBVector += computeEmbeddings (MI);
325+ }
326+
327+ return MBBVector;
328+ }
329+
330+ Embedding MIREmbedder::computeEmbeddings () const {
331+ Embedding MFuncVector (Dimension, 0 );
332+
333+ // Consider all reachable machine basic blocks in the function
334+ for (const auto *MBB : depth_first (&MF))
335+ MFuncVector += computeEmbeddings (*MBB);
336+ return MFuncVector;
337+ }
338+
339+ SymbolicMIREmbedder::SymbolicMIREmbedder (const MachineFunction &MF,
340+ const MIRVocabulary &Vocab)
341+ : MIREmbedder(MF, Vocab) {}
342+
343+ std::unique_ptr<SymbolicMIREmbedder>
344+ SymbolicMIREmbedder::create (const MachineFunction &MF,
345+ const MIRVocabulary &Vocab) {
346+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
347+ }
348+
349+ Embedding SymbolicMIREmbedder::computeEmbeddings (const MachineInstr &MI) const {
350+ // Skip debug instructions and other metadata
351+ if (MI.isDebugInstr ())
352+ return Embedding (Dimension, 0 );
353+
354+ // Todo: Add operand/argument contributions
355+
356+ return Vocab[MI.getOpcode ()];
357+ }
358+
359+ // ===----------------------------------------------------------------------===//
360+ // Printer Passes
262361// ===----------------------------------------------------------------------===//
263362
264363char MIR2VecVocabPrinterLegacyPass::ID = 0 ;
@@ -297,3 +396,56 @@ MachineFunctionPass *
297396llvm::createMIR2VecVocabPrinterLegacyPass (raw_ostream &OS) {
298397 return new MIR2VecVocabPrinterLegacyPass (OS);
299398}
399+
400+ char MIR2VecPrinterLegacyPass::ID = 0 ;
401+ INITIALIZE_PASS_BEGIN (MIR2VecPrinterLegacyPass, " print-mir2vec" ,
402+ " MIR2Vec Embedder Printer Pass" , false , true )
403+ INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
404+ INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
405+ INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, " print-mir2vec" ,
406+ " MIR2Vec Embedder Printer Pass" , false , true )
407+
408+ bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
409+ auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
410+ auto VocabOrErr =
411+ Analysis.getMIR2VecVocabulary (*MF.getFunction ().getParent ());
412+ assert (VocabOrErr && " Failed to get MIR2Vec vocabulary" );
413+ auto &MIRVocab = *VocabOrErr;
414+
415+ auto Emb = mir2vec::MIREmbedder::create (MIR2VecEmbeddingKind, MF, MIRVocab);
416+ if (!Emb) {
417+ OS << " Error creating MIR2Vec embeddings for function " << MF.getName ()
418+ << " \n " ;
419+ return false ;
420+ }
421+
422+ OS << " MIR2Vec embeddings for machine function " << MF.getName () << " :\n " ;
423+ OS << " Machine Function vector: " ;
424+ Emb->getMFunctionVector ().print (OS);
425+
426+ OS << " Machine basic block vectors:\n " ;
427+ for (const MachineBasicBlock &MBB : MF) {
428+ OS << " Machine basic block: " << MBB.getFullName () << " :\n " ;
429+ Emb->getMBBVector (MBB).print (OS);
430+ }
431+
432+ OS << " Machine instruction vectors:\n " ;
433+ for (const MachineBasicBlock &MBB : MF) {
434+ for (const MachineInstr &MI : MBB) {
435+ // Skip debug instructions as they are not
436+ // embedded
437+ if (MI.isDebugInstr ())
438+ continue ;
439+
440+ OS << " Machine instruction: " ;
441+ MI.print (OS);
442+ Emb->getMInstVector (MI).print (OS);
443+ }
444+ }
445+
446+ return false ;
447+ }
448+
449+ MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass (raw_ostream &OS) {
450+ return new MIR2VecPrinterLegacyPass (OS);
451+ }
0 commit comments