45
45
#include " llvm/Support/JSON.h"
46
46
#include < array>
47
47
#include < map>
48
+ #include < optional>
48
49
49
50
namespace llvm {
50
51
@@ -144,6 +145,73 @@ struct Embedding {
144
145
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
145
146
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
146
147
148
+ // / Generic storage class for section-based vocabularies.
149
+ // / VocabStorage provides a generic foundation for storing and accessing
150
+ // / embeddings organized into sections.
151
+ class VocabStorage {
152
+ private:
153
+ // / Section-based storage
154
+ std::vector<std::vector<Embedding>> Sections;
155
+
156
+ const size_t TotalSize;
157
+ const unsigned Dimension;
158
+
159
+ public:
160
+ // / Default constructor creates empty storage (invalid state)
161
+ VocabStorage () : Sections(), TotalSize(0 ), Dimension(0 ) {}
162
+
163
+ // / Create a VocabStorage with pre-organized section data
164
+ VocabStorage (std::vector<std::vector<Embedding>> &&SectionData);
165
+
166
+ VocabStorage (VocabStorage &&) = default ;
167
+ VocabStorage &operator =(VocabStorage &&) = delete ;
168
+
169
+ VocabStorage (const VocabStorage &) = delete ;
170
+ VocabStorage &operator =(const VocabStorage &) = delete ;
171
+
172
+ // / Get total number of entries across all sections
173
+ size_t size () const { return TotalSize; }
174
+
175
+ // / Get number of sections
176
+ unsigned getNumSections () const {
177
+ return static_cast <unsigned >(Sections.size ());
178
+ }
179
+
180
+ // / Section-based access: Storage[sectionId][localIndex]
181
+ const std::vector<Embedding> &operator [](unsigned SectionId) const {
182
+ assert (SectionId < Sections.size () && " Invalid section ID" );
183
+ return Sections[SectionId];
184
+ }
185
+
186
+ // / Get vocabulary dimension
187
+ unsigned getDimension () const { return Dimension; }
188
+
189
+ // / Check if vocabulary is valid (has data)
190
+ bool isValid () const { return TotalSize > 0 ; }
191
+
192
+ // / Iterator support for section-based access
193
+ class const_iterator {
194
+ const VocabStorage *Storage;
195
+ unsigned SectionId = 0 ;
196
+ size_t LocalIndex = 0 ;
197
+
198
+ public:
199
+ const_iterator (const VocabStorage *Storage, unsigned SectionId,
200
+ size_t LocalIndex)
201
+ : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
202
+
203
+ LLVM_ABI const Embedding &operator *() const ;
204
+ LLVM_ABI const_iterator &operator ++();
205
+ LLVM_ABI bool operator ==(const const_iterator &Other) const ;
206
+ LLVM_ABI bool operator !=(const const_iterator &Other) const ;
207
+ };
208
+
209
+ const_iterator begin () const { return const_iterator (this , 0 , 0 ); }
210
+ const_iterator end () const {
211
+ return const_iterator (this , getNumSections (), 0 );
212
+ }
213
+ };
214
+
147
215
// / Class for storing and accessing the IR2Vec vocabulary.
148
216
// / The Vocabulary class manages seed embeddings for LLVM IR entities. The
149
217
// / seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
164
232
class Vocabulary {
165
233
friend class llvm ::IR2VecVocabAnalysis;
166
234
167
- // Vocabulary Slot Layout:
235
+ // Vocabulary Layout:
168
236
// +----------------+------------------------------------------------------+
169
237
// | Entity Type | Index Range |
170
238
// +----------------+------------------------------------------------------+
@@ -180,8 +248,16 @@ class Vocabulary {
180
248
// and improves learning. Operands include Comparison predicates
181
249
// (ICmp/FCmp) along with other operand types. This can be extended to
182
250
// include other specializations in future.
183
- using VocabVector = std::vector<ir2vec::Embedding>;
184
- VocabVector Vocab;
251
+ enum class Section : unsigned {
252
+ Opcodes = 0 ,
253
+ CanonicalTypes = 1 ,
254
+ Operands = 2 ,
255
+ Predicates = 3 ,
256
+ MaxSections
257
+ };
258
+
259
+ // Use section-based storage for better organization and efficiency
260
+ VocabStorage Storage;
185
261
186
262
static constexpr unsigned NumICmpPredicates =
187
263
static_cast <unsigned >(CmpInst::LAST_ICMP_PREDICATE) -
@@ -233,10 +309,23 @@ class Vocabulary {
233
309
NumICmpPredicates + NumFCmpPredicates;
234
310
235
311
Vocabulary () = default ;
236
- LLVM_ABI Vocabulary (VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
312
+ LLVM_ABI Vocabulary (VocabStorage &&Storage) : Storage(std::move(Storage)) {}
313
+
314
+ Vocabulary (const Vocabulary &) = delete ;
315
+ Vocabulary &operator =(const Vocabulary &) = delete ;
316
+
317
+ Vocabulary (Vocabulary &&) = default ;
318
+ Vocabulary &operator =(Vocabulary &&Other) = delete ;
319
+
320
+ LLVM_ABI bool isValid () const {
321
+ return Storage.size () == NumCanonicalEntries;
322
+ }
323
+
324
+ LLVM_ABI unsigned getDimension () const {
325
+ assert (isValid () && " IR2Vec Vocabulary is invalid" );
326
+ return Storage.getDimension ();
327
+ }
237
328
238
- LLVM_ABI bool isValid () const { return Vocab.size () == NumCanonicalEntries; };
239
- LLVM_ABI unsigned getDimension () const ;
240
329
// / Total number of entries (opcodes + canonicalized types + operand kinds +
241
330
// / predicates)
242
331
static constexpr size_t getCanonicalSize () { return NumCanonicalEntries; }
@@ -245,59 +334,91 @@ class Vocabulary {
245
334
LLVM_ABI static StringRef getVocabKeyForOpcode (unsigned Opcode);
246
335
247
336
// / Function to get vocabulary key for a given TypeID
248
- LLVM_ABI static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
337
+ LLVM_ABI static StringRef getVocabKeyForTypeID (Type::TypeID TypeID) {
338
+ return getVocabKeyForCanonicalTypeID (getCanonicalTypeID (TypeID));
339
+ }
249
340
250
341
// / Function to get vocabulary key for a given OperandKind
251
- LLVM_ABI static StringRef getVocabKeyForOperandKind (OperandKind Kind);
342
+ LLVM_ABI static StringRef getVocabKeyForOperandKind (OperandKind Kind) {
343
+ unsigned Index = static_cast <unsigned >(Kind);
344
+ assert (Index < MaxOperandKinds && " Invalid OperandKind" );
345
+ return OperandKindNames[Index];
346
+ }
252
347
253
348
// / Function to classify an operand into OperandKind
254
349
LLVM_ABI static OperandKind getOperandKind (const Value *Op);
255
350
256
351
// / Function to get vocabulary key for a given predicate
257
352
LLVM_ABI static StringRef getVocabKeyForPredicate (CmpInst::Predicate P);
258
353
259
- // / Functions to return the slot index or position of a given Opcode, TypeID,
260
- // / or OperandKind in the vocabulary.
261
- LLVM_ABI static unsigned getSlotIndex (unsigned Opcode);
262
- LLVM_ABI static unsigned getSlotIndex (Type::TypeID TypeID);
263
- LLVM_ABI static unsigned getSlotIndex (const Value &Op);
264
- LLVM_ABI static unsigned getSlotIndex (CmpInst::Predicate P);
354
+ // / Functions to return flat index
355
+ LLVM_ABI static unsigned getIndex (unsigned Opcode) {
356
+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
357
+ return Opcode - 1 ; // Convert to zero-based index
358
+ }
359
+
360
+ LLVM_ABI static unsigned getIndex (Type::TypeID TypeID) {
361
+ assert (static_cast <unsigned >(TypeID) < MaxTypeIDs && " Invalid type ID" );
362
+ return MaxOpcodes + static_cast <unsigned >(getCanonicalTypeID (TypeID));
363
+ }
364
+
365
+ LLVM_ABI static unsigned getIndex (const Value &Op) {
366
+ unsigned Index = static_cast <unsigned >(getOperandKind (&Op));
367
+ assert (Index < MaxOperandKinds && " Invalid OperandKind" );
368
+ return OperandBaseOffset + Index;
369
+ }
370
+
371
+ LLVM_ABI static unsigned getIndex (CmpInst::Predicate P) {
372
+ return PredicateBaseOffset + getPredicateLocalIndex (P);
373
+ }
265
374
266
375
// / Accessors to get the embedding for a given entity.
267
- LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const ;
268
- LLVM_ABI const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
269
- LLVM_ABI const ir2vec::Embedding &operator [](const Value &Arg) const ;
270
- LLVM_ABI const ir2vec::Embedding &operator [](CmpInst::Predicate P) const ;
376
+ LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const {
377
+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
378
+ return Storage[static_cast <unsigned >(Section::Opcodes)][Opcode - 1 ];
379
+ }
380
+
381
+ LLVM_ABI const ir2vec::Embedding &operator [](Type::TypeID TypeID) const {
382
+ assert (static_cast <unsigned >(TypeID) < MaxTypeIDs && " Invalid type ID" );
383
+ unsigned LocalIndex = static_cast <unsigned >(getCanonicalTypeID (TypeID));
384
+ return Storage[static_cast <unsigned >(Section::CanonicalTypes)][LocalIndex];
385
+ }
386
+
387
+ LLVM_ABI const ir2vec::Embedding &operator [](const Value &Arg) const {
388
+ unsigned LocalIndex = static_cast <unsigned >(getOperandKind (&Arg));
389
+ assert (LocalIndex < MaxOperandKinds && " Invalid OperandKind" );
390
+ return Storage[static_cast <unsigned >(Section::Operands)][LocalIndex];
391
+ }
392
+
393
+ LLVM_ABI const ir2vec::Embedding &operator [](CmpInst::Predicate P) const {
394
+ unsigned LocalIndex = getPredicateLocalIndex (P);
395
+ return Storage[static_cast <unsigned >(Section::Predicates)][LocalIndex];
396
+ }
271
397
272
398
// / Const Iterator type aliases
273
- using const_iterator = VocabVector::const_iterator;
399
+ using const_iterator = VocabStorage::const_iterator;
400
+
274
401
const_iterator begin () const {
275
402
assert (isValid () && " IR2Vec Vocabulary is invalid" );
276
- return Vocab .begin ();
403
+ return Storage .begin ();
277
404
}
278
405
279
- const_iterator cbegin () const {
280
- assert (isValid () && " IR2Vec Vocabulary is invalid" );
281
- return Vocab.cbegin ();
282
- }
406
+ const_iterator cbegin () const { return begin (); }
283
407
284
408
const_iterator end () const {
285
409
assert (isValid () && " IR2Vec Vocabulary is invalid" );
286
- return Vocab .end ();
410
+ return Storage .end ();
287
411
}
288
412
289
- const_iterator cend () const {
290
- assert (isValid () && " IR2Vec Vocabulary is invalid" );
291
- return Vocab.cend ();
292
- }
413
+ const_iterator cend () const { return end (); }
293
414
294
415
// / Returns the string key for a given index position in the vocabulary.
295
416
// / This is useful for debugging or printing the vocabulary. Do not use this
296
417
// / for embedding generation as string based lookups are inefficient.
297
418
LLVM_ABI static StringRef getStringKey (unsigned Pos);
298
419
299
420
// / Create a dummy vocabulary for testing purposes.
300
- LLVM_ABI static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
421
+ LLVM_ABI static VocabStorage createDummyVocabForTest (unsigned Dim = 1 );
301
422
302
423
LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
303
424
ModuleAnalysisManager::Invalidator &Inv) const ;
@@ -306,12 +427,16 @@ class Vocabulary {
306
427
constexpr static unsigned NumCanonicalEntries =
307
428
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
308
429
309
- // Base offsets for slot layout to simplify index computation
430
+ // Base offsets for flat index computation
310
431
constexpr static unsigned OperandBaseOffset =
311
432
MaxOpcodes + MaxCanonicalTypeIDs;
312
433
constexpr static unsigned PredicateBaseOffset =
313
434
OperandBaseOffset + MaxOperandKinds;
314
435
436
+ // / Functions for predicate index calculations
437
+ static unsigned getPredicateLocalIndex (CmpInst::Predicate P);
438
+ static CmpInst::Predicate getPredicateFromLocalIndex (unsigned LocalIndex);
439
+
315
440
// / String mappings for CanonicalTypeID values
316
441
static constexpr StringLiteral CanonicalTypeNames[] = {
317
442
" FloatTy" , " VoidTy" , " LabelTy" , " MetadataTy" ,
@@ -358,15 +483,26 @@ class Vocabulary {
358
483
359
484
// / Function to get vocabulary key for canonical type by enum
360
485
LLVM_ABI static StringRef
361
- getVocabKeyForCanonicalTypeID (CanonicalTypeID CType);
486
+ getVocabKeyForCanonicalTypeID (CanonicalTypeID CType) {
487
+ unsigned Index = static_cast <unsigned >(CType);
488
+ assert (Index < MaxCanonicalTypeIDs && " Invalid CanonicalTypeID" );
489
+ return CanonicalTypeNames[Index];
490
+ }
362
491
363
492
// / Function to convert TypeID to CanonicalTypeID
364
- LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID);
493
+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID) {
494
+ unsigned Index = static_cast <unsigned >(TypeID);
495
+ assert (Index < MaxTypeIDs && " Invalid TypeID" );
496
+ return TypeIDMapping[Index];
497
+ }
365
498
366
499
// / Function to get the predicate enum value for a given index. Index is
367
500
// / relative to the predicates section of the vocabulary. E.g., Index 0
368
501
// / corresponds to the first predicate.
369
- LLVM_ABI static CmpInst::Predicate getPredicate (unsigned Index);
502
+ LLVM_ABI static CmpInst::Predicate getPredicate (unsigned Index) {
503
+ assert (Index < MaxPredicateKinds && " Invalid predicate index" );
504
+ return getPredicateFromLocalIndex (Index);
505
+ }
370
506
};
371
507
372
508
// / Embedder provides the interface to generate embeddings (vector
@@ -459,22 +595,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
459
595
// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
460
596
// / its corresponding embedding.
461
597
class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
462
- using VocabVector = std::vector<ir2vec::Embedding>;
463
598
using VocabMap = std::map<std::string, ir2vec::Embedding>;
464
- VocabMap OpcVocab, TypeVocab, ArgVocab;
465
- VocabVector Vocab;
599
+ std::optional<ir2vec::VocabStorage> Vocab;
466
600
467
- Error readVocabulary ();
601
+ Error readVocabulary (VocabMap &OpcVocab, VocabMap &TypeVocab,
602
+ VocabMap &ArgVocab);
468
603
Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
469
604
VocabMap &TargetVocab, unsigned &Dim);
470
- void generateNumMappedVocab ();
605
+ void generateVocabStorage (VocabMap &OpcVocab, VocabMap &TypeVocab,
606
+ VocabMap &ArgVocab);
471
607
void emitError (Error Err, LLVMContext &Ctx);
472
608
473
609
public:
474
610
LLVM_ABI static AnalysisKey Key;
475
611
IR2VecVocabAnalysis () = default ;
476
- LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector & Vocab);
477
- LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector && Vocab);
612
+ LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::VocabStorage && Vocab)
613
+ : Vocab(std::move(Vocab)) {}
478
614
using Result = ir2vec::Vocabulary;
479
615
LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
480
616
};
0 commit comments