-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[NFC][IR2Vec] Initialize Embedding vectors with zeros by default #155690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesInitialize Full diff: https://github.com/llvm/llvm-project/pull/155690.diff 2 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 44932a3385e16..6fb8f736da092 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -92,7 +92,7 @@ struct Embedding {
Embedding(std::vector<double> &&V) : Data(std::move(V)) {}
Embedding(std::initializer_list<double> IL) : Data(IL) {}
- explicit Embedding(size_t Size) : Data(Size) {}
+ explicit Embedding(size_t Size) : Data(Size, 0.0) {}
Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {}
size_t size() const { return Data.size(); }
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 565ec2a6287b7..6b90f1aabacfa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -155,7 +155,7 @@ void Embedding::print(raw_ostream &OS) const {
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
- FuncVector(Embedding(Dimension, 0)) {}
+ FuncVector(Embedding(Dimension)) {}
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
@@ -472,7 +472,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
@@ -487,7 +487,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Types - only canonical types are present in vocabulary
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
@@ -503,7 +503,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
|
|
@llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesInitialize Full diff: https://github.com/llvm/llvm-project/pull/155690.diff 2 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 44932a3385e16..6fb8f736da092 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -92,7 +92,7 @@ struct Embedding {
Embedding(std::vector<double> &&V) : Data(std::move(V)) {}
Embedding(std::initializer_list<double> IL) : Data(IL) {}
- explicit Embedding(size_t Size) : Data(Size) {}
+ explicit Embedding(size_t Size) : Data(Size, 0.0) {}
Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {}
size_t size() const { return Data.size(); }
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 565ec2a6287b7..6b90f1aabacfa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -155,7 +155,7 @@ void Embedding::print(raw_ostream &OS) const {
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
- FuncVector(Embedding(Dimension, 0)) {}
+ FuncVector(Embedding(Dimension)) {}
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
@@ -472,7 +472,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
@@ -487,7 +487,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Types - only canonical types are present in vocabulary
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
@@ -503,7 +503,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
|
9a18f1c to
7ddfeaa
Compare
51b1cd4 to
7ec3927
Compare
7ddfeaa to
c809d9d
Compare
f01119a to
f82d77f
Compare
c809d9d to
18675c6
Compare
18675c6 to
374bfa9
Compare
f0b3c0c to
a9edd27
Compare
97560b9 to
a20fb0e
Compare
a9edd27 to
493f471
Compare
a20fb0e to
b21b641
Compare
ec2e1e1 to
da83ad8
Compare
b21b641 to
5c658e1
Compare
5c658e1 to
0d74ab7
Compare
da83ad8 to
8c8500c
Compare
8c8500c to
fd4e1df
Compare
Merge activity
|

Initialize
Embeddingvectors with zeros by default when only size is provided.