@@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
153153// Embedder and its subclasses
154154// ===----------------------------------------------------------------------===//
155155
156- Embedder::Embedder (const Function &F, const Vocabulary &Vocab)
157- : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
158- OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
159- FuncVector(Embedding(Dimension)) {}
160-
161156std::unique_ptr<Embedder> Embedder::create (IR2VecKind Mode, const Function &F,
162157 const Vocabulary &Vocab) {
163158 switch (Mode) {
@@ -169,110 +164,83 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
169164 return nullptr ;
170165}
171166
172- const InstEmbeddingsMap &Embedder::getInstVecMap () const {
173- if (InstVecMap.empty ())
174- computeEmbeddings ();
175- return InstVecMap;
176- }
177-
178- const BBEmbeddingsMap &Embedder::getBBVecMap () const {
179- if (BBVecMap.empty ())
180- computeEmbeddings ();
181- return BBVecMap;
182- }
183-
184- const Embedding &Embedder::getBBVector (const BasicBlock &BB) const {
185- auto It = BBVecMap.find (&BB);
186- if (It != BBVecMap.end ())
187- return It->second ;
188- computeEmbeddings (BB);
189- return BBVecMap[&BB];
190- }
167+ Embedding Embedder::computeEmbeddings () const {
168+ Embedding FuncVector (Dimension, 0 );
191169
192- const Embedding &Embedder::getFunctionVector () const {
193- // Currently, we always (re)compute the embeddings for the function.
194- // This is cheaper than caching the vector.
195- computeEmbeddings ();
196- return FuncVector;
197- }
198-
199- void Embedder::computeEmbeddings () const {
200170 if (F.isDeclaration ())
201- return ;
202-
203- FuncVector = Embedding (Dimension, 0.0 );
171+ return FuncVector;
204172
205173 // Consider only the basic blocks that are reachable from entry
206- for (const BasicBlock *BB : depth_first (&F)) {
207- computeEmbeddings (*BB);
208- FuncVector += BBVecMap[BB];
209- }
210- }
211-
212- void SymbolicEmbedder::computeEmbeddings (const BasicBlock &BB) const {
213- Embedding BBVector (Dimension, 0 );
214-
215- // We consider only the non-debug and non-pseudo instructions
216- for (const auto &I : BB.instructionsWithoutDebug ()) {
217- Embedding ArgEmb (Dimension, 0 );
218- for (const auto &Op : I.operands ())
219- ArgEmb += Vocab[*Op];
220- auto InstVector =
221- Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
222- if (const auto *IC = dyn_cast<CmpInst>(&I))
223- InstVector += Vocab[IC->getPredicate ()];
224- InstVecMap[&I] = InstVector;
225- BBVector += InstVector;
226- }
227- BBVecMap[&BB] = BBVector;
174+ for (const BasicBlock *BB : depth_first (&F))
175+ FuncVector += computeEmbeddings (*BB);
176+ return FuncVector;
228177}
229178
230- void FlowAwareEmbedder ::computeEmbeddings (const BasicBlock &BB) const {
179+ Embedding Embedder ::computeEmbeddings (const BasicBlock &BB) const {
231180 Embedding BBVector (Dimension, 0 );
181+ for (const Instruction &I : BB.instructionsWithoutDebug ())
182+ BBVector += computeEmbeddings (I);
183+ return BBVector;
184+ }
185+
186+ Embedding SymbolicEmbedder::computeEmbeddings (const Instruction &I) const {
187+ // Currently, we always (re)compute the embeddings for symbolic embedder.
188+ // This is cheaper than caching the vectors.
189+ Embedding ArgEmb (Dimension, 0 );
190+ for (const auto &Op : I.operands ())
191+ ArgEmb += Vocab[*Op];
192+ auto InstVector =
193+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
194+ if (const auto *IC = dyn_cast<CmpInst>(&I))
195+ InstVector += Vocab[IC->getPredicate ()];
196+ return InstVector;
197+ }
198+
199+ Embedding FlowAwareEmbedder::computeEmbeddings (const Instruction &I) const {
200+ // If we have already computed the embedding for this instruction, return it
201+ auto It = InstVecMap.find (&I);
202+ if (It != InstVecMap.end ())
203+ return It->second ;
232204
233- // We consider only the non-debug and non-pseudo instructions
234- for (const auto &I : BB.instructionsWithoutDebug ()) {
235- // TODO: Handle call instructions differently.
236- // For now, we treat them like other instructions
237- Embedding ArgEmb (Dimension, 0 );
238- for (const auto &Op : I.operands ()) {
239- // If the operand is defined elsewhere, we use its embedding
240- if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
241- auto DefIt = InstVecMap.find (DefInst);
242- // Fixme (#159171): Ideally we should never miss an instruction
243- // embedding here.
244- // But when we have cyclic dependencies (e.g., phi
245- // nodes), we might miss the embedding. In such cases, we fall back to
246- // using the vocabulary embedding. This can be fixed by iterating to a
247- // fixed-point, or by using a simple solver for the set of simultaneous
248- // equations.
249- // Another case when we might miss an instruction embedding is when
250- // the operand instruction is in a different basic block that has not
251- // been processed yet. This can be fixed by processing the basic blocks
252- // in a topological order.
253- if (DefIt != InstVecMap.end ())
254- ArgEmb += DefIt->second ;
255- else
256- ArgEmb += Vocab[*Op];
257- }
258- // If the operand is not defined by an instruction, we use the vocabulary
259- else {
260- LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
261- << *Op << " =" << Vocab[*Op][0 ] << " \n " );
205+ // TODO: Handle call instructions differently.
206+ // For now, we treat them like other instructions
207+ Embedding ArgEmb (Dimension, 0 );
208+ for (const auto &Op : I.operands ()) {
209+ // If the operand is defined elsewhere, we use its embedding
210+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
211+ auto DefIt = InstVecMap.find (DefInst);
212+ // Fixme (#159171): Ideally we should never miss an instruction
213+ // embedding here.
214+ // But when we have cyclic dependencies (e.g., phi
215+ // nodes), we might miss the embedding. In such cases, we fall back to
216+ // using the vocabulary embedding. This can be fixed by iterating to a
217+ // fixed-point, or by using a simple solver for the set of simultaneous
218+ // equations.
219+ // Another case when we might miss an instruction embedding is when
220+ // the operand instruction is in a different basic block that has not
221+ // been processed yet. This can be fixed by processing the basic blocks
222+ // in a topological order.
223+ if (DefIt != InstVecMap.end ())
224+ ArgEmb += DefIt->second ;
225+ else
262226 ArgEmb += Vocab[*Op];
263- }
264227 }
265- // Create the instruction vector by combining opcode, type, and arguments
266- // embeddings
267- auto InstVector =
268- Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
269- // Add compare predicate embedding as an additional operand if applicable
270- if (const auto *IC = dyn_cast<CmpInst>(&I))
271- InstVector += Vocab[IC->getPredicate ()];
272- InstVecMap[&I] = InstVector;
273- BBVector += InstVector;
228+ // If the operand is not defined by an instruction, we use the
229+ // vocabulary
230+ else {
231+ LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
232+ << *Op << " =" << Vocab[*Op][0 ] << " \n " );
233+ ArgEmb += Vocab[*Op];
234+ }
274235 }
275- BBVecMap[&BB] = BBVector;
236+ // Create the instruction vector by combining opcode, type, and arguments
237+ // embeddings
238+ auto InstVector =
239+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
240+ if (const auto *IC = dyn_cast<CmpInst>(&I))
241+ InstVector += Vocab[IC->getPredicate ()];
242+ InstVecMap[&I] = InstVector;
243+ return InstVector;
276244}
277245
278246// ==----------------------------------------------------------------------===//
@@ -695,25 +663,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
695663 Emb->getFunctionVector ().print (OS);
696664
697665 OS << " Basic block vectors:\n " ;
698- const auto &BBMap = Emb->getBBVecMap ();
699666 for (const BasicBlock &BB : F) {
700- auto It = BBMap.find (&BB);
701- if (It != BBMap.end ()) {
702- OS << " Basic block: " << BB.getName () << " :\n " ;
703- It->second .print (OS);
704- }
667+ OS << " Basic block: " << BB.getName () << " :\n " ;
668+ Emb->getBBVector (BB).print (OS);
705669 }
706670
707671 OS << " Instruction vectors:\n " ;
708- const auto &InstMap = Emb->getInstVecMap ();
709672 for (const BasicBlock &BB : F) {
710673 for (const Instruction &I : BB) {
711- auto It = InstMap.find (&I);
712- if (It != InstMap.end ()) {
713- OS << " Instruction: " ;
714- I.print (OS);
715- It->second .print (OS);
716- }
674+ OS << " Instruction: " ;
675+ I.print (OS);
676+ Emb->getInstVector (I).print (OS);
717677 }
718678 }
719679 }
0 commit comments