Skip to content

Commit 6a8b450

Browse files
committed
Exposing weights of Opc, Types, Args
1 parent be7f312 commit 6a8b450

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ class Embedder {
7373

7474
/// Weights for different entities (like opcode, arguments, types)
7575
/// in the IR instructions to generate the vector representation.
76-
// FIXME: Defaults to the values used in the original algorithm. Can be
77-
// parameterized later.
78-
const float OpcWeight = 1.0, TypeWeight = 0.5, ArgWeight = 0.2;
76+
const float OpcWeight, TypeWeight, ArgWeight;
7977

8078
/// Dimension of the vector representation; captured from the input vocabulary
8179
const unsigned Dimension;

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ static cl::opt<std::string>
4141
VocabFile("ir2vec-vocab-path", cl::Optional,
4242
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
4343
cl::cat(IR2VecCategory));
44+
static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
45+
cl::init(1.0),
46+
cl::desc("Weight for opcode embeddings"),
47+
cl::cat(IR2VecCategory));
48+
static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
49+
cl::init(0.5),
50+
cl::desc("Weight for type embeddings"),
51+
cl::cat(IR2VecCategory));
52+
static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
53+
cl::init(0.2),
54+
cl::desc("Weight for argument embeddings"),
55+
cl::cat(IR2VecCategory));
4456

4557
AnalysisKey IR2VecVocabAnalysis::Key;
4658

@@ -54,7 +66,8 @@ AnalysisKey IR2VecVocabAnalysis::Key;
5466

5567
Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
5668
unsigned Dimension)
57-
: F(F), Vocabulary(Vocabulary), Dimension(Dimension) {}
69+
: F(F), Vocabulary(Vocabulary), Dimension(Dimension), OpcWeight(OpcWeight),
70+
TypeWeight(TypeWeight), ArgWeight(ArgWeight) {}
5871

5972
ErrorOr<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
6073
const Function &F,

0 commit comments

Comments
 (0)