Skip to content

Commit be7f312

Browse files
committed
Addressing review comments
1 parent 9b4541e commit be7f312

File tree

8 files changed

+541
-664
lines changed

8 files changed

+541
-664
lines changed

llvm/docs/MLGO.rst

Lines changed: 47 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -191,130 +191,79 @@ of loops and regions can be derived from these representations, which can be
191191
useful in different scenarios. The representations can be useful for various
192192
downstream tasks, including ML-guided compiler optimizations.
193193

194-
Currently, to use IR2Vec embeddings, the JSON vocabulary first needs to be read
195-
and used to obtain the vocabulary mapping. Then, use this mapping to
196-
derive the representations. In LLVM, this process is implemented using two
197-
independent passes: ``IR2VecVocabAnalysis`` and ``IR2VecAnalysis``. The former
198-
reads the JSON vocabulary and populates ``IR2VecVocabResult``, which is then used
199-
by ``IR2VecAnalysis``.
194+
The core components are:
195+
- **Vocabulary**: A mapping from IR entities (opcodes, types, etc.) to their
196+
vector representations. This is managed by ``IR2VecVocabAnalysis``.
197+
- **Embedder**: A class (``ir2vec::Embedder``) that uses the vocabulary to
198+
compute embeddings for instructions, basic blocks, and functions.
200199

201-
``IR2VecVocabAnalysis`` is immutable and is intended to
202-
be run once before ``IR2VecAnalysis`` is run. In the future, we plan
203-
to improve this requirement by automatically generating default the vocabulary mappings
204-
during build time, eliminating the need for a separate file read.
200+
Using IR2Vec
201+
------------
205202

206-
IR2VecAnalysis Usage
207-
--------------------
203+
For generating embeddings, first the vocabulary should be obtained. Then, the
204+
embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
208205

209-
To use IR2Vec in an LLVM-based tool or pass, interaction with the analysis
210-
results can be done through the following APIs:
211-
212-
1. **Accessing the Analysis Results:**
213-
214-
To access the IR2Vec embeddings, obtain the ``IR2VecAnalysis``
215-
result from the Function Analysis Manager (FAM).
206+
1. **Get the Vocabulary**:
207+
In a ModulePass, get the vocabulary analysis result:
216208

217209
.. code-block:: c++
218210

219-
#include "llvm/Analysis/IR2VecAnalysis.h"
220-
221-
// ... other includes and code ...
222-
223-
llvm::FunctionAnalysisManager &FAM = ...; // The FAM instance
224-
llvm::Function &F = ...; // The function to analyze
225-
auto &IR2VecResult = FAM.getResult<llvm::IR2VecAnalysis>(F);
226-
227-
2. **Checking for Valid Results:**
228-
229-
Ensure that the analysis result is valid before accessing the embeddings:
230-
231-
.. code-block:: c++
232-
233-
if (IR2VecResult.isValid()) {
234-
// Proceed to access embeddings
211+
auto &VocabRes = MAM.getResult<IR2VecVocabAnalysis>(M);
212+
if (!VocabRes.isValid()) {
213+
// Handle error: vocabulary is not available or invalid
214+
return;
235215
}
216+
const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary();
217+
unsigned Dimension = VocabRes.getDimension();
236218

237-
3. **Retrieving Embeddings:**
219+
Note that ``IR2VecVocabAnalysis`` pass is immutable.
238220

239-
The ``IR2VecResult`` provides access to embeddings (currently) at three levels:
221+
2. **Create Embedder instance**:
222+
With the vocabulary, create an embedder for a specific function:
240223

241-
- **Instruction Embeddings:**
224+
.. code-block:: c++
242225

243-
.. code-block:: c++
226+
// Assuming F is an llvm::Function&
227+
// For example, using IR2VecKind::Symbolic:
228+
ErrorOr<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
229+
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension);
244230

245-
const auto &instVecMap = IR2VecResult.getInstVecMap();
246-
// instVecMap is a SmallMapVector<const Instruction*, ir2vec::Embedding, 128>
247-
for (const auto &it : instVecMap) {
248-
const Instruction *I = it.first;
249-
const ir2vec::Embedding &embedding = it.second;
250-
// Use the instruction embedding
251-
}
252-
- **Basic Block Embeddings:**
231+
if (auto EC = EmbOrErr.getError()) {
232+
// Handle error in embedder creation
233+
return;
234+
}
235+
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
253236
254-
.. code-block:: c++
237+
3. **Compute and Access Embeddings**:
238+
Call ``computeEmbeddings()`` on the embedder instance to compute the
239+
embeddings. Then the embeddings can be accessed using different getter
240+
methods. Currently, ``Embedder`` can generate embeddings at three levels:
241+
Instructions, Basic Blocks, and Functions.
255242

256-
const auto &bbVecMap = IR2VecResult.getBBVecMap();
257-
// bbVecMap is a SmallMapVector<const BasicBlock*, ir2vec::Embedding, 16>
258-
for (const auto &it : bbVecMap) {
259-
const BasicBlock *BB = it.first;
260-
const ir2vec::Embedding &embedding = it.second;
261-
// Use the basic block embedding
262-
}
263-
- **Function Embedding:**
243+
.. code-block:: c++
264244

265-
.. code-block:: c++
245+
Emb->computeEmbeddings();
246+
const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
247+
const ir2vec::InstEmbeddingsMap &InstVecMap = Emb->getInstVecMap();
248+
const ir2vec::BBEmbeddingsMap &BBVecMap = Emb->getBBVecMap();
266249

267-
const ir2vec::Embedding &funcEmbedding = IR2VecResult.getFunctionVector();
268-
// Use the function embedding
250+
// Example: Iterate over instruction embeddings
251+
for (const auto &Entry : InstVecMap) {
252+
const Instruction *Inst = Entry.getFirst();
253+
const ir2vec::Embedding &InstEmbedding = Entry.getSecond();
254+
// Use Inst and InstEmbedding
255+
}
269256
270257
4. **Working with Embeddings:**
271-
272258
Embeddings are represented as ``std::vector<double>``. These
273259
vectors as features for machine learning models, compute similarity scores
274260
between different code snippets, or perform other analyses as needed.
275261

276-
Example Usage
277-
^^^^^^^^^^^^^
278-
279-
.. code-block:: c++
280-
281-
#include "llvm/Analysis/IR2VecAnalysis.h"
282-
#include "llvm/IR/Function.h"
283-
#include "llvm/IR/Instructions.h"
284-
#include "llvm/Passes/PassBuilder.h"
285-
286-
// ... other includes and code ...
287-
288-
void processFunction(llvm::Function &F, llvm::FunctionAnalysisManager &FAM) {
289-
auto &IR2VecResult = FAM.getResult<llvm::IR2VecAnalysis>(F);
290-
291-
if (IR2VecResult.isValid()) {
292-
const auto &instVecMap = IR2VecResult.getInstVecMap();
293-
for (const auto &it : instVecMap) {
294-
const Instruction *I = it.first;
295-
const auto &embedding = it.second;
296-
llvm::errs() << "Instruction: " << *I << "\n";
297-
llvm::errs() << "Embedding: ";
298-
for (double val : embedding) {
299-
llvm::errs() << val << " ";
300-
}
301-
llvm::errs() << "\n";
302-
}
303-
} else {
304-
llvm::errs() << "IR2Vec analysis failed for function " << F.getName() << "\n";
305-
}
306-
}
307-
308-
// ... rest of the pass ...
309-
310-
// In the pass's run method:
311-
// processFunction(F, FAM);
312-
313262
Further Details
314263
---------------
315264

316265
For more detailed information about the IR2Vec algorithm, its parameters, and
317266
advanced usage, please refer to the original paper:
318267
`IR2Vec: LLVM IR Based Scalable Program Embeddings <https://doi.org/10.1145/3418463>`_.
319-
The LLVM source code for ``IR2VecAnalysis`` can also be explored to understand the
268+
The LLVM source code for ``IR2Vec`` can also be explored to understand the
320269
implementation details.
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
//===- IR2Vec.h - Implementation of IR2Vec ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See the LICENSE file for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis),
11+
/// the core ir2vec::Embedder interface for generating IR embeddings,
12+
/// and related utilities like the IR2VecPrinterPass.
13+
///
14+
/// Program Embeddings are typically or derived-from a learned
15+
/// representation of the program. Such embeddings are used to represent the
16+
/// programs as input to machine learning algorithms. IR2Vec represents the
17+
/// LLVM IR as embeddings.
18+
///
19+
/// The IR2Vec algorithm is described in the following paper:
20+
///
21+
/// IR2Vec: LLVM IR Based Scalable Program Embeddings, S. VenkataKeerthy,
22+
/// Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar, Ramakrishna
23+
/// Upadrasta, and Y. N. Srikant, ACM Transactions on Architecture and
24+
/// Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463.
25+
/// https://arxiv.org/abs/1909.06228
26+
///
27+
//===----------------------------------------------------------------------===//
28+
29+
#ifndef LLVM_ANALYSIS_IR2VEC_H
30+
#define LLVM_ANALYSIS_IR2VEC_H
31+
32+
#include "llvm/ADT/DenseMap.h"
33+
#include "llvm/IR/PassManager.h"
34+
#include "llvm/Support/ErrorOr.h"
35+
#include <map>
36+
37+
namespace llvm {
38+
39+
class Module;
40+
class BasicBlock;
41+
class Instruction;
42+
class Function;
43+
class Type;
44+
class Value;
45+
class raw_ostream;
46+
47+
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
48+
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
49+
/// of the IR entities. Flow-aware embeddings build on top of symbolic
50+
/// embeddings and additionally capture the flow information in the IR.
51+
/// IR2VecKind is used to specify the type of embeddings to generate.
52+
/// Currently, only Symbolic embeddings are supported.
53+
enum class IR2VecKind { Symbolic };
54+
55+
namespace ir2vec {
56+
using Embedding = std::vector<double>;
57+
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
58+
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
59+
// FIXME: Current the keys are strings. This can be changed to
60+
// use integers for cheaper lookups.
61+
using Vocab = std::map<std::string, Embedding>;
62+
63+
/// Embedder provides the interface to generate embeddings (vector
64+
/// representations) for instructions, basic blocks, and functions. The vector
65+
/// representations are generated using IR2Vec algorithms.
66+
///
67+
/// The Embedder class is an abstract class and it is intended to be
68+
/// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware.
69+
class Embedder {
70+
protected:
71+
const Function &F;
72+
const Vocab &Vocabulary;
73+
74+
/// Weights for different entities (like opcode, arguments, types)
75+
/// in the IR instructions to generate the vector representation.
76+
// FIXME: Defaults to the values used in the original algorithm. Can be
77+
// parameterized later.
78+
const float OpcWeight = 1.0, TypeWeight = 0.5, ArgWeight = 0.2;
79+
80+
/// Dimension of the vector representation; captured from the input vocabulary
81+
const unsigned Dimension;
82+
83+
// Utility maps - these are used to store the vector representations of
84+
// instructions, basic blocks and functions.
85+
Embedding FuncVector;
86+
BBEmbeddingsMap BBVecMap;
87+
InstEmbeddingsMap InstVecMap;
88+
89+
Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
90+
91+
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
92+
/// zero vector.
93+
Embedding lookupVocab(const std::string &Key);
94+
95+
/// Adds two vectors: Dst += Src
96+
void addVectors(Embedding &Dst, const Embedding &Src);
97+
98+
/// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
99+
void addScaledVector(Embedding &Dst, const Embedding &Src, float Factor);
100+
101+
public:
102+
virtual ~Embedder() = default;
103+
104+
/// Top level function to compute embeddings. Given a function, it
105+
/// generates embeddings for all the instructions and basic blocks in that
106+
/// function. Logic of computing the embeddings is specific to the kind of
107+
/// embeddings being computed.
108+
virtual void computeEmbeddings() = 0;
109+
110+
/// Factory method to create an Embedder object.
111+
static ErrorOr<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
112+
const Function &F,
113+
const Vocab &Vocabulary,
114+
unsigned Dimension);
115+
116+
/// Returns a map containing instructions and the corresponding vector
117+
/// representations for a given module corresponding to the IR2Vec
118+
/// algorithm.
119+
const InstEmbeddingsMap &getInstVecMap() const { return InstVecMap; }
120+
121+
/// Returns a map containing basic block and the corresponding vector
122+
/// representations for a given module corresponding to the IR2Vec
123+
/// algorithm.
124+
const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
125+
126+
/// Returns the vector representation for a given function corresponding to
127+
/// the IR2Vec algorithm.
128+
const Embedding &getFunctionVector() const { return FuncVector; }
129+
};
130+
131+
/// Class for computing the Symbolic embeddings of IR2Vec
132+
class SymbolicEmbedder : public Embedder {
133+
private:
134+
/// Utility function to compute the vector representation for a given basic
135+
/// block.
136+
Embedding computeBB2Vec(const BasicBlock &BB);
137+
138+
/// Utility function to compute the vector representation for a given
139+
/// function.
140+
Embedding computeFunc2Vec();
141+
142+
/// Utility function to compute the vector representation for a given type.
143+
Embedding getTypeEmbedding(const Type *Ty);
144+
145+
/// Utility function to compute the vector representation for a given
146+
/// operand.
147+
Embedding getOperandEmbedding(const Value *Op);
148+
149+
public:
150+
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
151+
unsigned Dimension)
152+
: Embedder(F, Vocabulary, Dimension) {
153+
FuncVector = Embedding(Dimension, 0);
154+
}
155+
void computeEmbeddings() override;
156+
};
157+
158+
} // namespace ir2vec
159+
160+
class IR2VecVocabResult;
161+
162+
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
163+
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
164+
/// its corresponding embedding.
165+
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
166+
ir2vec::Vocab Vocabulary;
167+
Error readVocabulary();
168+
169+
public:
170+
static AnalysisKey Key;
171+
IR2VecVocabAnalysis() = default;
172+
using Result = IR2VecVocabResult;
173+
Result run(Module &M, ModuleAnalysisManager &MAM);
174+
};
175+
176+
class IR2VecVocabResult {
177+
ir2vec::Vocab Vocabulary;
178+
bool Valid = false;
179+
180+
public:
181+
IR2VecVocabResult() = default;
182+
IR2VecVocabResult(ir2vec::Vocab &&Vocabulary);
183+
184+
bool isValid() const { return Valid; }
185+
const ir2vec::Vocab &getVocabulary() const;
186+
unsigned getDimension() const;
187+
bool invalidate(Module &M, const PreservedAnalyses &PA,
188+
ModuleAnalysisManager::Invalidator &Inv);
189+
};
190+
191+
/// This pass prints the IR2Vec embeddings for instructions, basic blocks, and
192+
/// functions.
193+
class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
194+
raw_ostream &OS;
195+
void printVector(const ir2vec::Embedding &Vec) const;
196+
197+
public:
198+
explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
199+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
200+
static bool isRequired() { return true; }
201+
};
202+
203+
} // namespace llvm
204+
205+
#endif // LLVM_ANALYSIS_IR2VEC_H

0 commit comments

Comments
 (0)