Skip to content

[IR2Vec] Add support for flow-aware embeddings #152613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,19 @@ class IR2VecVocabAnalysis;
/// of the IR entities. Flow-aware embeddings build on top of symbolic
/// embeddings and additionally capture the flow information in the IR.
/// IR2VecKind is used to specify the type of embeddings to generate.
/// Currently, only Symbolic embeddings are supported.
enum class IR2VecKind { Symbolic };
/// Note: Implementation of FlowAware embeddings is not same as the one
/// described in the paper. The current implementation is a simplified version
/// that captures the flow information (SSA-based use-defs) without tracing
/// through memory level use-defs in the embedding computation described in the
/// paper.
enum class IR2VecKind { Symbolic, FlowAware };

namespace ir2vec {

LLVM_ABI extern cl::opt<float> OpcWeight;
LLVM_ABI extern cl::opt<float> TypeWeight;
LLVM_ABI extern cl::opt<float> ArgWeight;
LLVM_ABI extern cl::opt<IR2VecKind> IR2VecEmbeddingKind;

/// Embedding is a datatype that wraps std::vector<double>. It provides
/// additional functionality for arithmetic and comparison operations.
Expand Down Expand Up @@ -257,9 +262,8 @@ class Embedder {
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);

/// Helper function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F. Logic of computing
/// the embeddings is specific to the kind of embeddings being computed.
virtual void computeEmbeddings() const = 0;
/// the instructions and basic blocks in the function F.
void computeEmbeddings() const;

/// Helper function to compute the embedding for a given basic block.
/// Specific to the kind of embeddings being computed.
Expand Down Expand Up @@ -296,7 +300,6 @@ class Embedder {
/// representations obtained from the Vocabulary.
class LLVM_ABI SymbolicEmbedder : public Embedder {
private:
void computeEmbeddings() const override;
void computeEmbeddings(const BasicBlock &BB) const override;

public:
Expand All @@ -306,6 +309,20 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
}
};

/// Class for computing the Flow-aware embeddings of IR2Vec.
/// Flow-aware embeddings build on the vocabulary, just like Symbolic
/// embeddings, and additionally capture the flow information in the IR.
class LLVM_ABI FlowAwareEmbedder : public Embedder {
private:
void computeEmbeddings(const BasicBlock &BB) const override;

public:
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
: Embedder(F, Vocab) {
FuncVector = Embedding(Dimension, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize in the initializer list? FuncVector(Dimension, 0)

same in the Symbolic, now that I see it

another option would be to let the base class do it, seems both cases (flow and symbolic) would do the same thing?

}
};

} // namespace ir2vec

/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
Expand Down
72 changes: 62 additions & 10 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
cl::opt<IR2VecKind> IR2VecEmbeddingKind(
"ir2vec-kind", cl::Optional,
cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
"Generate symbolic embeddings"),
clEnumValN(IR2VecKind::FlowAware, "flow-aware",
"Generate flow-aware embeddings")),
cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
cl::cat(IR2VecCategory));

} // namespace ir2vec
} // namespace llvm

Expand Down Expand Up @@ -123,8 +132,12 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
double Tolerance) const {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
<< (*this)[Itr] << " vs " << RHS[Itr]
<< "; Tolerance: " << Tolerance << "\n");
return false;
}
return true;
}

Expand All @@ -149,6 +162,8 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocab);
case IR2VecKind::FlowAware:
return std::make_unique<FlowAwareEmbedder>(F, Vocab);
}
return nullptr;
}
Expand Down Expand Up @@ -180,6 +195,17 @@ const Embedding &Embedder::getFunctionVector() const {
return FuncVector;
}

void Embedder::computeEmbeddings() const {
if (F.isDeclaration())
return;

// Consider only the basic blocks that are reachable from entry
for (const BasicBlock *BB : depth_first(&F)) {
computeEmbeddings(*BB);
FuncVector += BBVecMap[BB];
}
}

void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);

Expand All @@ -196,15 +222,38 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
BBVecMap[&BB] = BBVector;
}

void SymbolicEmbedder::computeEmbeddings() const {
if (F.isDeclaration())
return;
void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);

// Consider only the basic blocks that are reachable from entry
for (const BasicBlock *BB : depth_first(&F)) {
computeEmbeddings(*BB);
FuncVector += BBVecMap[BB];
// We consider only the non-debug and non-pseudo instructions
for (const auto &I : BB.instructionsWithoutDebug()) {
// TODO: Handle call instructions differently.
// For now, we treat them like other instructions
Embedding ArgEmb(Dimension, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should Embedding have the ctor have the 0 initial value as a default value for the ctor argument? (can be some follow-up nfc)

for (const auto &Op : I.operands()) {
// If the operand is defined elsewhere, we use its embedding
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
auto DefIt = InstVecMap.find(DefInst);
assert(DefIt != InstVecMap.end() &&
"Instruction should have been processed before its operands");
ArgEmb += DefIt->second;
continue;
}
// If the operand is not defined by an instruction, we use the vocabulary
else {
LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
<< *Op << "=" << Vocab[Op][0] << "\n");
ArgEmb += Vocab[Op];
}
}
// Create the instruction vector by combining opcode, type, and arguments
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
BBVecMap[&BB] = BBVector;
}

// ==----------------------------------------------------------------------===//
Expand Down Expand Up @@ -552,8 +601,11 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");

for (Function &F : M) {
std::unique_ptr<Embedder> Emb =
Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
std::unique_ptr<Embedder> Emb;
if (IR2VecEmbeddingKind == IR2VecKind::Symbolic)
Emb = Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
else
Emb = Embedder::create(IR2VecKind::FlowAware, F, Vocabulary);
if (!Emb) {
OS << "Error creating IR2Vec embeddings \n";
continue;
Expand Down
72 changes: 72 additions & 0 deletions llvm/test/Analysis/IR2Vec/basic-flowaware.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG

define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
entry:
%a.addr = alloca i32, align 4
%b.addr = alloca float, align 4
store i32 %a, ptr %a.addr, align 4
store float %b, ptr %b.addr, align 4
%0 = load i32, ptr %a.addr, align 4
%1 = load i32, ptr %a.addr, align 4
%mul = mul nsw i32 %0, %1
%conv = sitofp i32 %mul to float
%2 = load float, ptr %b.addr, align 4
%add = fadd float %conv, %2
ret float %add
}

; 3D-CHECK-OPC: IR2Vec embeddings for function _Z3abcif:
; 3D-CHECK-OPC-NEXT: Function vector: [ 3630.00 3672.00 3714.00 ]
; 3D-CHECK-OPC-NEXT: Basic block vectors:
; 3D-CHECK-OPC-NEXT: Basic block: entry:
; 3D-CHECK-OPC-NEXT: [ 3630.00 3672.00 3714.00 ]
; 3D-CHECK-OPC-NEXT: Instruction vectors:
; 3D-CHECK-OPC-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 188.00 190.00 192.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 188.00 190.00 192.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 419.00 424.00 429.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 549.00 555.00 561.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 185.00 187.00 189.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: %add = fadd float %conv, %2 [ 774.00 783.00 792.00 ]
; 3D-CHECK-OPC-NEXT: Instruction: ret float %add [ 775.00 785.00 795.00 ]

; 3D-CHECK-TYPE: IR2Vec embeddings for function _Z3abcif:
; 3D-CHECK-TYPE-NEXT: Function vector: [ 355.50 376.50 397.50 ]
; 3D-CHECK-TYPE-NEXT: Basic block vectors:
; 3D-CHECK-TYPE-NEXT: Basic block: entry:
; 3D-CHECK-TYPE-NEXT: [ 355.50 376.50 397.50 ]
; 3D-CHECK-TYPE-NEXT: Instruction vectors:
; 3D-CHECK-TYPE-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 12.50 13.00 13.50 ]
; 3D-CHECK-TYPE-NEXT: Instruction: %b.addr = alloca float, align 4 [ 12.50 13.00 13.50 ]
; 3D-CHECK-TYPE-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 14.50 15.50 16.50 ]
; 3D-CHECK-TYPE-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 14.50 15.50 16.50 ]
; 3D-CHECK-TYPE-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
; 3D-CHECK-TYPE-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
; 3D-CHECK-TYPE-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 53.50 56.00 58.50 ]
; 3D-CHECK-TYPE-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 54.00 57.00 60.00 ]
; 3D-CHECK-TYPE-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 13.00 14.00 15.00 ]
; 3D-CHECK-TYPE-NEXT: Instruction: %add = fadd float %conv, %2 [ 67.50 72.00 76.50 ]
; 3D-CHECK-TYPE-NEXT: Instruction: ret float %add [ 69.50 74.50 79.50 ]

; 3D-CHECK-ARG: IR2Vec embeddings for function _Z3abcif:
; 3D-CHECK-ARG-NEXT: Function vector: [ 27.80 31.60 35.40 ]
; 3D-CHECK-ARG-NEXT: Basic block vectors:
; 3D-CHECK-ARG-NEXT: Basic block: entry:
; 3D-CHECK-ARG-NEXT: [ 27.80 31.60 35.40 ]
; 3D-CHECK-ARG-NEXT: Instruction vectors:
; 3D-CHECK-ARG-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 1.40 1.60 1.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: %b.addr = alloca float, align 4 [ 1.40 1.60 1.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 3.40 3.80 4.20 ]
; 3D-CHECK-ARG-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 3.40 3.80 4.20 ]
; 3D-CHECK-ARG-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 2.80 3.20 3.60 ]
; 3D-CHECK-ARG-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 2.80 3.20 3.60 ]
; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 1.40 1.60 1.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.20 4.80 5.40 ]
; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 4.20 4.80 5.40 ]
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK


define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
entry:
%a.addr = alloca i32, align 4
Expand Down Expand Up @@ -74,11 +70,3 @@ entry:
; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 0.80 1.00 1.20 ]
; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.00 4.40 4.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 2.00 2.20 2.40 ]

; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file

; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file

; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file

; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
27 changes: 27 additions & 0 deletions llvm/test/Analysis/IR2Vec/basic-vocab.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK

define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
entry:
%a.addr = alloca i32, align 4
%b.addr = alloca float, align 4
store i32 %a, ptr %a.addr, align 4
store float %b, ptr %b.addr, align 4
%0 = load i32, ptr %a.addr, align 4
%1 = load i32, ptr %a.addr, align 4
%mul = mul nsw i32 %0, %1
%conv = sitofp i32 %mul to float
%2 = load float, ptr %b.addr, align 4
%add = fadd float %conv, %2
ret float %add
}

; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file

; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file

; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file

; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
Loading