99// / \file
1010// / This file implements the IR2Vec embedding generation tool.
1111// /
12- // / Currently supports triplet generation for vocabulary training.
13- // / Future updates will support embedding generation using trained vocabulary.
12+ // / This tool provides two main functionalities:
1413// /
15- // / Usage: llvm-ir2vec input.bc -o triplets.txt
14+ // / 1. Triplet Generation Mode (--mode=triplets):
15+ // / Generates triplets (opcode, type, operands) for vocabulary training.
16+ // / Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt
1617// /
17- // / TODO: Add embedding generation mode with vocabulary support
18+ // / 2. Embedding Generation Mode (--mode=embeddings):
19+ // / Generates IR2Vec embeddings using a trained vocabulary.
20+ // / Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json
21+ // / --level=func input.bc -o embeddings.txt Levels: --level=inst
22+ // / (instructions), --level=bb (basic blocks), --level=func (functions)
23+ // / (See IR2Vec.cpp for more embedding generation options)
1824// /
1925// ===----------------------------------------------------------------------===//
2026
2430#include " llvm/IR/Instructions.h"
2531#include " llvm/IR/LLVMContext.h"
2632#include " llvm/IR/Module.h"
33+ #include " llvm/IR/PassInstrumentation.h"
34+ #include " llvm/IR/PassManager.h"
2735#include " llvm/IR/Type.h"
2836#include " llvm/IRReader/IRReader.h"
2937#include " llvm/Support/CommandLine.h"
3341#include " llvm/Support/SourceMgr.h"
3442#include " llvm/Support/raw_ostream.h"
3543
36- using namespace llvm ;
37- using namespace ir2vec ;
38-
3944#define DEBUG_TYPE " ir2vec"
4045
46+ namespace llvm {
47+ namespace ir2vec {
48+
4149static cl::OptionCategory IR2VecToolCategory (" IR2Vec Tool Options" );
4250
4351static cl::opt<std::string> InputFilename (cl::Positional,
@@ -50,16 +58,63 @@ static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
5058 cl::init(" -" ),
5159 cl::cat(IR2VecToolCategory));
5260
61+ enum ToolMode {
62+ TripletMode, // Generate triplets for vocabulary training
63+ EmbeddingMode // Generate embeddings using trained vocabulary
64+ };
65+
66+ static cl::opt<ToolMode>
67+ Mode (" mode" , cl::desc(" Tool operation mode:" ),
68+ cl::values(clEnumValN(TripletMode, " triplets" ,
69+ " Generate triplets for vocabulary training" ),
70+ clEnumValN(EmbeddingMode, " embeddings" ,
71+ " Generate embeddings using trained vocabulary" )),
72+ cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory));
73+
74+ static cl::opt<std::string>
75+ FunctionName (" function" , cl::desc(" Process specific function only" ),
76+ cl::value_desc(" name" ), cl::Optional, cl::init(" " ),
77+ cl::cat(IR2VecToolCategory));
78+
79+ enum EmbeddingLevel {
80+ InstructionLevel, // Generate instruction-level embeddings
81+ BasicBlockLevel, // Generate basic block-level embeddings
82+ FunctionLevel // Generate function-level embeddings
83+ };
84+
85+ static cl::opt<EmbeddingLevel>
86+ Level (" level" , cl::desc(" Embedding generation level (for embedding mode):" ),
87+ cl::values(clEnumValN(InstructionLevel, " inst" ,
88+ " Generate instruction-level embeddings" ),
89+ clEnumValN(BasicBlockLevel, " bb" ,
90+ " Generate basic block-level embeddings" ),
91+ clEnumValN(FunctionLevel, " func" ,
92+ " Generate function-level embeddings" )),
93+ cl::init(FunctionLevel), cl::cat(IR2VecToolCategory));
94+
5395namespace {
5496
55- // / Helper class for collecting IR information and generating triplets
97+ // / Helper class for collecting IR triplets and generating embeddings
5698class IR2VecTool {
5799private:
58100 Module &M;
101+ ModuleAnalysisManager MAM;
102+ const Vocabulary *Vocab = nullptr ;
59103
60104public:
61105 explicit IR2VecTool (Module &M) : M(M) {}
62106
107+ // / Initialize the IR2Vec vocabulary analysis
108+ bool initializeVocabulary () {
109+ // Register and run the IR2Vec vocabulary analysis
110+ // The vocabulary file path is specified via --ir2vec-vocab-path global
111+ // option
112+ MAM.registerPass ([&] { return PassInstrumentationAnalysis (); });
113+ MAM.registerPass ([&] { return IR2VecVocabAnalysis (); });
114+ Vocab = &MAM.getResult <IR2VecVocabAnalysis>(M);
115+ return Vocab->isValid ();
116+ }
117+
63118 // / Generate triplets for the entire module
64119 void generateTriplets (raw_ostream &OS) const {
65120 for (const Function &F : M)
@@ -81,6 +136,68 @@ class IR2VecTool {
81136 OS << LocalOutput;
82137 }
83138
139+ // / Generate embeddings for the entire module
140+ void generateEmbeddings (raw_ostream &OS) const {
141+ if (!Vocab->isValid ()) {
142+ OS << " Error: Vocabulary is not valid. IR2VecTool not initialized.\n " ;
143+ return ;
144+ }
145+
146+ for (const Function &F : M)
147+ generateEmbeddings (F, OS);
148+ }
149+
150+ // / Generate embeddings for a single function
151+ void generateEmbeddings (const Function &F, raw_ostream &OS) const {
152+ if (F.isDeclaration ()) {
153+ OS << " Function " << F.getName () << " is a declaration, skipping.\n " ;
154+ return ;
155+ }
156+
157+ // Create embedder for this function
158+ assert (Vocab->isValid () && " Vocabulary is not valid" );
159+ auto Emb = Embedder::create (IR2VecKind::Symbolic, F, *Vocab);
160+ if (!Emb) {
161+ OS << " Error: Failed to create embedder for function " << F.getName ()
162+ << " \n " ;
163+ return ;
164+ }
165+
166+ OS << " Function: " << F.getName () << " \n " ;
167+
168+ // Generate embeddings based on the specified level
169+ switch (Level) {
170+ case FunctionLevel: {
171+ Emb->getFunctionVector ().print (OS);
172+ break ;
173+ }
174+ case BasicBlockLevel: {
175+ const auto &BBVecMap = Emb->getBBVecMap ();
176+ for (const BasicBlock &BB : F) {
177+ auto It = BBVecMap.find (&BB);
178+ if (It != BBVecMap.end ()) {
179+ OS << BB.getName () << " :" ;
180+ It->second .print (OS);
181+ }
182+ }
183+ break ;
184+ }
185+ case InstructionLevel: {
186+ const auto &InstMap = Emb->getInstVecMap ();
187+ for (const BasicBlock &BB : F) {
188+ for (const Instruction &I : BB) {
189+ auto It = InstMap.find (&I);
190+ if (It != InstMap.end ()) {
191+ I.print (OS);
192+ It->second .print (OS);
193+ }
194+ }
195+ }
196+ break ;
197+ }
198+ }
199+ }
200+
84201private:
85202 // / Process a single basic block for triplet generation
86203 void traverseBasicBlock (const BasicBlock &BB, raw_string_ostream &OS) const {
@@ -105,23 +222,70 @@ class IR2VecTool {
105222
106223Error processModule (Module &M, raw_ostream &OS) {
107224 IR2VecTool Tool (M);
108- Tool.generateTriplets (OS);
109225
226+ if (Mode == EmbeddingMode) {
227+ // Initialize vocabulary for embedding generation
228+ // Note: Requires --ir2vec-vocab-path option to be set
229+ if (!Tool.initializeVocabulary ())
230+ return createStringError (
231+ errc::invalid_argument,
232+ " Failed to initialize IR2Vec vocabulary. "
233+ " Make sure to specify --ir2vec-vocab-path for embedding mode." );
234+
235+ if (!FunctionName.empty ()) {
236+ // Process single function
237+ if (const Function *F = M.getFunction (FunctionName))
238+ Tool.generateEmbeddings (*F, OS);
239+ else
240+ return createStringError (errc::invalid_argument,
241+ " Function '%s' not found" ,
242+ FunctionName.c_str ());
243+ } else {
244+ // Process all functions
245+ Tool.generateEmbeddings (OS);
246+ }
247+ } else {
248+ // Triplet generation mode - no vocabulary needed
249+ if (!FunctionName.empty ())
250+ // Process single function
251+ if (const Function *F = M.getFunction (FunctionName))
252+ Tool.generateTriplets (*F, OS);
253+ else
254+ return createStringError (errc::invalid_argument,
255+ " Function '%s' not found" ,
256+ FunctionName.c_str ());
257+ else
258+ // Process all functions
259+ Tool.generateTriplets (OS);
260+ }
110261 return Error::success ();
111262}
112-
113- } // anonymous namespace
263+ } // namespace
264+ } // namespace ir2vec
265+ } // namespace llvm
114266
115267int main (int argc, char **argv) {
268+ using namespace llvm ;
269+ using namespace llvm ::ir2vec;
270+
116271 InitLLVM X (argc, argv);
117272 cl::HideUnrelatedOptions (IR2VecToolCategory);
118273 cl::ParseCommandLineOptions (
119274 argc, argv,
120- " IR2Vec - Triplet Generation Tool\n "
121- " Generates triplets for vocabulary training from LLVM IR.\n "
122- " Future updates will support embedding generation.\n\n "
275+ " IR2Vec - Embedding Generation Tool\n "
276+ " Generates embeddings for a given LLVM IR and "
277+ " supports triplet generation for vocabulary "
278+ " training and embedding generation.\n\n "
123279 " Usage:\n "
124- " llvm-ir2vec input.bc -o triplets.txt\n " );
280+ " Triplet mode: llvm-ir2vec --mode=triplets input.bc\n "
281+ " Embedding mode: llvm-ir2vec --mode=embeddings "
282+ " --ir2vec-vocab-path=vocab.json --level=func input.bc\n "
283+ " Levels: --level=inst (instructions), --level=bb (basic blocks), "
284+ " --level=func (functions)\n " );
285+
286+ // Validate command line options
287+ if (Mode == TripletMode && Level.getNumOccurrences () > 0 )
288+ errs () << " Warning: --level option is ignored in triplet mode\n " ;
125289
126290 // Parse the input LLVM IR file
127291 SMDiagnostic Err;
0 commit comments