@@ -155,8 +155,8 @@ void Embedding::print(raw_ostream &OS) const {
155155
156156Embedder::Embedder (const Function &F, const Vocabulary &Vocab)
157157 : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
158- OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
159- FuncVector(Embedding(Dimension)) { }
158+ OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
159+ }
160160
161161std::unique_ptr<Embedder> Embedder::create (IR2VecKind Mode, const Function &F,
162162 const Vocabulary &Vocab) {
@@ -169,112 +169,104 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
169169 return nullptr ;
170170}
171171
172- const InstEmbeddingsMap &Embedder::getInstVecMap () const {
173- if (InstVecMap.empty ())
174- computeEmbeddings ();
175- return InstVecMap;
172+ Embedding Embedder::getInstVector (const Instruction &I) const {
173+ return computeEmbeddings (I);
176174}
177175
178- const BBEmbeddingsMap &Embedder::getBBVecMap () const {
179- if (BBVecMap.empty ())
180- computeEmbeddings ();
181- return BBVecMap;
176+ Embedding Embedder::getBBVector (const BasicBlock &BB) const {
177+ return computeEmbeddings (BB);
182178}
183179
184- const Embedding &Embedder::getBBVector (const BasicBlock &BB) const {
185- auto It = BBVecMap.find (&BB);
186- if (It != BBVecMap.end ())
187- return It->second ;
188- computeEmbeddings (BB);
189- return BBVecMap[&BB];
190- }
191-
192- const Embedding &Embedder::getFunctionVector () const {
180+ Embedding Embedder::getFunctionVector () const {
193181 // Currently, we always (re)compute the embeddings for the function.
194182 // This is cheaper than caching the vector.
195- computeEmbeddings ();
196- return FuncVector;
183+ return computeEmbeddings ();
197184}
198185
199- void Embedder::computeEmbeddings () const {
186+ Embedding Embedder::computeEmbeddings () const {
200187 if (F.isDeclaration ())
201- return ;
188+ return Embedding (Dimension, 0.0 ) ;
202189
203- FuncVector = Embedding (Dimension, 0.0 );
190+ Embedding FuncVector (Dimension, 0.0 );
204191
205192 // Consider only the basic blocks that are reachable from entry
206193 for (const BasicBlock *BB : depth_first (&F)) {
207- computeEmbeddings (*BB);
208- FuncVector += BBVecMap[BB];
194+ FuncVector += computeEmbeddings (*BB);
209195 }
196+ return FuncVector;
210197}
211198
212- void SymbolicEmbedder ::computeEmbeddings (const BasicBlock &BB) const {
199+ Embedding Embedder ::computeEmbeddings (const BasicBlock &BB) const {
213200 Embedding BBVector (Dimension, 0 );
214201
215202 // We consider only the non-debug and non-pseudo instructions
216203 for (const auto &I : BB.instructionsWithoutDebug ()) {
217- Embedding ArgEmb (Dimension, 0 );
218- for (const auto &Op : I.operands ())
219- ArgEmb += Vocab[*Op];
220- auto InstVector =
221- Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
222- if (const auto *IC = dyn_cast<CmpInst>(&I))
223- InstVector += Vocab[IC->getPredicate ()];
224- InstVecMap[&I] = InstVector;
225- BBVector += InstVector;
204+ BBVector += computeEmbeddings (I);
226205 }
227- BBVecMap[&BB] = BBVector;
206+ return BBVector;
228207}
229208
230- void FlowAwareEmbedder::computeEmbeddings (const BasicBlock &BB) const {
231- Embedding BBVector (Dimension, 0 );
209+ Embedding SymbolicEmbedder::computeEmbeddings (const Instruction &I) const {
210+ Embedding ArgEmb (Dimension, 0 );
211+ for (const auto &Op : I.operands ())
212+ ArgEmb += Vocab[*Op];
213+ auto InstVector =
214+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
215+ if (const auto *IC = dyn_cast<CmpInst>(&I))
216+ InstVector += Vocab[IC->getPredicate ()];
217+ return InstVector;
218+ }
232219
233- // We consider only the non-debug and non-pseudo instructions
234- for (const auto &I : BB.instructionsWithoutDebug ()) {
235- // TODO: Handle call instructions differently.
236- // For now, we treat them like other instructions
237- Embedding ArgEmb (Dimension, 0 );
238- for (const auto &Op : I.operands ()) {
239- // If the operand is defined elsewhere, we use its embedding
240- if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
241- auto DefIt = InstVecMap.find (DefInst);
242- // Fixme (#159171): Ideally we should never miss an instruction
243- // embedding here.
244- // But when we have cyclic dependencies (e.g., phi
245- // nodes), we might miss the embedding. In such cases, we fall back to
246- // using the vocabulary embedding. This can be fixed by iterating to a
247- // fixed-point, or by using a simple solver for the set of simultaneous
248- // equations.
249- // Another case when we might miss an instruction embedding is when
250- // the operand instruction is in a different basic block that has not
251- // been processed yet. This can be fixed by processing the basic blocks
252- // in a topological order.
253- if (DefIt != InstVecMap.end ())
254- ArgEmb += DefIt->second ;
255- else
256- ArgEmb += Vocab[*Op];
257- }
258- // If the operand is not defined by an instruction, we use the vocabulary
259- else {
260- LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
261- << *Op << " =" << Vocab[*Op][0 ] << " \n " );
220+ Embedding FlowAwareEmbedder::computeEmbeddings (const Instruction &I) const {
221+ // If we have already computed the embedding for this instruction, return it
222+ auto It = InstVecMap.find (&I);
223+ if (It != InstVecMap.end ())
224+ return It->second ;
225+
226+ // TODO: Handle call instructions differently.
227+ // For now, we treat them like other instructions
228+ Embedding ArgEmb (Dimension, 0 );
229+ for (const auto &Op : I.operands ()) {
230+ // If the operand is defined elsewhere, we use its embedding
231+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
232+ auto DefIt = InstVecMap.find (DefInst);
233+ // Fixme (#159171): Ideally we should never miss an instruction
234+ // embedding here.
235+ // But when we have cyclic dependencies (e.g., phi
236+ // nodes), we might miss the embedding. In such cases, we fall back to
237+ // using the vocabulary embedding. This can be fixed by iterating to a
238+ // fixed-point, or by using a simple solver for the set of simultaneous
239+ // equations.
240+ // Another case when we might miss an instruction embedding is when
241+ // the operand instruction is in a different basic block that has not
242+ // been processed yet. This can be fixed by processing the basic blocks
243+ // in a topological order.
244+ if (DefIt != InstVecMap.end ())
245+ ArgEmb += DefIt->second ;
246+ else
262247 ArgEmb += Vocab[*Op];
263- }
264248 }
265- // Create the instruction vector by combining opcode, type, and arguments
266- // embeddings
267- auto InstVector =
268- Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
269- // Add compare predicate embedding as an additional operand if applicable
270- if (const auto *IC = dyn_cast<CmpInst>(&I))
271- InstVector += Vocab[IC->getPredicate ()];
272- InstVecMap[&I] = InstVector;
273- BBVector += InstVector;
249+ // If the operand is not defined by an instruction, we use the
250+ // vocabulary
251+ else {
252+ LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
253+ << *Op << " =" << Vocab[*Op][0 ] << " \n " );
254+ ArgEmb += Vocab[*Op];
255+ }
274256 }
275- BBVecMap[&BB] = BBVector;
257+ // Create the instruction vector by combining opcode, type, and arguments
258+ // embeddings
259+ auto InstVector =
260+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
261+ // Add compare predicate embedding as an additional operand if applicable
262+ if (const auto *IC = dyn_cast<CmpInst>(&I))
263+ InstVector += Vocab[IC->getPredicate ()];
264+ InstVecMap[&I] = InstVector;
265+ return InstVector;
276266}
277267
268+ void FlowAwareEmbedder::invalidateEmbeddings () { InstVecMap.clear (); }
269+
278270// ==----------------------------------------------------------------------===//
279271// VocabStorage
280272// ===----------------------------------------------------------------------===//
@@ -695,25 +687,19 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
695687 Emb->getFunctionVector ().print (OS);
696688
697689 OS << " Basic block vectors:\n " ;
698- const auto &BBMap = Emb->getBBVecMap ();
699690 for (const BasicBlock &BB : F) {
700- auto It = BBMap.find (&BB);
701- if (It != BBMap.end ()) {
702- OS << " Basic block: " << BB.getName () << " :\n " ;
703- It->second .print (OS);
704- }
691+ auto BBVector = Emb->getBBVector (BB);
692+ OS << " Basic block: " << BB.getName () << " :\n " ;
693+ BBVector.print (OS);
705694 }
706695
707696 OS << " Instruction vectors:\n " ;
708- const auto &InstMap = Emb->getInstVecMap ();
709697 for (const BasicBlock &BB : F) {
710698 for (const Instruction &I : BB) {
711- auto It = InstMap.find (&I);
712- if (It != InstMap.end ()) {
713- OS << " Instruction: " ;
714- I.print (OS);
715- It->second .print (OS);
716- }
699+ auto InstVector = Emb->getInstVector (I);
700+ OS << " Instruction: " ;
701+ I.print (OS);
702+ InstVector.print (OS);
717703 }
718704 }
719705 }
0 commit comments