Skip to content

Commit 048b61d

Browse files
committed
Adding IR2Vec as an analysis pass
1 parent 749535b commit 048b61d

File tree

9 files changed

+670
-0
lines changed

9 files changed

+670
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
//===- IR2VecAnalysis.h - IR2Vec Analysis Implementation -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See the LICENSE file for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// This file contains the declaration of IR2VecAnalysis that computes
11+
/// IR2Vec Embeddings of the program.
12+
///
13+
/// Program Embeddings are typically or derived-from a learned
14+
/// representation of the program. Such embeddings are used to represent the
15+
/// programs as input to machine learning algorithms. IR2Vec represents the
16+
/// LLVM IR as embeddings.
17+
///
18+
/// The IR2Vec algorithm is described in the following paper:
19+
///
20+
/// IR2Vec: LLVM IR Based Scalable Program Embeddings, S. VenkataKeerthy,
21+
/// Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar, Ramakrishna
22+
/// Upadrasta, and Y. N. Srikant, ACM Transactions on Architecture and
23+
/// Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463.
24+
/// https://arxiv.org/abs/1909.06228
25+
///
26+
//===----------------------------------------------------------------------===//
27+
28+
#ifndef LLVM_ANALYSIS_IR2VECANALYSIS_H
29+
#define LLVM_ANALYSIS_IR2VECANALYSIS_H
30+
31+
#include "llvm/ADT/MapVector.h"
32+
#include "llvm/IR/PassManager.h"
33+
#include <map>
34+
35+
namespace llvm {
36+
37+
class Module;
38+
class BasicBlock;
39+
class Instruction;
40+
class Function;
41+
42+
namespace ir2vec {
43+
using Embedding = std::vector<double>;
44+
// ToDo: Current the keys are strings. This can be changed to
45+
// use integers for cheaper lookups.
46+
using Vocab = std::map<std::string, Embedding>;
47+
} // namespace ir2vec
48+
49+
class VocabResult;
50+
class IR2VecResult;
51+
52+
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
53+
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
54+
/// its corresponding embedding.
55+
class VocabAnalysis : public AnalysisInfoMixin<VocabAnalysis> {
56+
unsigned DIM = 0;
57+
ir2vec::Vocab Vocabulary;
58+
Error readVocabulary();
59+
60+
public:
61+
static AnalysisKey Key;
62+
VocabAnalysis() = default;
63+
using Result = VocabResult;
64+
Result run(Module &M, ModuleAnalysisManager &MAM);
65+
};
66+
67+
class VocabResult {
68+
ir2vec::Vocab Vocabulary;
69+
bool Valid = false;
70+
unsigned DIM = 0;
71+
72+
public:
73+
VocabResult() = default;
74+
VocabResult(const ir2vec::Vocab &Vocabulary, unsigned Dim);
75+
76+
// Helper functions
77+
bool isValid() const { return Valid; }
78+
const ir2vec::Vocab &getVocabulary() const;
79+
unsigned getDimension() const { return DIM; }
80+
bool invalidate(Module &M, const PreservedAnalyses &PA,
81+
ModuleAnalysisManager::Invalidator &Inv);
82+
};
83+
84+
class IR2VecResult {
85+
SmallMapVector<const Instruction *, ir2vec::Embedding, 128> InstVecMap;
86+
SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> BBVecMap;
87+
ir2vec::Embedding FuncVector;
88+
unsigned DIM = 0;
89+
bool Valid = false;
90+
91+
public:
92+
IR2VecResult() = default;
93+
IR2VecResult(
94+
SmallMapVector<const Instruction *, ir2vec::Embedding, 128> InstMap,
95+
SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> BBMap,
96+
const ir2vec::Embedding &FuncVector, unsigned Dim);
97+
bool isValid() const { return Valid; }
98+
99+
const SmallMapVector<const Instruction *, ir2vec::Embedding, 128> &
100+
getInstVecMap() const;
101+
const SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> &
102+
getBBVecMap() const;
103+
const ir2vec::Embedding &getFunctionVector() const;
104+
unsigned getDimension() const;
105+
};
106+
107+
/// This analysis provides the IR2Vec embeddings for instructions, basic blocks,
108+
/// and functions.
109+
class IR2VecAnalysis : public AnalysisInfoMixin<IR2VecAnalysis> {
110+
bool Avg;
111+
float WO = 1, WT = 0.5, WA = 0.2;
112+
113+
public:
114+
IR2VecAnalysis() = default;
115+
static AnalysisKey Key;
116+
using Result = IR2VecResult;
117+
Result run(Function &F, FunctionAnalysisManager &FAM);
118+
};
119+
120+
/// This pass prints the IR2Vec embeddings for instructions, basic blocks, and
121+
/// functions.
122+
class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
123+
raw_ostream &OS;
124+
void printVector(const ir2vec::Embedding &Vec) const;
125+
126+
public:
127+
explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
128+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
129+
static bool isRequired() { return true; }
130+
};
131+
132+
} // namespace llvm
133+
134+
#endif // LLVM_ANALYSIS_IR2VECANALYSIS_H

llvm/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ add_llvm_component_library(LLVMAnalysis
6767
GlobalsModRef.cpp
6868
GuardUtils.cpp
6969
HeatUtils.cpp
70+
IR2VecAnalysis.cpp
7071
IRSimilarityIdentifier.cpp
7172
IVDescriptors.cpp
7273
IVUsers.cpp

0 commit comments

Comments
 (0)