@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
8585 return *this ;
8686}
8787
88+ Embedding &Embedding::operator *=(double Factor) {
89+ std::transform (this ->begin (), this ->end (), this ->begin (),
90+ [Factor](double Elem) { return Elem * Factor; });
91+ return *this ;
92+ }
93+
8894Embedding &Embedding::scaleAndAdd (const Embedding &Src, float Factor) {
8995 assert (this ->size () == Src.size () && " Vectors must have the same dimension" );
9096 for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
101107 return true ;
102108}
103109
110+ void Embedding::print (raw_ostream &OS) const {
111+ OS << " [" ;
112+ for (const auto &Elem : Data)
113+ OS << " " << format (" %.2f" , Elem) << " " ;
114+ OS << " ]\n " ;
115+ }
116+
104117// ==----------------------------------------------------------------------===//
105118// Embedder and its subclasses
106119// ===----------------------------------------------------------------------===//
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196209 for (const auto &I : BB.instructionsWithoutDebug ()) {
197210 Embedding InstVector (Dimension, 0 );
198211
199- const auto OpcVec = lookupVocab (I.getOpcodeName ());
200- InstVector.scaleAndAdd (OpcVec, OpcWeight);
201-
202212 // FIXME: Currently lookups are string based. Use numeric Keys
203213 // for efficiency.
204- const auto Type = I.getType ();
205- const auto TypeVec = getTypeEmbedding (Type);
206- InstVector.scaleAndAdd (TypeVec, TypeWeight);
207-
214+ InstVector += lookupVocab (I.getOpcodeName ());
215+ InstVector += getTypeEmbedding (I.getType ());
208216 for (const auto &Op : I.operands ()) {
209- const auto OperandVec = getOperandEmbedding (Op.get ());
210- InstVector.scaleAndAdd (OperandVec, ArgWeight);
217+ InstVector += getOperandEmbedding (Op.get ());
211218 }
212219 InstVecMap[&I] = InstVector;
213220 BBVector += InstVector;
@@ -251,6 +258,43 @@ bool IR2VecVocabResult::invalidate(
251258 return !(PAC.preservedWhenStateless ());
252259}
253260
261+ Error IR2VecVocabAnalysis::parseVocabSection (
262+ StringRef Key, const json::Value &ParsedVocabValue,
263+ ir2vec::Vocab &TargetVocab, unsigned &Dim) {
264+ json::Path::Root Path (" " );
265+ const json::Object *RootObj = ParsedVocabValue.getAsObject ();
266+ if (!RootObj)
267+ return createStringError (errc::invalid_argument,
268+ " JSON root is not an object" );
269+
270+ const json::Value *SectionValue = RootObj->get (Key);
271+ if (!SectionValue)
272+ return createStringError (errc::invalid_argument,
273+ " Missing '" + std::string (Key) +
274+ " ' section in vocabulary file" );
275+ if (!json::fromJSON (*SectionValue, TargetVocab, Path))
276+ return createStringError (errc::illegal_byte_sequence,
277+ " Unable to parse '" + std::string (Key) +
278+ " ' section from vocabulary" );
279+
280+ Dim = TargetVocab.begin ()->second .size ();
281+ if (Dim == 0 )
282+ return createStringError (errc::illegal_byte_sequence,
283+ " Dimension of '" + std::string (Key) +
284+ " ' section of the vocabulary is zero" );
285+
286+ if (!std::all_of (TargetVocab.begin (), TargetVocab.end (),
287+ [Dim](const std::pair<StringRef, Embedding> &Entry) {
288+ return Entry.second .size () == Dim;
289+ }))
290+ return createStringError (
291+ errc::illegal_byte_sequence,
292+ " All vectors in the '" + std::string (Key) +
293+ " ' section of the vocabulary are not of the same dimension" );
294+
295+ return Error::success ();
296+ };
297+
254298// FIXME: Make this optional. We can avoid file reads
255299// by auto-generating a default vocabulary during the build time.
256300Error IR2VecVocabAnalysis::readVocabulary () {
@@ -259,32 +303,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
259303 return createFileError (VocabFile, BufOrError.getError ());
260304
261305 auto Content = BufOrError.get ()->getBuffer ();
262- json::Path::Root Path ( " " );
306+
263307 Expected<json::Value> ParsedVocabValue = json::parse (Content);
264308 if (!ParsedVocabValue)
265309 return ParsedVocabValue.takeError ();
266310
267- bool Res = json::fromJSON (*ParsedVocabValue, Vocabulary, Path);
268- if (!Res)
269- return createStringError (errc::illegal_byte_sequence,
270- " Unable to parse the vocabulary" );
311+ ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
312+ unsigned OpcodeDim = 0 , TypeDim = 0 , ArgDim = 0 ;
313+ if (auto Err = parseVocabSection (" Opcodes" , *ParsedVocabValue, OpcodeVocab,
314+ OpcodeDim))
315+ return Err;
271316
272- if (Vocabulary. empty ())
273- return createStringError (errc::illegal_byte_sequence,
274- " Vocabulary is empty " ) ;
317+ if (auto Err =
318+ parseVocabSection ( " Types " , *ParsedVocabValue, TypeVocab, TypeDim))
319+ return Err ;
275320
276- unsigned Dim = Vocabulary.begin ()->second .size ();
277- if (Dim == 0 )
321+ if (auto Err =
322+ parseVocabSection (" Arguments" , *ParsedVocabValue, ArgVocab, ArgDim))
323+ return Err;
324+
325+ if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
278326 return createStringError (errc::illegal_byte_sequence,
279- " Dimension of vocabulary is zero " );
327+ " Vocabulary sections have different dimensions " );
280328
281- if (!std::all_of (Vocabulary.begin (), Vocabulary.end (),
282- [Dim](const std::pair<StringRef, Embedding> &Entry) {
283- return Entry.second .size () == Dim;
284- }))
285- return createStringError (
286- errc::illegal_byte_sequence,
287- " All vectors in the vocabulary are not of the same dimension" );
329+ auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
330+ for (auto &Entry : Vocab)
331+ Entry.second *= Weight;
332+ };
333+ scaleVocabSection (OpcodeVocab, OpcWeight);
334+ scaleVocabSection (TypeVocab, TypeWeight);
335+ scaleVocabSection (ArgVocab, ArgWeight);
336+
337+ Vocabulary.insert (OpcodeVocab.begin (), OpcodeVocab.end ());
338+ Vocabulary.insert (TypeVocab.begin (), TypeVocab.end ());
339+ Vocabulary.insert (ArgVocab.begin (), ArgVocab.end ());
288340
289341 return Error::success ();
290342}
@@ -304,7 +356,6 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
304356IR2VecVocabAnalysis::Result
305357IR2VecVocabAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
306358 auto Ctx = &M.getContext ();
307- // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
308359 // If vocabulary is already populated by the constructor, use it.
309360 if (!Vocabulary.empty ())
310361 return IR2VecVocabResult (std::move (Vocabulary));
@@ -323,16 +374,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
323374}
324375
325376// ==----------------------------------------------------------------------===//
326- // IR2VecPrinterPass
377+ // Printer Passes
327378// ===----------------------------------------------------------------------===//
328379
329- void IR2VecPrinterPass::printVector (const Embedding &Vec) const {
330- OS << " [" ;
331- for (const auto &Elem : Vec)
332- OS << " " << format (" %.2f" , Elem) << " " ;
333- OS << " ]\n " ;
334- }
335-
336380PreservedAnalyses IR2VecPrinterPass::run (Module &M,
337381 ModuleAnalysisManager &MAM) {
338382 auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
@@ -353,15 +397,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
353397
354398 OS << " IR2Vec embeddings for function " << F.getName () << " :\n " ;
355399 OS << " Function vector: " ;
356- printVector ( Emb->getFunctionVector ());
400+ Emb->getFunctionVector (). print (OS );
357401
358402 OS << " Basic block vectors:\n " ;
359403 const auto &BBMap = Emb->getBBVecMap ();
360404 for (const BasicBlock &BB : F) {
361405 auto It = BBMap.find (&BB);
362406 if (It != BBMap.end ()) {
363407 OS << " Basic block: " << BB.getName () << " :\n " ;
364- printVector ( It->second );
408+ It->second . print (OS );
365409 }
366410 }
367411
@@ -373,10 +417,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
373417 if (It != InstMap.end ()) {
374418 OS << " Instruction: " ;
375419 I.print (OS);
376- printVector ( It->second );
420+ It->second . print (OS );
377421 }
378422 }
379423 }
380424 }
381425 return PreservedAnalyses::all ();
382426}
427+
428+ PreservedAnalyses IR2VecVocabPrinterPass::run (Module &M,
429+ ModuleAnalysisManager &MAM) {
430+ auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
431+ assert (IR2VecVocabResult.isValid () && " IR2Vec Vocabulary is invalid" );
432+
433+ auto Vocab = IR2VecVocabResult.getVocabulary ();
434+ for (const auto &Entry : Vocab) {
435+ OS << " Key: " << Entry.first << " : " ;
436+ Entry.second .print (OS);
437+ }
438+
439+ return PreservedAnalyses::all ();
440+ }
0 commit comments