Skip to content

Commit ed1d954

Browse files
authored
[IR2Vec] Refactor vocabulary to use section-based storage (#158376)
Refactored IR2Vec vocabulary and introduced IR (semantics) agnostic `VocabStorage` - `Vocabulary` *has-a* `VocabStorage` - `Vocabulary` deals with LLVM IR specific entities. This would help in efficient reuse of parts of the logic for MIR. - Storage uses a section-based approach instead of a flat vector, improving organization and access patterns.
1 parent 4aaf6d1 commit ed1d954

File tree

6 files changed

+648
-194
lines changed

6 files changed

+648
-194
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 177 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "llvm/Support/JSON.h"
4646
#include <array>
4747
#include <map>
48+
#include <optional>
4849

4950
namespace llvm {
5051

@@ -144,6 +145,73 @@ struct Embedding {
144145
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
145146
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
146147

148+
/// Generic storage class for section-based vocabularies.
149+
/// VocabStorage provides a generic foundation for storing and accessing
150+
/// embeddings organized into sections.
151+
class VocabStorage {
152+
private:
153+
/// Section-based storage
154+
std::vector<std::vector<Embedding>> Sections;
155+
156+
const size_t TotalSize;
157+
const unsigned Dimension;
158+
159+
public:
160+
/// Default constructor creates empty storage (invalid state)
161+
VocabStorage() : Sections(), TotalSize(0), Dimension(0) {}
162+
163+
/// Create a VocabStorage with pre-organized section data
164+
VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
165+
166+
VocabStorage(VocabStorage &&) = default;
167+
VocabStorage &operator=(VocabStorage &&) = delete;
168+
169+
VocabStorage(const VocabStorage &) = delete;
170+
VocabStorage &operator=(const VocabStorage &) = delete;
171+
172+
/// Get total number of entries across all sections
173+
size_t size() const { return TotalSize; }
174+
175+
/// Get number of sections
176+
unsigned getNumSections() const {
177+
return static_cast<unsigned>(Sections.size());
178+
}
179+
180+
/// Section-based access: Storage[sectionId][localIndex]
181+
const std::vector<Embedding> &operator[](unsigned SectionId) const {
182+
assert(SectionId < Sections.size() && "Invalid section ID");
183+
return Sections[SectionId];
184+
}
185+
186+
/// Get vocabulary dimension
187+
unsigned getDimension() const { return Dimension; }
188+
189+
/// Check if vocabulary is valid (has data)
190+
bool isValid() const { return TotalSize > 0; }
191+
192+
/// Iterator support for section-based access
193+
class const_iterator {
194+
const VocabStorage *Storage;
195+
unsigned SectionId = 0;
196+
size_t LocalIndex = 0;
197+
198+
public:
199+
const_iterator(const VocabStorage *Storage, unsigned SectionId,
200+
size_t LocalIndex)
201+
: Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
202+
203+
LLVM_ABI const Embedding &operator*() const;
204+
LLVM_ABI const_iterator &operator++();
205+
LLVM_ABI bool operator==(const const_iterator &Other) const;
206+
LLVM_ABI bool operator!=(const const_iterator &Other) const;
207+
};
208+
209+
const_iterator begin() const { return const_iterator(this, 0, 0); }
210+
const_iterator end() const {
211+
return const_iterator(this, getNumSections(), 0);
212+
}
213+
};
214+
147215
/// Class for storing and accessing the IR2Vec vocabulary.
148216
/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
149217
/// seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
164232
class Vocabulary {
165233
friend class llvm::IR2VecVocabAnalysis;
166234

167-
// Vocabulary Slot Layout:
235+
// Vocabulary Layout:
168236
// +----------------+------------------------------------------------------+
169237
// | Entity Type | Index Range |
170238
// +----------------+------------------------------------------------------+
@@ -180,8 +248,16 @@ class Vocabulary {
180248
// and improves learning. Operands include Comparison predicates
181249
// (ICmp/FCmp) along with other operand types. This can be extended to
182250
// include other specializations in future.
183-
using VocabVector = std::vector<ir2vec::Embedding>;
184-
VocabVector Vocab;
251+
enum class Section : unsigned {
252+
Opcodes = 0,
253+
CanonicalTypes = 1,
254+
Operands = 2,
255+
Predicates = 3,
256+
MaxSections
257+
};
258+
259+
// Use section-based storage for better organization and efficiency
260+
VocabStorage Storage;
185261

186262
static constexpr unsigned NumICmpPredicates =
187263
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
@@ -233,10 +309,23 @@ class Vocabulary {
233309
NumICmpPredicates + NumFCmpPredicates;
234310

235311
Vocabulary() = default;
236-
LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
312+
LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {}
313+
314+
Vocabulary(const Vocabulary &) = delete;
315+
Vocabulary &operator=(const Vocabulary &) = delete;
316+
317+
Vocabulary(Vocabulary &&) = default;
318+
Vocabulary &operator=(Vocabulary &&Other) = delete;
319+
320+
LLVM_ABI bool isValid() const {
321+
return Storage.size() == NumCanonicalEntries;
322+
}
323+
324+
LLVM_ABI unsigned getDimension() const {
325+
assert(isValid() && "IR2Vec Vocabulary is invalid");
326+
return Storage.getDimension();
327+
}
237328

238-
LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
239-
LLVM_ABI unsigned getDimension() const;
240329
/// Total number of entries (opcodes + canonicalized types + operand kinds +
241330
/// predicates)
242331
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
@@ -245,59 +334,91 @@ class Vocabulary {
245334
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
246335

247336
/// Function to get vocabulary key for a given TypeID
248-
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
337+
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID) {
338+
return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
339+
}
249340

250341
/// Function to get vocabulary key for a given OperandKind
251-
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);
342+
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind) {
343+
unsigned Index = static_cast<unsigned>(Kind);
344+
assert(Index < MaxOperandKinds && "Invalid OperandKind");
345+
return OperandKindNames[Index];
346+
}
252347

253348
/// Function to classify an operand into OperandKind
254349
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
255350

256351
/// Function to get vocabulary key for a given predicate
257352
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
258353

259-
/// Functions to return the slot index or position of a given Opcode, TypeID,
260-
/// or OperandKind in the vocabulary.
261-
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
262-
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
263-
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
264-
LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
354+
/// Functions to return flat index
355+
LLVM_ABI static unsigned getIndex(unsigned Opcode) {
356+
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
357+
return Opcode - 1; // Convert to zero-based index
358+
}
359+
360+
LLVM_ABI static unsigned getIndex(Type::TypeID TypeID) {
361+
assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
362+
return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
363+
}
364+
365+
LLVM_ABI static unsigned getIndex(const Value &Op) {
366+
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
367+
assert(Index < MaxOperandKinds && "Invalid OperandKind");
368+
return OperandBaseOffset + Index;
369+
}
370+
371+
LLVM_ABI static unsigned getIndex(CmpInst::Predicate P) {
372+
return PredicateBaseOffset + getPredicateLocalIndex(P);
373+
}
265374

266375
/// Accessors to get the embedding for a given entity.
267-
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
268-
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
269-
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
270-
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
376+
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const {
377+
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
378+
return Storage[static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
379+
}
380+
381+
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeID) const {
382+
assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
383+
unsigned LocalIndex = static_cast<unsigned>(getCanonicalTypeID(TypeID));
384+
return Storage[static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
385+
}
386+
387+
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const {
388+
unsigned LocalIndex = static_cast<unsigned>(getOperandKind(&Arg));
389+
assert(LocalIndex < MaxOperandKinds && "Invalid OperandKind");
390+
return Storage[static_cast<unsigned>(Section::Operands)][LocalIndex];
391+
}
392+
393+
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const {
394+
unsigned LocalIndex = getPredicateLocalIndex(P);
395+
return Storage[static_cast<unsigned>(Section::Predicates)][LocalIndex];
396+
}
271397

272398
/// Const Iterator type aliases
273-
using const_iterator = VocabVector::const_iterator;
399+
using const_iterator = VocabStorage::const_iterator;
400+
274401
const_iterator begin() const {
275402
assert(isValid() && "IR2Vec Vocabulary is invalid");
276-
return Vocab.begin();
403+
return Storage.begin();
277404
}
278405

279-
const_iterator cbegin() const {
280-
assert(isValid() && "IR2Vec Vocabulary is invalid");
281-
return Vocab.cbegin();
282-
}
406+
const_iterator cbegin() const { return begin(); }
283407

284408
const_iterator end() const {
285409
assert(isValid() && "IR2Vec Vocabulary is invalid");
286-
return Vocab.end();
410+
return Storage.end();
287411
}
288412

289-
const_iterator cend() const {
290-
assert(isValid() && "IR2Vec Vocabulary is invalid");
291-
return Vocab.cend();
292-
}
413+
const_iterator cend() const { return end(); }
293414

294415
/// Returns the string key for a given index position in the vocabulary.
295416
/// This is useful for debugging or printing the vocabulary. Do not use this
296417
/// for embedding generation as string based lookups are inefficient.
297418
LLVM_ABI static StringRef getStringKey(unsigned Pos);
298419

299420
/// Create a dummy vocabulary for testing purposes.
300-
LLVM_ABI static VocabVector createDummyVocabForTest(unsigned Dim = 1);
421+
LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1);
301422

302423
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
303424
ModuleAnalysisManager::Invalidator &Inv) const;
@@ -306,12 +427,16 @@ class Vocabulary {
306427
constexpr static unsigned NumCanonicalEntries =
307428
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
308429

309-
// Base offsets for slot layout to simplify index computation
430+
// Base offsets for flat index computation
310431
constexpr static unsigned OperandBaseOffset =
311432
MaxOpcodes + MaxCanonicalTypeIDs;
312433
constexpr static unsigned PredicateBaseOffset =
313434
OperandBaseOffset + MaxOperandKinds;
314435

436+
/// Functions for predicate index calculations
437+
static unsigned getPredicateLocalIndex(CmpInst::Predicate P);
438+
static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex);
439+
315440
/// String mappings for CanonicalTypeID values
316441
static constexpr StringLiteral CanonicalTypeNames[] = {
317442
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
@@ -358,15 +483,26 @@ class Vocabulary {
358483

359484
/// Function to get vocabulary key for canonical type by enum
360485
LLVM_ABI static StringRef
361-
getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
486+
getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
487+
unsigned Index = static_cast<unsigned>(CType);
488+
assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
489+
return CanonicalTypeNames[Index];
490+
}
362491

363492
/// Function to convert TypeID to CanonicalTypeID
364-
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
493+
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID) {
494+
unsigned Index = static_cast<unsigned>(TypeID);
495+
assert(Index < MaxTypeIDs && "Invalid TypeID");
496+
return TypeIDMapping[Index];
497+
}
365498

366499
/// Function to get the predicate enum value for a given index. Index is
367500
/// relative to the predicates section of the vocabulary. E.g., Index 0
368501
/// corresponds to the first predicate.
369-
LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
502+
LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index) {
503+
assert(Index < MaxPredicateKinds && "Invalid predicate index");
504+
return getPredicateFromLocalIndex(Index);
505+
}
370506
};
371507

372508
/// Embedder provides the interface to generate embeddings (vector
@@ -459,22 +595,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
459595
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
460596
/// its corresponding embedding.
461597
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
462-
using VocabVector = std::vector<ir2vec::Embedding>;
463598
using VocabMap = std::map<std::string, ir2vec::Embedding>;
464-
VocabMap OpcVocab, TypeVocab, ArgVocab;
465-
VocabVector Vocab;
599+
std::optional<ir2vec::VocabStorage> Vocab;
466600

467-
Error readVocabulary();
601+
Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
602+
VocabMap &ArgVocab);
468603
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
469604
VocabMap &TargetVocab, unsigned &Dim);
470-
void generateNumMappedVocab();
605+
void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
606+
VocabMap &ArgVocab);
471607
void emitError(Error Err, LLVMContext &Ctx);
472608

473609
public:
474610
LLVM_ABI static AnalysisKey Key;
475611
IR2VecVocabAnalysis() = default;
476-
LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
477-
LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
612+
LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
613+
: Vocab(std::move(Vocab)) {}
478614
using Result = ir2vec::Vocabulary;
479615
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM);
480616
};

0 commit comments

Comments
 (0)