@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
8585 return *this ;
8686}
8787
88+ Embedding &Embedding::operator *=(double Factor) {
89+ std::transform (this ->begin (), this ->end (), this ->begin (),
90+ [Factor](double Elem) { return Elem * Factor; });
91+ return *this ;
92+ }
93+
8894Embedding &Embedding::scaleAndAdd (const Embedding &Src, float Factor) {
8995 assert (this ->size () == Src.size () && " Vectors must have the same dimension" );
9096 for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
101107 return true ;
102108}
103109
110+ void Embedding::print (raw_ostream &OS) const {
111+ OS << " [" ;
112+ for (const auto &Elem : Data)
113+ OS << " " << format (" %.2f" , Elem) << " " ;
114+ OS << " ]\n " ;
115+ }
116+
104117// ==----------------------------------------------------------------------===//
105118// Embedder and its subclasses
106119// ===----------------------------------------------------------------------===//
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196209 for (const auto &I : BB.instructionsWithoutDebug ()) {
197210 Embedding InstVector (Dimension, 0 );
198211
199- const auto OpcVec = lookupVocab (I.getOpcodeName ());
200- InstVector.scaleAndAdd (OpcVec, OpcWeight);
201-
202212 // FIXME: Currently lookups are string based. Use numeric Keys
203213 // for efficiency.
204- const auto Type = I.getType ();
205- const auto TypeVec = getTypeEmbedding (Type);
206- InstVector.scaleAndAdd (TypeVec, TypeWeight);
207-
214+ InstVector += lookupVocab (I.getOpcodeName ());
215+ InstVector += getTypeEmbedding (I.getType ());
208216 for (const auto &Op : I.operands ()) {
209- const auto OperandVec = getOperandEmbedding (Op.get ());
210- InstVector.scaleAndAdd (OperandVec, ArgWeight);
217+ InstVector += getOperandEmbedding (Op.get ());
211218 }
212219 InstVecMap[&I] = InstVector;
213220 BBVector += InstVector;
@@ -251,6 +258,46 @@ bool IR2VecVocabResult::invalidate(
251258 return !(PAC.preservedWhenStateless ());
252259}
253260
261+ Error IR2VecVocabAnalysis::parseVocabSection (const char *Key,
262+ const json::Value ParsedVocabValue,
263+ ir2vec::Vocab &TargetVocab,
264+ unsigned &Dim) {
265+ assert (Key && " Key cannot be null" );
266+
267+ json::Path::Root Path (" " );
268+ const json::Object *RootObj = ParsedVocabValue.getAsObject ();
269+ if (!RootObj)
270+ return createStringError (errc::invalid_argument,
271+ " JSON root is not an object" );
272+
273+ const json::Value *SectionValue = RootObj->get (Key);
274+ if (!SectionValue)
275+ return createStringError (errc::invalid_argument,
276+ " Missing '" + std::string (Key) +
277+ " ' section in vocabulary file" );
278+ if (!json::fromJSON (*SectionValue, TargetVocab, Path))
279+ return createStringError (errc::illegal_byte_sequence,
280+ " Unable to parse '" + std::string (Key) +
281+ " ' section from vocabulary" );
282+
283+ Dim = TargetVocab.begin ()->second .size ();
284+ if (Dim == 0 )
285+ return createStringError (errc::illegal_byte_sequence,
286+ " Dimension of '" + std::string (Key) +
287+ " ' section of the vocabulary is zero" );
288+
289+ if (!std::all_of (TargetVocab.begin (), TargetVocab.end (),
290+ [Dim](const std::pair<StringRef, Embedding> &Entry) {
291+ return Entry.second .size () == Dim;
292+ }))
293+ return createStringError (
294+ errc::illegal_byte_sequence,
295+ " All vectors in the '" + std::string (Key) +
296+ " ' section of the vocabulary are not of the same dimension" );
297+
298+ return Error::success ();
299+ };
300+
254301// FIXME: Make this optional. We can avoid file reads
255302// by auto-generating a default vocabulary during the build time.
256303Error IR2VecVocabAnalysis::readVocabulary () {
@@ -259,32 +306,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
259306 return createFileError (VocabFile, BufOrError.getError ());
260307
261308 auto Content = BufOrError.get ()->getBuffer ();
262- json::Path::Root Path ( " " );
309+
263310 Expected<json::Value> ParsedVocabValue = json::parse (Content);
264311 if (!ParsedVocabValue)
265312 return ParsedVocabValue.takeError ();
266313
267- bool Res = json::fromJSON (*ParsedVocabValue, Vocabulary, Path);
268- if (!Res)
269- return createStringError (errc::illegal_byte_sequence,
270- " Unable to parse the vocabulary" );
314+ ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
315+ unsigned OpcodeDim, TypeDim, ArgDim;
316+ if (auto Err = parseVocabSection (" Opcodes" , *ParsedVocabValue, OpcodeVocab,
317+ OpcodeDim))
318+ return Err;
271319
272- if (Vocabulary. empty ())
273- return createStringError (errc::illegal_byte_sequence,
274- " Vocabulary is empty " ) ;
320+ if (auto Err =
321+ parseVocabSection ( " Types " , *ParsedVocabValue, TypeVocab, TypeDim))
322+ return Err ;
275323
276- unsigned Dim = Vocabulary.begin ()->second .size ();
277- if (Dim == 0 )
324+ if (auto Err =
325+ parseVocabSection (" Arguments" , *ParsedVocabValue, ArgVocab, ArgDim))
326+ return Err;
327+
328+ if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
278329 return createStringError (errc::illegal_byte_sequence,
279- " Dimension of vocabulary is zero " );
330+ " Vocabulary sections have different dimensions " );
280331
281- if (!std::all_of (Vocabulary.begin (), Vocabulary.end (),
282- [Dim](const std::pair<StringRef, Embedding> &Entry) {
283- return Entry.second .size () == Dim;
284- }))
285- return createStringError (
286- errc::illegal_byte_sequence,
287- " All vectors in the vocabulary are not of the same dimension" );
332+ auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
333+ for (auto &Entry : Vocab)
334+ Entry.second *= Weight;
335+ };
336+ scaleVocabSection (OpcodeVocab, OpcWeight);
337+ scaleVocabSection (TypeVocab, TypeWeight);
338+ scaleVocabSection (ArgVocab, ArgWeight);
339+
340+ Vocabulary.insert (OpcodeVocab.begin (), OpcodeVocab.end ());
341+ Vocabulary.insert (TypeVocab.begin (), TypeVocab.end ());
342+ Vocabulary.insert (ArgVocab.begin (), ArgVocab.end ());
288343
289344 return Error::success ();
290345}
@@ -304,7 +359,7 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
304359IR2VecVocabAnalysis::Result
305360IR2VecVocabAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
306361 auto Ctx = &M.getContext ();
307- // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
362+
308363 // If vocabulary is already populated by the constructor, use it.
309364 if (!Vocabulary.empty ())
310365 return IR2VecVocabResult (std::move (Vocabulary));
@@ -323,16 +378,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
323378}
324379
325380// ==----------------------------------------------------------------------===//
326- // IR2VecPrinterPass
381+ // Printer Passes
327382// ===----------------------------------------------------------------------===//
328383
329- void IR2VecPrinterPass::printVector (const Embedding &Vec) const {
330- OS << " [" ;
331- for (const auto &Elem : Vec)
332- OS << " " << format (" %.2f" , Elem) << " " ;
333- OS << " ]\n " ;
334- }
335-
336384PreservedAnalyses IR2VecPrinterPass::run (Module &M,
337385 ModuleAnalysisManager &MAM) {
338386 auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
@@ -353,15 +401,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
353401
354402 OS << " IR2Vec embeddings for function " << F.getName () << " :\n " ;
355403 OS << " Function vector: " ;
356- printVector ( Emb->getFunctionVector ());
404+ Emb->getFunctionVector (). print (OS );
357405
358406 OS << " Basic block vectors:\n " ;
359407 const auto &BBMap = Emb->getBBVecMap ();
360408 for (const BasicBlock &BB : F) {
361409 auto It = BBMap.find (&BB);
362410 if (It != BBMap.end ()) {
363411 OS << " Basic block: " << BB.getName () << " :\n " ;
364- printVector ( It->second );
412+ It->second . print (OS );
365413 }
366414 }
367415
@@ -373,10 +421,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
373421 if (It != InstMap.end ()) {
374422 OS << " Instruction: " ;
375423 I.print (OS);
376- printVector ( It->second );
424+ It->second . print (OS );
377425 }
378426 }
379427 }
380428 }
381429 return PreservedAnalyses::all ();
382430}
431+
432+ PreservedAnalyses IR2VecVocabPrinterPass::run (Module &M,
433+ ModuleAnalysisManager &MAM) {
434+ auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
435+ assert (IR2VecVocabResult.isValid () && " IR2Vec Vocabulary is invalid" );
436+
437+ auto Vocab = IR2VecVocabResult.getVocabulary ();
438+ for (const auto &Entry : Vocab) {
439+ OS << " Key: " << Entry.first << " : " ;
440+ Entry.second .print (OS);
441+ }
442+
443+ return PreservedAnalyses::all ();
444+ }
0 commit comments