@@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
153153// Embedder and its subclasses
154154// ===----------------------------------------------------------------------===//
155155
156- Embedder::Embedder (const Function &F, const Vocabulary &Vocab)
157- : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
158- OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
159- FuncVector(Embedding(Dimension)) {}
160-
161156std::unique_ptr<Embedder> Embedder::create (IR2VecKind Mode, const Function &F,
162157 const Vocabulary &Vocab) {
163158 switch (Mode) {
@@ -169,110 +164,85 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
169164 return nullptr ;
170165}
171166
172- const InstEmbeddingsMap &Embedder::getInstVecMap () const {
173- if (InstVecMap.empty ())
174- computeEmbeddings ();
175- return InstVecMap;
176- }
177-
178- const BBEmbeddingsMap &Embedder::getBBVecMap () const {
179- if (BBVecMap.empty ())
180- computeEmbeddings ();
181- return BBVecMap;
182- }
183-
184- const Embedding &Embedder::getBBVector (const BasicBlock &BB) const {
185- auto It = BBVecMap.find (&BB);
186- if (It != BBVecMap.end ())
187- return It->second ;
188- computeEmbeddings (BB);
189- return BBVecMap[&BB];
190- }
167+ Embedding Embedder::computeEmbeddings () const {
168+ Embedding FuncVector (Dimension, 0.0 );
191169
192- const Embedding &Embedder::getFunctionVector () const {
193- // Currently, we always (re)compute the embeddings for the function.
194- // This is cheaper than caching the vector.
195- computeEmbeddings ();
196- return FuncVector;
197- }
198-
199- void Embedder::computeEmbeddings () const {
200170 if (F.isDeclaration ())
201- return ;
202-
203- FuncVector = Embedding (Dimension, 0.0 );
171+ return FuncVector;
204172
205173 // Consider only the basic blocks that are reachable from entry
206- for (const BasicBlock *BB : depth_first (&F)) {
207- computeEmbeddings (*BB);
208- FuncVector += BBVecMap[BB];
209- }
174+ for (const BasicBlock *BB : depth_first (&F))
175+ FuncVector += computeEmbeddings (*BB);
176+ return FuncVector;
210177}
211178
212- void SymbolicEmbedder ::computeEmbeddings (const BasicBlock &BB) const {
179+ Embedding Embedder ::computeEmbeddings (const BasicBlock &BB) const {
213180 Embedding BBVector (Dimension, 0 );
214181
215182 // We consider only the non-debug and non-pseudo instructions
216- for (const auto &I : BB.instructionsWithoutDebug ()) {
217- Embedding ArgEmb (Dimension, 0 );
218- for (const auto &Op : I.operands ())
219- ArgEmb += Vocab[*Op];
220- auto InstVector =
221- Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
222- if (const auto *IC = dyn_cast<CmpInst>(&I))
223- InstVector += Vocab[IC->getPredicate ()];
224- InstVecMap[&I] = InstVector;
225- BBVector += InstVector;
226- }
227- BBVecMap[&BB] = BBVector;
228- }
229-
230- void FlowAwareEmbedder::computeEmbeddings (const BasicBlock &BB) const {
231- Embedding BBVector (Dimension, 0 );
183+ for (const auto &I : BB.instructionsWithoutDebug ())
184+ BBVector += computeEmbeddings (I);
185+ return BBVector;
186+ }
187+
188+ Embedding SymbolicEmbedder::computeEmbeddings (const Instruction &I) const {
189+ // Currently, we always (re)compute the embeddings for symbolic embedder.
190+ // This is cheaper than caching the vectors.
191+ Embedding ArgEmb (Dimension, 0 );
192+ for (const auto &Op : I.operands ())
193+ ArgEmb += Vocab[*Op];
194+ auto InstVector =
195+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
196+ if (const auto *IC = dyn_cast<CmpInst>(&I))
197+ InstVector += Vocab[IC->getPredicate ()];
198+ return InstVector;
199+ }
200+
201+ Embedding FlowAwareEmbedder::computeEmbeddings (const Instruction &I) const {
202+ // If we have already computed the embedding for this instruction, return it
203+ auto It = InstVecMap.find (&I);
204+ if (It != InstVecMap.end ())
205+ return It->second ;
232206
233- // We consider only the non-debug and non-pseudo instructions
234- for (const auto &I : BB.instructionsWithoutDebug ()) {
235- // TODO: Handle call instructions differently.
236- // For now, we treat them like other instructions
237- Embedding ArgEmb (Dimension, 0 );
238- for (const auto &Op : I.operands ()) {
239- // If the operand is defined elsewhere, we use its embedding
240- if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
241- auto DefIt = InstVecMap.find (DefInst);
242- // Fixme (#159171): Ideally we should never miss an instruction
243- // embedding here.
244- // But when we have cyclic dependencies (e.g., phi
245- // nodes), we might miss the embedding. In such cases, we fall back to
246- // using the vocabulary embedding. This can be fixed by iterating to a
247- // fixed-point, or by using a simple solver for the set of simultaneous
248- // equations.
249- // Another case when we might miss an instruction embedding is when
250- // the operand instruction is in a different basic block that has not
251- // been processed yet. This can be fixed by processing the basic blocks
252- // in a topological order.
253- if (DefIt != InstVecMap.end ())
254- ArgEmb += DefIt->second ;
255- else
256- ArgEmb += Vocab[*Op];
257- }
258- // If the operand is not defined by an instruction, we use the vocabulary
259- else {
260- LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
261- << *Op << " =" << Vocab[*Op][0 ] << " \n " );
207+ // TODO: Handle call instructions differently.
208+ // For now, we treat them like other instructions
209+ Embedding ArgEmb (Dimension, 0 );
210+ for (const auto &Op : I.operands ()) {
211+ // If the operand is defined elsewhere, we use its embedding
212+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
213+ auto DefIt = InstVecMap.find (DefInst);
214+ // Fixme (#159171): Ideally we should never miss an instruction
215+ // embedding here.
216+ // But when we have cyclic dependencies (e.g., phi
217+ // nodes), we might miss the embedding. In such cases, we fall back to
218+ // using the vocabulary embedding. This can be fixed by iterating to a
219+ // fixed-point, or by using a simple solver for the set of simultaneous
220+ // equations.
221+ // Another case when we might miss an instruction embedding is when
222+ // the operand instruction is in a different basic block that has not
223+ // been processed yet. This can be fixed by processing the basic blocks
224+ // in a topological order.
225+ if (DefIt != InstVecMap.end ())
226+ ArgEmb += DefIt->second ;
227+ else
262228 ArgEmb += Vocab[*Op];
263- }
264229 }
265- // Create the instruction vector by combining opcode, type, and arguments
266- // embeddings
267- auto InstVector =
268- Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
269- // Add compare predicate embedding as an additional operand if applicable
270- if (const auto *IC = dyn_cast<CmpInst>(&I))
271- InstVector += Vocab[IC->getPredicate ()];
272- InstVecMap[&I] = InstVector;
273- BBVector += InstVector;
230+ // If the operand is not defined by an instruction, we use the
231+ // vocabulary
232+ else {
233+ LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
234+ << *Op << " =" << Vocab[*Op][0 ] << " \n " );
235+ ArgEmb += Vocab[*Op];
236+ }
274237 }
275- BBVecMap[&BB] = BBVector;
238+ // Create the instruction vector by combining opcode, type, and arguments
239+ // embeddings
240+ auto InstVector =
241+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
242+ if (const auto *IC = dyn_cast<CmpInst>(&I))
243+ InstVector += Vocab[IC->getPredicate ()];
244+ InstVecMap[&I] = InstVector;
245+ return InstVector;
276246}
277247
278248// ==----------------------------------------------------------------------===//
@@ -695,25 +665,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
695665 Emb->getFunctionVector ().print (OS);
696666
697667 OS << " Basic block vectors:\n " ;
698- const auto &BBMap = Emb->getBBVecMap ();
699668 for (const BasicBlock &BB : F) {
700- auto It = BBMap.find (&BB);
701- if (It != BBMap.end ()) {
702- OS << " Basic block: " << BB.getName () << " :\n " ;
703- It->second .print (OS);
704- }
669+ OS << " Basic block: " << BB.getName () << " :\n " ;
670+ Emb->getBBVector (BB).print (OS);
705671 }
706672
707673 OS << " Instruction vectors:\n " ;
708- const auto &InstMap = Emb->getInstVecMap ();
709674 for (const BasicBlock &BB : F) {
710675 for (const Instruction &I : BB) {
711- auto It = InstMap.find (&I);
712- if (It != InstMap.end ()) {
713- OS << " Instruction: " ;
714- I.print (OS);
715- It->second .print (OS);
716- }
676+ OS << " Instruction: " ;
677+ I.print (OS);
678+ Emb->getInstVector (I).print (OS);
717679 }
718680 }
719681 }
0 commit comments