| 
 | 1 | +//===- llvm-ir2vec.cpp - IR2Vec Embedding Generation Tool -----------------===//  | 
 | 2 | +//  | 
 | 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.  | 
 | 4 | +// See https://llvm.org/LICENSE.txt for license information.  | 
 | 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception  | 
 | 6 | +//  | 
 | 7 | +//===----------------------------------------------------------------------===//  | 
 | 8 | +///  | 
 | 9 | +/// \file  | 
 | 10 | +/// This file implements the IR2Vec embedding generation tool.  | 
 | 11 | +///  | 
 | 12 | +/// Currently supports triplet generation for vocabulary training.  | 
 | 13 | +/// Future updates will support embedding generation using trained vocabulary.  | 
 | 14 | +///  | 
 | 15 | +/// Usage: llvm-ir2vec input.bc -o triplets.txt  | 
 | 16 | +///  | 
 | 17 | +/// TODO: Add embedding generation mode with vocabulary support  | 
 | 18 | +///  | 
 | 19 | +//===----------------------------------------------------------------------===//  | 
 | 20 | + | 
 | 21 | +#include "llvm/Analysis/IR2Vec.h"  | 
 | 22 | +#include "llvm/IR/BasicBlock.h"  | 
 | 23 | +#include "llvm/IR/Function.h"  | 
 | 24 | +#include "llvm/IR/Instructions.h"  | 
 | 25 | +#include "llvm/IR/LLVMContext.h"  | 
 | 26 | +#include "llvm/IR/Module.h"  | 
 | 27 | +#include "llvm/IR/Type.h"  | 
 | 28 | +#include "llvm/IRReader/IRReader.h"  | 
 | 29 | +#include "llvm/Support/CommandLine.h"  | 
 | 30 | +#include "llvm/Support/Debug.h"  | 
 | 31 | +#include "llvm/Support/Errc.h"  | 
 | 32 | +#include "llvm/Support/InitLLVM.h"  | 
 | 33 | +#include "llvm/Support/SourceMgr.h"  | 
 | 34 | +#include "llvm/Support/raw_ostream.h"  | 
 | 35 | + | 
 | 36 | +using namespace llvm;  | 
 | 37 | +using namespace ir2vec;  | 
 | 38 | + | 
 | 39 | +#define DEBUG_TYPE "ir2vec"  | 
 | 40 | + | 
 | 41 | +static cl::OptionCategory IR2VecToolCategory("IR2Vec Tool Options");  | 
 | 42 | + | 
 | 43 | +static cl::opt<std::string> InputFilename(cl::Positional,  | 
 | 44 | +                                          cl::desc("<input bitcode file>"),  | 
 | 45 | +                                          cl::Required,  | 
 | 46 | +                                          cl::cat(IR2VecToolCategory));  | 
 | 47 | + | 
 | 48 | +static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),  | 
 | 49 | +                                           cl::value_desc("filename"),  | 
 | 50 | +                                           cl::init("-"),  | 
 | 51 | +                                           cl::cat(IR2VecToolCategory));  | 
 | 52 | + | 
 | 53 | +namespace {  | 
 | 54 | + | 
 | 55 | +/// Helper class for collecting IR information and generating triplets  | 
 | 56 | +class IR2VecTool {  | 
 | 57 | +private:  | 
 | 58 | +  Module &M;  | 
 | 59 | + | 
 | 60 | +public:  | 
 | 61 | +  explicit IR2VecTool(Module &M) : M(M) {}  | 
 | 62 | + | 
 | 63 | +  /// Generate triplets for the entire module  | 
 | 64 | +  void generateTriplets(raw_ostream &OS) const {  | 
 | 65 | +    for (const Function &F : M)  | 
 | 66 | +      generateTriplets(F, OS);  | 
 | 67 | +  }  | 
 | 68 | + | 
 | 69 | +  /// Generate triplets for a single function  | 
 | 70 | +  void generateTriplets(const Function &F, raw_ostream &OS) const {  | 
 | 71 | +    if (F.isDeclaration())  | 
 | 72 | +      return;  | 
 | 73 | + | 
 | 74 | +    std::string LocalOutput;  | 
 | 75 | +    raw_string_ostream LocalOS(LocalOutput);  | 
 | 76 | + | 
 | 77 | +    for (const BasicBlock &BB : F)  | 
 | 78 | +      traverseBasicBlock(BB, LocalOS);  | 
 | 79 | + | 
 | 80 | +    LocalOS.flush();  | 
 | 81 | +    OS << LocalOutput;  | 
 | 82 | +  }  | 
 | 83 | + | 
 | 84 | +private:  | 
 | 85 | +  /// Process a single basic block for triplet generation  | 
 | 86 | +  void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const {  | 
 | 87 | +    // Consider only non-debug and non-pseudo instructions  | 
 | 88 | +    for (const auto &I : BB.instructionsWithoutDebug()) {  | 
 | 89 | +      StringRef OpcStr = Vocabulary::getVocabKeyForOpcode(I.getOpcode());  | 
 | 90 | +      StringRef TypeStr =  | 
 | 91 | +          Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID());  | 
 | 92 | + | 
 | 93 | +      OS << '\n' << OpcStr << ' ' << TypeStr << ' ';  | 
 | 94 | + | 
 | 95 | +      LLVM_DEBUG(I.print(dbgs()); dbgs() << "\n");  | 
 | 96 | +      LLVM_DEBUG(I.getType()->print(dbgs()); dbgs() << " Type\n");  | 
 | 97 | + | 
 | 98 | +      for (const Use &U : I.operands())  | 
 | 99 | +        OS << Vocabulary::getVocabKeyForOperandKind(  | 
 | 100 | +                  Vocabulary::getOperandKind(U.get()))  | 
 | 101 | +           << ' ';  | 
 | 102 | +    }  | 
 | 103 | +  }  | 
 | 104 | +};  | 
 | 105 | + | 
 | 106 | +Error processModule(Module &M, raw_ostream &OS) {  | 
 | 107 | +  IR2VecTool Tool(M);  | 
 | 108 | +  Tool.generateTriplets(OS);  | 
 | 109 | + | 
 | 110 | +  return Error::success();  | 
 | 111 | +}  | 
 | 112 | + | 
 | 113 | +} // anonymous namespace  | 
 | 114 | + | 
 | 115 | +int main(int argc, char **argv) {  | 
 | 116 | +  InitLLVM X(argc, argv);  | 
 | 117 | +  cl::HideUnrelatedOptions(IR2VecToolCategory);  | 
 | 118 | +  cl::ParseCommandLineOptions(  | 
 | 119 | +      argc, argv,  | 
 | 120 | +      "IR2Vec - Triplet Generation Tool\n"  | 
 | 121 | +      "Generates triplets for vocabulary training from LLVM IR.\n"  | 
 | 122 | +      "Future updates will support embedding generation.\n\n"  | 
 | 123 | +      "Usage:\n"  | 
 | 124 | +      "  llvm-ir2vec input.bc -o triplets.txt\n");  | 
 | 125 | + | 
 | 126 | +  // Parse the input LLVM IR file  | 
 | 127 | +  SMDiagnostic Err;  | 
 | 128 | +  LLVMContext Context;  | 
 | 129 | +  std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);  | 
 | 130 | +  if (!M) {  | 
 | 131 | +    Err.print(argv[0], errs());  | 
 | 132 | +    return 1;  | 
 | 133 | +  }  | 
 | 134 | + | 
 | 135 | +  std::error_code EC;  | 
 | 136 | +  raw_fd_ostream OS(OutputFilename, EC);  | 
 | 137 | +  if (EC) {  | 
 | 138 | +    errs() << "Error opening output file: " << EC.message() << "\n";  | 
 | 139 | +    return 1;  | 
 | 140 | +  }  | 
 | 141 | + | 
 | 142 | +  if (Error Err = processModule(*M, OS)) {  | 
 | 143 | +    handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {  | 
 | 144 | +      errs() << "Error: " << EIB.message() << "\n";  | 
 | 145 | +    });  | 
 | 146 | +    return 1;  | 
 | 147 | +  }  | 
 | 148 | + | 
 | 149 | +  return 0;  | 
 | 150 | +}  | 
0 commit comments