@@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
153
153
// Embedder and its subclasses
154
154
// ===----------------------------------------------------------------------===//
155
155
156
- Embedder::Embedder (const Function &F, const Vocabulary &Vocab)
157
- : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
158
- OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
159
- FuncVector(Embedding(Dimension)) {}
160
-
161
156
std::unique_ptr<Embedder> Embedder::create (IR2VecKind Mode, const Function &F,
162
157
const Vocabulary &Vocab) {
163
158
switch (Mode) {
@@ -169,110 +164,85 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
169
164
return nullptr ;
170
165
}
171
166
172
- const InstEmbeddingsMap &Embedder::getInstVecMap () const {
173
- if (InstVecMap.empty ())
174
- computeEmbeddings ();
175
- return InstVecMap;
176
- }
177
-
178
- const BBEmbeddingsMap &Embedder::getBBVecMap () const {
179
- if (BBVecMap.empty ())
180
- computeEmbeddings ();
181
- return BBVecMap;
182
- }
183
-
184
- const Embedding &Embedder::getBBVector (const BasicBlock &BB) const {
185
- auto It = BBVecMap.find (&BB);
186
- if (It != BBVecMap.end ())
187
- return It->second ;
188
- computeEmbeddings (BB);
189
- return BBVecMap[&BB];
190
- }
167
+ Embedding Embedder::computeEmbeddings () const {
168
+ Embedding FuncVector (Dimension, 0.0 );
191
169
192
- const Embedding &Embedder::getFunctionVector () const {
193
- // Currently, we always (re)compute the embeddings for the function.
194
- // This is cheaper than caching the vector.
195
- computeEmbeddings ();
196
- return FuncVector;
197
- }
198
-
199
- void Embedder::computeEmbeddings () const {
200
170
if (F.isDeclaration ())
201
- return ;
202
-
203
- FuncVector = Embedding (Dimension, 0.0 );
171
+ return FuncVector;
204
172
205
173
// Consider only the basic blocks that are reachable from entry
206
- for (const BasicBlock *BB : depth_first (&F)) {
207
- computeEmbeddings (*BB);
208
- FuncVector += BBVecMap[BB];
209
- }
174
+ for (const BasicBlock *BB : depth_first (&F))
175
+ FuncVector += computeEmbeddings (*BB);
176
+ return FuncVector;
210
177
}
211
178
212
- void SymbolicEmbedder ::computeEmbeddings (const BasicBlock &BB) const {
179
+ Embedding Embedder ::computeEmbeddings (const BasicBlock &BB) const {
213
180
Embedding BBVector (Dimension, 0 );
214
181
215
182
// We consider only the non-debug and non-pseudo instructions
216
- for (const auto &I : BB.instructionsWithoutDebug ()) {
217
- Embedding ArgEmb (Dimension, 0 );
218
- for (const auto &Op : I.operands ())
219
- ArgEmb += Vocab[*Op];
220
- auto InstVector =
221
- Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
222
- if (const auto *IC = dyn_cast<CmpInst>(&I))
223
- InstVector += Vocab[IC->getPredicate ()];
224
- InstVecMap[&I] = InstVector;
225
- BBVector += InstVector;
226
- }
227
- BBVecMap[&BB] = BBVector;
228
- }
229
-
230
- void FlowAwareEmbedder::computeEmbeddings (const BasicBlock &BB) const {
231
- Embedding BBVector (Dimension, 0 );
183
+ for (const auto &I : BB.instructionsWithoutDebug ())
184
+ BBVector += computeEmbeddings (I);
185
+ return BBVector;
186
+ }
187
+
188
+ Embedding SymbolicEmbedder::computeEmbeddings (const Instruction &I) const {
189
+ // Currently, we always (re)compute the embeddings for symbolic embedder.
190
+ // This is cheaper than caching the vectors.
191
+ Embedding ArgEmb (Dimension, 0 );
192
+ for (const auto &Op : I.operands ())
193
+ ArgEmb += Vocab[*Op];
194
+ auto InstVector =
195
+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
196
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
197
+ InstVector += Vocab[IC->getPredicate ()];
198
+ return InstVector;
199
+ }
200
+
201
+ Embedding FlowAwareEmbedder::computeEmbeddings (const Instruction &I) const {
202
+ // If we have already computed the embedding for this instruction, return it
203
+ auto It = InstVecMap.find (&I);
204
+ if (It != InstVecMap.end ())
205
+ return It->second ;
232
206
233
- // We consider only the non-debug and non-pseudo instructions
234
- for (const auto &I : BB.instructionsWithoutDebug ()) {
235
- // TODO: Handle call instructions differently.
236
- // For now, we treat them like other instructions
237
- Embedding ArgEmb (Dimension, 0 );
238
- for (const auto &Op : I.operands ()) {
239
- // If the operand is defined elsewhere, we use its embedding
240
- if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
241
- auto DefIt = InstVecMap.find (DefInst);
242
- // Fixme (#159171): Ideally we should never miss an instruction
243
- // embedding here.
244
- // But when we have cyclic dependencies (e.g., phi
245
- // nodes), we might miss the embedding. In such cases, we fall back to
246
- // using the vocabulary embedding. This can be fixed by iterating to a
247
- // fixed-point, or by using a simple solver for the set of simultaneous
248
- // equations.
249
- // Another case when we might miss an instruction embedding is when
250
- // the operand instruction is in a different basic block that has not
251
- // been processed yet. This can be fixed by processing the basic blocks
252
- // in a topological order.
253
- if (DefIt != InstVecMap.end ())
254
- ArgEmb += DefIt->second ;
255
- else
256
- ArgEmb += Vocab[*Op];
257
- }
258
- // If the operand is not defined by an instruction, we use the vocabulary
259
- else {
260
- LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
261
- << *Op << " =" << Vocab[*Op][0 ] << " \n " );
207
+ // TODO: Handle call instructions differently.
208
+ // For now, we treat them like other instructions
209
+ Embedding ArgEmb (Dimension, 0 );
210
+ for (const auto &Op : I.operands ()) {
211
+ // If the operand is defined elsewhere, we use its embedding
212
+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
213
+ auto DefIt = InstVecMap.find (DefInst);
214
+ // Fixme (#159171): Ideally we should never miss an instruction
215
+ // embedding here.
216
+ // But when we have cyclic dependencies (e.g., phi
217
+ // nodes), we might miss the embedding. In such cases, we fall back to
218
+ // using the vocabulary embedding. This can be fixed by iterating to a
219
+ // fixed-point, or by using a simple solver for the set of simultaneous
220
+ // equations.
221
+ // Another case when we might miss an instruction embedding is when
222
+ // the operand instruction is in a different basic block that has not
223
+ // been processed yet. This can be fixed by processing the basic blocks
224
+ // in a topological order.
225
+ if (DefIt != InstVecMap.end ())
226
+ ArgEmb += DefIt->second ;
227
+ else
262
228
ArgEmb += Vocab[*Op];
263
- }
264
229
}
265
- // Create the instruction vector by combining opcode, type, and arguments
266
- // embeddings
267
- auto InstVector =
268
- Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
269
- // Add compare predicate embedding as an additional operand if applicable
270
- if (const auto *IC = dyn_cast<CmpInst>(&I))
271
- InstVector += Vocab[IC->getPredicate ()];
272
- InstVecMap[&I] = InstVector;
273
- BBVector += InstVector;
230
+ // If the operand is not defined by an instruction, we use the
231
+ // vocabulary
232
+ else {
233
+ LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
234
+ << *Op << " =" << Vocab[*Op][0 ] << " \n " );
235
+ ArgEmb += Vocab[*Op];
236
+ }
274
237
}
275
- BBVecMap[&BB] = BBVector;
238
+ // Create the instruction vector by combining opcode, type, and arguments
239
+ // embeddings
240
+ auto InstVector =
241
+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
242
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
243
+ InstVector += Vocab[IC->getPredicate ()];
244
+ InstVecMap[&I] = InstVector;
245
+ return InstVector;
276
246
}
277
247
278
248
// ==----------------------------------------------------------------------===//
@@ -695,25 +665,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
695
665
Emb->getFunctionVector ().print (OS);
696
666
697
667
OS << " Basic block vectors:\n " ;
698
- const auto &BBMap = Emb->getBBVecMap ();
699
668
for (const BasicBlock &BB : F) {
700
- auto It = BBMap.find (&BB);
701
- if (It != BBMap.end ()) {
702
- OS << " Basic block: " << BB.getName () << " :\n " ;
703
- It->second .print (OS);
704
- }
669
+ OS << " Basic block: " << BB.getName () << " :\n " ;
670
+ Emb->getBBVector (BB).print (OS);
705
671
}
706
672
707
673
OS << " Instruction vectors:\n " ;
708
- const auto &InstMap = Emb->getInstVecMap ();
709
674
for (const BasicBlock &BB : F) {
710
675
for (const Instruction &I : BB) {
711
- auto It = InstMap.find (&I);
712
- if (It != InstMap.end ()) {
713
- OS << " Instruction: " ;
714
- I.print (OS);
715
- It->second .print (OS);
716
- }
676
+ OS << " Instruction: " ;
677
+ I.print (OS);
678
+ Emb->getInstVector (I).print (OS);
717
679
}
718
680
}
719
681
}
0 commit comments