-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[IR2Vec] Add support for flow-aware embeddings #152613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlgo @llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesThis patch introduces support for Flow-Aware embeddings in IR2Vec, which capture data flow information in addition to symbolic representations. Patch is 20.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152613.diff 6 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 17f41129fd4fa..3cfc206c94788 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -54,14 +54,19 @@ class IR2VecVocabAnalysis;
/// of the IR entities. Flow-aware embeddings build on top of symbolic
/// embeddings and additionally capture the flow information in the IR.
/// IR2VecKind is used to specify the type of embeddings to generate.
-/// Currently, only Symbolic embeddings are supported.
-enum class IR2VecKind { Symbolic };
+/// Note: Implementation of FlowAware embeddings is not same as the one
+/// described in the paper. The current implementation is a simplified version
+/// that captures the flow information (SSA-based use-defs) without tracing
+/// through memory level use-defs in the embedding computation described in the
+/// paper.
+enum class IR2VecKind { Symbolic, FlowAware };
namespace ir2vec {
LLVM_ABI extern cl::opt<float> OpcWeight;
LLVM_ABI extern cl::opt<float> TypeWeight;
LLVM_ABI extern cl::opt<float> ArgWeight;
+LLVM_ABI extern cl::opt<IR2VecKind> IR2VecEmbeddingKind;
/// Embedding is a datatype that wraps std::vector<double>. It provides
/// additional functionality for arithmetic and comparison operations.
@@ -257,9 +262,8 @@ class Embedder {
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
/// Helper function to compute embeddings. It generates embeddings for all
- /// the instructions and basic blocks in the function F. Logic of computing
- /// the embeddings is specific to the kind of embeddings being computed.
- virtual void computeEmbeddings() const = 0;
+ /// the instructions and basic blocks in the function F.
+ void computeEmbeddings() const;
/// Helper function to compute the embedding for a given basic block.
/// Specific to the kind of embeddings being computed.
@@ -296,7 +300,6 @@ class Embedder {
/// representations obtained from the Vocabulary.
class LLVM_ABI SymbolicEmbedder : public Embedder {
private:
- void computeEmbeddings() const override;
void computeEmbeddings(const BasicBlock &BB) const override;
public:
@@ -306,6 +309,20 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
}
};
+/// Class for computing the Flow-aware embeddings of IR2Vec.
+/// Flow-aware embeddings build on the vocabulary, just like Symbolic
+/// embeddings, and additionally capture the flow information in the IR.
+class LLVM_ABI FlowAwareEmbedder : public Embedder {
+private:
+ void computeEmbeddings(const BasicBlock &BB) const override;
+
+public:
+ FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
+ : Embedder(F, Vocab) {
+ FuncVector = Embedding(Dimension, 0);
+ }
+};
+
} // namespace ir2vec
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 95f30fd3f4275..0bea25ec26b8e 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
+cl::opt<IR2VecKind> IR2VecEmbeddingKind(
+ "ir2vec-kind", cl::Optional,
+ cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
+ "Generate symbolic embeddings"),
+ clEnumValN(IR2VecKind::FlowAware, "flow-aware",
+ "Generate flow-aware embeddings")),
+ cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
+ cl::cat(IR2VecCategory));
+
} // namespace ir2vec
} // namespace llvm
@@ -149,6 +158,8 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocab);
+ case IR2VecKind::FlowAware:
+ return std::make_unique<FlowAwareEmbedder>(F, Vocab);
}
return nullptr;
}
@@ -180,6 +191,17 @@ const Embedding &Embedder::getFunctionVector() const {
return FuncVector;
}
+void Embedder::computeEmbeddings() const {
+ if (F.isDeclaration())
+ return;
+
+ // Consider only the basic blocks that are reachable from entry
+ for (const BasicBlock *BB : depth_first(&F)) {
+ computeEmbeddings(*BB);
+ FuncVector += BBVecMap[BB];
+ }
+}
+
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
@@ -196,15 +218,46 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
BBVecMap[&BB] = BBVector;
}
-void SymbolicEmbedder::computeEmbeddings() const {
- if (F.isDeclaration())
- return;
+void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+ Embedding BBVector(Dimension, 0);
- // Consider only the basic blocks that are reachable from entry
- for (const BasicBlock *BB : depth_first(&F)) {
- computeEmbeddings(*BB);
- FuncVector += BBVecMap[BB];
+ // We consider only the non-debug and non-pseudo instructions
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ // TODO: Handle call instructions differently.
+ // For now, we treat them like other instructions
+ Embedding ArgEmb(Dimension, 0);
+ for (const auto &Op : I.operands()) {
+ // If the operand is defined elsewhere, we use its embedding
+ if (const Instruction *DefInst = dyn_cast<Instruction>(Op)) {
+ auto DefIt = InstVecMap.find(DefInst);
+ assert(DefIt != InstVecMap.end() &&
+ "Instruction should have been processed before its operands");
+ if (DefIt != InstVecMap.end()) {
+ ArgEmb += DefIt->second;
+ continue;
+ }
+ // If the definition is not in the map, we use the vocabulary
+ // Not expected, but handle it gracefully
+ LLVM_DEBUG(dbgs() << "Warning: Operand defined by instruction not "
+ "found in InstVecMap: "
+ << *DefInst << "\n");
+ ArgEmb += Vocab[Op];
+ }
+ // If the operand is not defined by an instruction, we use the vocabulary
+ else {
+ LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+ << *Op << "=" << Vocab[Op][0] << "\n");
+ ArgEmb += Vocab[Op];
+ }
+ }
+ // Create the instruction vector by combining opcode, type, and arguments
+ // embeddings
+ auto InstVector =
+ Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ InstVecMap[&I] = InstVector;
+ BBVector += InstVector;
}
+ BBVecMap[&BB] = BBVector;
}
// ==----------------------------------------------------------------------===//
@@ -552,8 +605,11 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
for (Function &F : M) {
- std::unique_ptr<Embedder> Emb =
- Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ std::unique_ptr<Embedder> Emb;
+ if (IR2VecEmbeddingKind == IR2VecKind::Symbolic)
+ Emb = Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ else
+ Emb = Embedder::create(IR2VecKind::FlowAware, F, Vocabulary);
if (!Emb) {
OS << "Error creating IR2Vec embeddings \n";
continue;
diff --git a/llvm/test/Analysis/IR2Vec/basic-flowaware.ll b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
new file mode 100644
index 0000000000000..4a7f970a9cf91
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
@@ -0,0 +1,72 @@
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
+
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca float, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store float %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %a.addr, align 4
+ %mul = mul nsw i32 %0, %1
+ %conv = sitofp i32 %mul to float
+ %2 = load float, ptr %b.addr, align 4
+ %add = fadd float %conv, %2
+ ret float %add
+}
+
+; 3D-CHECK-OPC: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-OPC-NEXT: Function vector: [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Basic block vectors:
+; 3D-CHECK-OPC-NEXT: Basic block: entry:
+; 3D-CHECK-OPC-NEXT: [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction vectors:
+; 3D-CHECK-OPC-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 419.00 424.00 429.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 549.00 555.00 561.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %add = fadd float %conv, %2 [ 774.00 783.00 792.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: ret float %add [ 775.00 785.00 795.00 ]
+
+; 3D-CHECK-TYPE: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-TYPE-NEXT: Function vector: [ 355.50 376.50 397.50 ]
+; 3D-CHECK-TYPE-NEXT: Basic block vectors:
+; 3D-CHECK-TYPE-NEXT: Basic block: entry:
+; 3D-CHECK-TYPE-NEXT: [ 355.50 376.50 397.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction vectors:
+; 3D-CHECK-TYPE-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 12.50 13.00 13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %b.addr = alloca float, align 4 [ 12.50 13.00 13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 14.50 15.50 16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 14.50 15.50 16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 53.50 56.00 58.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 54.00 57.00 60.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 13.00 14.00 15.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %add = fadd float %conv, %2 [ 67.50 72.00 76.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: ret float %add [ 69.50 74.50 79.50 ]
+
+; 3D-CHECK-ARG: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-ARG-NEXT: Function vector: [ 27.80 31.60 35.40 ]
+; 3D-CHECK-ARG-NEXT: Basic block vectors:
+; 3D-CHECK-ARG-NEXT: Basic block: entry:
+; 3D-CHECK-ARG-NEXT: [ 27.80 31.60 35.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction vectors:
+; 3D-CHECK-ARG-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %b.addr = alloca float, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 3.40 3.80 4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 3.40 3.80 4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 2.80 3.20 3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 2.80 3.20 3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.20 4.80 5.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 4.20 4.80 5.40 ]
diff --git a/llvm/test/Analysis/IR2Vec/basic.ll b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
similarity index 81%
rename from llvm/test/Analysis/IR2Vec/basic.ll
rename to llvm/test/Analysis/IR2Vec/basic-symbolic.ll
index cb0544fb19860..35abd3c7fa269 100644
--- a/llvm/test/Analysis/IR2Vec/basic.ll
+++ b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
@@ -1,11 +1,7 @@
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
-
+
define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
entry:
%a.addr = alloca i32, align 4
@@ -74,11 +70,3 @@ entry:
; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 0.80 1.00 1.20 ]
; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.00 4.40 4.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 2.00 2.20 2.40 ]
-
-; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
-
-; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
-
-; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
-
-; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/test/Analysis/IR2Vec/basic-vocab.ll b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
new file mode 100644
index 0000000000000..eeeee831814a8
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
@@ -0,0 +1,27 @@
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
+
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca float, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store float %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %a.addr, align 4
+ %mul = mul nsw i32 %0, %1
+ %conv = sitofp i32 %mul to float
+ %2 = load float, ptr %b.addr, align 4
+ %add = fadd float %conv, %2
+ ret float %add
+}
+
+; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
+
+; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
+
+; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
+
+; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index e288585033c53..f6846963b3e2f 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,6 @@ namespace {
class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
- void computeEmbeddings() const override {}
void computeEmbeddings(const BasicBlock &BB) const override {}
};
@@ -258,6 +257,18 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
EXPECT_NE(Emb, nullptr);
}
+TEST(IR2VecTest, CreateFlowAwareEmbedder) {
+ Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
+
+ LLVMContext Ctx;
+ Module M("M", Ctx);
+ FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ EXPECT_NE(Emb, nullptr);
+}
+
TEST(IR2VecTest, CreateInvalidMode) {
Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
@@ -307,10 +318,12 @@ class IR2VecTestFixture : public ::testing::Test {
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
+ F->print(llvm::errs());
+ F->dump();
}
};
-TEST_F(IR2VecTestFixture, GetInstVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -327,7 +340,24 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 16.8)));
}
-TEST_F(IR2VecTestFixture, GetBBVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &InstMap = Emb->getInstVecMap();
+
+ EXPECT_EQ(InstMap.size(), 2u);
+ EXPECT_TRUE(InstMap.count(AddInst));
+ EXPECT_TRUE(InstMap.count(RetInst));
+
+ EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+ EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
+
+ EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.6)));
+ EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.2)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -342,7 +372,22 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.4)));
}
-TEST_F(IR2VecTestFixture, GetBBVector) {
+TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &BBMap = Emb->getBBVecMap();
+
+ EXPECT_EQ(BBMap.size(), 1u);
+ EXPECT_TRUE(BBMap.count(BB));
+ EXPECT_EQ(BBMap.at(BB).size(), 2u);
+
+ // BB vector should be sum of add and ret: {27.6, 27.6} + {35.2, 35.2} =
+ // {62.8, 62.8}
+ EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 62.8)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -352,7 +397,17 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.4)));
}
-TEST_F(IR2VecTestFixture, GetFunctionVector) {
+TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &BBVec ...
[truncated]
|
/// the embeddings is specific to the kind of embeddings being computed. | ||
virtual void computeEmbeddings() const = 0; | ||
/// the instructions and basic blocks in the function F. | ||
void computeEmbeddings() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did it lose its virtual
?
Embedding ArgEmb(Dimension, 0); | ||
for (const auto &Op : I.operands()) { | ||
// If the operand is defined elsewhere, we use its embedding | ||
if (const Instruction *DefInst = dyn_cast<Instruction>(Op)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can const auto *DefInst
, it's clear what the type is from the rhs
if (const Instruction *DefInst = dyn_cast<Instruction>(Op)) { | ||
auto DefIt = InstVecMap.find(DefInst); | ||
assert(DefIt != InstVecMap.end() && | ||
"Instruction should have been processed before its operands"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the assert and the if...?
This patch introduces support for Flow-Aware embeddings in IR2Vec, which capture data flow information in addition to symbolic representations.