Skip to content

Commit 879ed94

Browse files
committed
Flow-Aware Embeddings
1 parent 47f54e4 commit 879ed94

File tree

6 files changed

+259
-33
lines changed

6 files changed

+259
-33
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,19 @@ class IR2VecVocabAnalysis;
5454
/// of the IR entities. Flow-aware embeddings build on top of symbolic
5555
/// embeddings and additionally capture the flow information in the IR.
5656
/// IR2VecKind is used to specify the type of embeddings to generate.
57-
/// Currently, only Symbolic embeddings are supported.
58-
enum class IR2VecKind { Symbolic };
57+
/// Note: Implementation of FlowAware embeddings is not same as the one
58+
/// described in the paper. The current implementation is a simplified version
59+
/// that captures the flow information (SSA-based use-defs) without tracing
60+
/// through memory level use-defs in the embedding computation described in the
61+
/// paper.
62+
enum class IR2VecKind { Symbolic, FlowAware };
5963

6064
namespace ir2vec {
6165

6266
LLVM_ABI extern cl::opt<float> OpcWeight;
6367
LLVM_ABI extern cl::opt<float> TypeWeight;
6468
LLVM_ABI extern cl::opt<float> ArgWeight;
69+
LLVM_ABI extern cl::opt<IR2VecKind> IR2VecEmbeddingKind;
6570

6671
/// Embedding is a datatype that wraps std::vector<double>. It provides
6772
/// additional functionality for arithmetic and comparison operations.
@@ -257,9 +262,8 @@ class Embedder {
257262
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
258263

259264
/// Helper function to compute embeddings. It generates embeddings for all
260-
/// the instructions and basic blocks in the function F. Logic of computing
261-
/// the embeddings is specific to the kind of embeddings being computed.
262-
virtual void computeEmbeddings() const = 0;
265+
/// the instructions and basic blocks in the function F.
266+
void computeEmbeddings() const;
263267

264268
/// Helper function to compute the embedding for a given basic block.
265269
/// Specific to the kind of embeddings being computed.
@@ -296,7 +300,6 @@ class Embedder {
296300
/// representations obtained from the Vocabulary.
297301
class LLVM_ABI SymbolicEmbedder : public Embedder {
298302
private:
299-
void computeEmbeddings() const override;
300303
void computeEmbeddings(const BasicBlock &BB) const override;
301304

302305
public:
@@ -306,6 +309,20 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
306309
}
307310
};
308311

312+
/// Class for computing the Flow-aware embeddings of IR2Vec.
313+
/// Flow-aware embeddings build on the vocabulary, just like Symbolic
314+
/// embeddings, and additionally capture the flow information in the IR.
315+
class LLVM_ABI FlowAwareEmbedder : public Embedder {
316+
private:
317+
void computeEmbeddings(const BasicBlock &BB) const override;
318+
319+
public:
320+
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
321+
: Embedder(F, Vocab) {
322+
FuncVector = Embedding(Dimension, 0);
323+
}
324+
};
325+
309326
} // namespace ir2vec
310327

311328
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
5252
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
5353
cl::desc("Weight for argument embeddings"),
5454
cl::cat(IR2VecCategory));
55+
cl::opt<IR2VecKind> IR2VecEmbeddingKind(
56+
"ir2vec-kind", cl::Optional,
57+
cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
58+
"Generate symbolic embeddings"),
59+
clEnumValN(IR2VecKind::FlowAware, "flow-aware",
60+
"Generate flow-aware embeddings")),
61+
cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
62+
cl::cat(IR2VecCategory));
63+
5564
} // namespace ir2vec
5665
} // namespace llvm
5766

@@ -149,6 +158,8 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
149158
switch (Mode) {
150159
case IR2VecKind::Symbolic:
151160
return std::make_unique<SymbolicEmbedder>(F, Vocab);
161+
case IR2VecKind::FlowAware:
162+
return std::make_unique<FlowAwareEmbedder>(F, Vocab);
152163
}
153164
return nullptr;
154165
}
@@ -180,6 +191,17 @@ const Embedding &Embedder::getFunctionVector() const {
180191
return FuncVector;
181192
}
182193

194+
void Embedder::computeEmbeddings() const {
195+
if (F.isDeclaration())
196+
return;
197+
198+
// Consider only the basic blocks that are reachable from entry
199+
for (const BasicBlock *BB : depth_first(&F)) {
200+
computeEmbeddings(*BB);
201+
FuncVector += BBVecMap[BB];
202+
}
203+
}
204+
183205
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
184206
Embedding BBVector(Dimension, 0);
185207

@@ -196,15 +218,46 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196218
BBVecMap[&BB] = BBVector;
197219
}
198220

199-
void SymbolicEmbedder::computeEmbeddings() const {
200-
if (F.isDeclaration())
201-
return;
221+
void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
222+
Embedding BBVector(Dimension, 0);
202223

203-
// Consider only the basic blocks that are reachable from entry
204-
for (const BasicBlock *BB : depth_first(&F)) {
205-
computeEmbeddings(*BB);
206-
FuncVector += BBVecMap[BB];
224+
// We consider only the non-debug and non-pseudo instructions
225+
for (const auto &I : BB.instructionsWithoutDebug()) {
226+
// TODO: Handle call instructions differently.
227+
// For now, we treat them like other instructions
228+
Embedding ArgEmb(Dimension, 0);
229+
for (const auto &Op : I.operands()) {
230+
// If the operand is defined elsewhere, we use its embedding
231+
if (const Instruction *DefInst = dyn_cast<Instruction>(Op)) {
232+
auto DefIt = InstVecMap.find(DefInst);
233+
assert(DefIt != InstVecMap.end() &&
234+
"Instruction should have been processed before its operands");
235+
if (DefIt != InstVecMap.end()) {
236+
ArgEmb += DefIt->second;
237+
continue;
238+
}
239+
// If the definition is not in the map, we use the vocabulary
240+
// Not expected, but handle it gracefully
241+
LLVM_DEBUG(dbgs() << "Warning: Operand defined by instruction not "
242+
"found in InstVecMap: "
243+
<< *DefInst << "\n");
244+
ArgEmb += Vocab[Op];
245+
}
246+
// If the operand is not defined by an instruction, we use the vocabulary
247+
else {
248+
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
249+
<< *Op << "=" << Vocab[Op][0] << "\n");
250+
ArgEmb += Vocab[Op];
251+
}
252+
}
253+
// Create the instruction vector by combining opcode, type, and arguments
254+
// embeddings
255+
auto InstVector =
256+
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
257+
InstVecMap[&I] = InstVector;
258+
BBVector += InstVector;
207259
}
260+
BBVecMap[&BB] = BBVector;
208261
}
209262

210263
// ==----------------------------------------------------------------------===//
@@ -552,8 +605,11 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
552605
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
553606

554607
for (Function &F : M) {
555-
std::unique_ptr<Embedder> Emb =
556-
Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
608+
std::unique_ptr<Embedder> Emb;
609+
if (IR2VecEmbeddingKind == IR2VecKind::Symbolic)
610+
Emb = Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
611+
else
612+
Emb = Embedder::create(IR2VecKind::FlowAware, F, Vocabulary);
557613
if (!Emb) {
558614
OS << "Error creating IR2Vec embeddings \n";
559615
continue;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
2+
; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
3+
; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
4+
5+
define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
6+
entry:
7+
%a.addr = alloca i32, align 4
8+
%b.addr = alloca float, align 4
9+
store i32 %a, ptr %a.addr, align 4
10+
store float %b, ptr %b.addr, align 4
11+
%0 = load i32, ptr %a.addr, align 4
12+
%1 = load i32, ptr %a.addr, align 4
13+
%mul = mul nsw i32 %0, %1
14+
%conv = sitofp i32 %mul to float
15+
%2 = load float, ptr %b.addr, align 4
16+
%add = fadd float %conv, %2
17+
ret float %add
18+
}
19+
20+
; 3D-CHECK-OPC: IR2Vec embeddings for function _Z3abcif:
21+
; 3D-CHECK-OPC-NEXT: Function vector: [ 3630.00 3672.00 3714.00 ]
22+
; 3D-CHECK-OPC-NEXT: Basic block vectors:
23+
; 3D-CHECK-OPC-NEXT: Basic block: entry:
24+
; 3D-CHECK-OPC-NEXT: [ 3630.00 3672.00 3714.00 ]
25+
; 3D-CHECK-OPC-NEXT: Instruction vectors:
26+
; 3D-CHECK-OPC-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ]
27+
; 3D-CHECK-OPC-NEXT: Instruction: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ]
28+
; 3D-CHECK-OPC-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 188.00 190.00 192.00 ]
29+
; 3D-CHECK-OPC-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 188.00 190.00 192.00 ]
30+
; 3D-CHECK-OPC-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
31+
; 3D-CHECK-OPC-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
32+
; 3D-CHECK-OPC-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 419.00 424.00 429.00 ]
33+
; 3D-CHECK-OPC-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 549.00 555.00 561.00 ]
34+
; 3D-CHECK-OPC-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 185.00 187.00 189.00 ]
35+
; 3D-CHECK-OPC-NEXT: Instruction: %add = fadd float %conv, %2 [ 774.00 783.00 792.00 ]
36+
; 3D-CHECK-OPC-NEXT: Instruction: ret float %add [ 775.00 785.00 795.00 ]
37+
38+
; 3D-CHECK-TYPE: IR2Vec embeddings for function _Z3abcif:
39+
; 3D-CHECK-TYPE-NEXT: Function vector: [ 355.50 376.50 397.50 ]
40+
; 3D-CHECK-TYPE-NEXT: Basic block vectors:
41+
; 3D-CHECK-TYPE-NEXT: Basic block: entry:
42+
; 3D-CHECK-TYPE-NEXT: [ 355.50 376.50 397.50 ]
43+
; 3D-CHECK-TYPE-NEXT: Instruction vectors:
44+
; 3D-CHECK-TYPE-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 12.50 13.00 13.50 ]
45+
; 3D-CHECK-TYPE-NEXT: Instruction: %b.addr = alloca float, align 4 [ 12.50 13.00 13.50 ]
46+
; 3D-CHECK-TYPE-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 14.50 15.50 16.50 ]
47+
; 3D-CHECK-TYPE-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 14.50 15.50 16.50 ]
48+
; 3D-CHECK-TYPE-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
49+
; 3D-CHECK-TYPE-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
50+
; 3D-CHECK-TYPE-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 53.50 56.00 58.50 ]
51+
; 3D-CHECK-TYPE-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 54.00 57.00 60.00 ]
52+
; 3D-CHECK-TYPE-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 13.00 14.00 15.00 ]
53+
; 3D-CHECK-TYPE-NEXT: Instruction: %add = fadd float %conv, %2 [ 67.50 72.00 76.50 ]
54+
; 3D-CHECK-TYPE-NEXT: Instruction: ret float %add [ 69.50 74.50 79.50 ]
55+
56+
; 3D-CHECK-ARG: IR2Vec embeddings for function _Z3abcif:
57+
; 3D-CHECK-ARG-NEXT: Function vector: [ 27.80 31.60 35.40 ]
58+
; 3D-CHECK-ARG-NEXT: Basic block vectors:
59+
; 3D-CHECK-ARG-NEXT: Basic block: entry:
60+
; 3D-CHECK-ARG-NEXT: [ 27.80 31.60 35.40 ]
61+
; 3D-CHECK-ARG-NEXT: Instruction vectors:
62+
; 3D-CHECK-ARG-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 1.40 1.60 1.80 ]
63+
; 3D-CHECK-ARG-NEXT: Instruction: %b.addr = alloca float, align 4 [ 1.40 1.60 1.80 ]
64+
; 3D-CHECK-ARG-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 3.40 3.80 4.20 ]
65+
; 3D-CHECK-ARG-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 3.40 3.80 4.20 ]
66+
; 3D-CHECK-ARG-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
67+
; 3D-CHECK-ARG-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
68+
; 3D-CHECK-ARG-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 2.80 3.20 3.60 ]
69+
; 3D-CHECK-ARG-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 2.80 3.20 3.60 ]
70+
; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 1.40 1.60 1.80 ]
71+
; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.20 4.80 5.40 ]
72+
; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 4.20 4.80 5.40 ]

llvm/test/Analysis/IR2Vec/basic.ll renamed to llvm/test/Analysis/IR2Vec/basic-symbolic.ll

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
22
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
33
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
4-
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
5-
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
6-
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
7-
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
8-
4+
95
define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
106
entry:
117
%a.addr = alloca i32, align 4
@@ -74,11 +70,3 @@ entry:
7470
; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 0.80 1.00 1.20 ]
7571
; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.00 4.40 4.80 ]
7672
; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 2.00 2.20 2.40 ]
77-
78-
; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
79-
80-
; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
81-
82-
; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
83-
84-
; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
2+
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
3+
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
4+
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
5+
6+
define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
7+
entry:
8+
%a.addr = alloca i32, align 4
9+
%b.addr = alloca float, align 4
10+
store i32 %a, ptr %a.addr, align 4
11+
store float %b, ptr %b.addr, align 4
12+
%0 = load i32, ptr %a.addr, align 4
13+
%1 = load i32, ptr %a.addr, align 4
14+
%mul = mul nsw i32 %0, %1
15+
%conv = sitofp i32 %mul to float
16+
%2 = load float, ptr %b.addr, align 4
17+
%add = fadd float %conv, %2
18+
ret float %add
19+
}
20+
21+
; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
22+
23+
; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
24+
25+
; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
26+
27+
; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions

0 commit comments

Comments
 (0)