|
| 1 | +//===- IR2VecAnalysis.h - IR2Vec Analysis Implementation -------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM |
| 4 | +// Exceptions. See the LICENSE file for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +/// |
| 9 | +/// \file |
| 10 | +/// This file contains the declaration of IR2VecAnalysis that computes |
| 11 | +/// IR2Vec Embeddings of the program. |
| 12 | +/// |
| 13 | +/// Program Embeddings are typically or derived-from a learned |
| 14 | +/// representation of the program. Such embeddings are used to represent the |
| 15 | +/// programs as input to machine learning algorithms. IR2Vec represents the |
| 16 | +/// LLVM IR as embeddings. |
| 17 | +/// |
| 18 | +/// The IR2Vec algorithm is described in the following paper: |
| 19 | +/// |
| 20 | +/// IR2Vec: LLVM IR Based Scalable Program Embeddings, S. VenkataKeerthy, |
| 21 | +/// Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar, Ramakrishna |
| 22 | +/// Upadrasta, and Y. N. Srikant, ACM Transactions on Architecture and |
| 23 | +/// Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463. |
| 24 | +/// https://arxiv.org/abs/1909.06228 |
| 25 | +/// |
| 26 | +//===----------------------------------------------------------------------===// |
| 27 | + |
| 28 | +#ifndef LLVM_ANALYSIS_IR2VECANALYSIS_H |
| 29 | +#define LLVM_ANALYSIS_IR2VECANALYSIS_H |
| 30 | + |
| 31 | +#include "llvm/ADT/MapVector.h" |
| 32 | +#include "llvm/IR/PassManager.h" |
| 33 | +#include <map> |
| 34 | + |
| 35 | +namespace llvm { |
| 36 | + |
| 37 | +class Module; |
| 38 | +class BasicBlock; |
| 39 | +class Instruction; |
| 40 | +class Function; |
| 41 | + |
| 42 | +namespace ir2vec { |
| 43 | +using Embedding = std::vector<double>; |
| 44 | +// ToDo: Current the keys are strings. This can be changed to |
| 45 | +// use integers for cheaper lookups. |
| 46 | +using Vocab = std::map<std::string, Embedding>; |
| 47 | +} // namespace ir2vec |
| 48 | + |
| 49 | +class VocabResult; |
| 50 | +class IR2VecResult; |
| 51 | + |
| 52 | +/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a |
| 53 | +/// mapping between an entity of the IR (like opcode, type, argument, etc.) and |
| 54 | +/// its corresponding embedding. |
| 55 | +class VocabAnalysis : public AnalysisInfoMixin<VocabAnalysis> { |
| 56 | + unsigned DIM = 0; |
| 57 | + ir2vec::Vocab Vocabulary; |
| 58 | + Error readVocabulary(); |
| 59 | + |
| 60 | +public: |
| 61 | + static AnalysisKey Key; |
| 62 | + VocabAnalysis() = default; |
| 63 | + using Result = VocabResult; |
| 64 | + Result run(Module &M, ModuleAnalysisManager &MAM); |
| 65 | +}; |
| 66 | + |
| 67 | +class VocabResult { |
| 68 | + ir2vec::Vocab Vocabulary; |
| 69 | + bool Valid = false; |
| 70 | + unsigned DIM = 0; |
| 71 | + |
| 72 | +public: |
| 73 | + VocabResult() = default; |
| 74 | + VocabResult(const ir2vec::Vocab &Vocabulary, unsigned Dim); |
| 75 | + |
| 76 | + // Helper functions |
| 77 | + bool isValid() const { return Valid; } |
| 78 | + const ir2vec::Vocab &getVocabulary() const; |
| 79 | + unsigned getDimension() const { return DIM; } |
| 80 | + bool invalidate(Module &M, const PreservedAnalyses &PA, |
| 81 | + ModuleAnalysisManager::Invalidator &Inv); |
| 82 | +}; |
| 83 | + |
| 84 | +class IR2VecResult { |
| 85 | + SmallMapVector<const Instruction *, ir2vec::Embedding, 128> InstVecMap; |
| 86 | + SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> BBVecMap; |
| 87 | + ir2vec::Embedding FuncVector; |
| 88 | + unsigned DIM = 0; |
| 89 | + bool Valid = false; |
| 90 | + |
| 91 | +public: |
| 92 | + IR2VecResult() = default; |
| 93 | + IR2VecResult( |
| 94 | + SmallMapVector<const Instruction *, ir2vec::Embedding, 128> InstMap, |
| 95 | + SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> BBMap, |
| 96 | + const ir2vec::Embedding &FuncVector, unsigned Dim); |
| 97 | + bool isValid() const { return Valid; } |
| 98 | + |
| 99 | + const SmallMapVector<const Instruction *, ir2vec::Embedding, 128> & |
| 100 | + getInstVecMap() const; |
| 101 | + const SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> & |
| 102 | + getBBVecMap() const; |
| 103 | + const ir2vec::Embedding &getFunctionVector() const; |
| 104 | + unsigned getDimension() const; |
| 105 | +}; |
| 106 | + |
| 107 | +/// This analysis provides the IR2Vec embeddings for instructions, basic blocks, |
| 108 | +/// and functions. |
| 109 | +class IR2VecAnalysis : public AnalysisInfoMixin<IR2VecAnalysis> { |
| 110 | + bool Avg; |
| 111 | + float WO = 1, WT = 0.5, WA = 0.2; |
| 112 | + |
| 113 | +public: |
| 114 | + IR2VecAnalysis() = default; |
| 115 | + static AnalysisKey Key; |
| 116 | + using Result = IR2VecResult; |
| 117 | + Result run(Function &F, FunctionAnalysisManager &FAM); |
| 118 | +}; |
| 119 | + |
| 120 | +/// This pass prints the IR2Vec embeddings for instructions, basic blocks, and |
| 121 | +/// functions. |
| 122 | +class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> { |
| 123 | + raw_ostream &OS; |
| 124 | + void printVector(const ir2vec::Embedding &Vec) const; |
| 125 | + |
| 126 | +public: |
| 127 | + explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {} |
| 128 | + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); |
| 129 | + static bool isRequired() { return true; } |
| 130 | +}; |
| 131 | + |
| 132 | +} // namespace llvm |
| 133 | + |
| 134 | +#endif // LLVM_ANALYSIS_IR2VECANALYSIS_H |
0 commit comments