Skip to content

Commit 9f33bc0

Browse files
committed
[RTG][Elaboration] Use malloc instead of IR for virtual registers and labels
1 parent 1ae382a commit 9f33bc0

File tree

1 file changed

+67
-78
lines changed

1 file changed

+67
-78
lines changed

lib/Dialect/RTG/Transforms/ElaborationPass.cpp

Lines changed: 67 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -89,56 +89,30 @@ struct BagStorage;
8989
struct SequenceStorage;
9090
struct RandomizedSequenceStorage;
9191
struct SetStorage;
92+
struct VirtualRegisterStorage;
93+
struct UniqueLabelStorage;
9294

93-
/// Represents a unique virtual register.
94-
struct VirtualRegister {
95-
VirtualRegister(uint64_t id, ArrayAttr allowedRegs)
96-
: id(id), allowedRegs(allowedRegs) {}
97-
98-
bool operator==(const VirtualRegister &other) const {
99-
assert(
100-
id != other.id ||
101-
allowedRegs == other.allowedRegs &&
102-
"instances with the same ID must have the same allowed registers");
103-
return id == other.id;
104-
}
105-
106-
// The ID of this virtual register.
107-
uint64_t id;
108-
109-
// The list of fixed registers allowed to be selected for this virtual
110-
// register.
111-
ArrayAttr allowedRegs;
112-
};
113-
95+
/// Simple wrapper around a 'StringAttr' such that we know to materialize it as
96+
/// a label declaration instead of calling the builtin dialect constant
97+
/// materializer.
11498
struct LabelValue {
115-
LabelValue(StringAttr name, uint64_t id = 0) : name(name), id(id) {}
99+
LabelValue(StringAttr name) : name(name) {}
116100

117-
bool operator==(const LabelValue &other) const {
118-
return name == other.name && id == other.id;
119-
}
101+
bool operator==(const LabelValue &other) const { return name == other.name; }
120102

121-
/// The label name. For unique labels, this is just the prefix.
103+
/// The label name.
122104
StringAttr name;
123-
124-
/// Standard label declarations always have id=0
125-
uint64_t id;
126105
};
127106

128107
/// The abstract base class for elaborated values.
129108
using ElaboratorValue =
130109
std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
131-
RandomizedSequenceStorage *, SetStorage *, VirtualRegister,
132-
LabelValue>;
133-
134-
// NOLINTNEXTLINE(readability-identifier-naming)
135-
llvm::hash_code hash_value(const VirtualRegister &val) {
136-
return llvm::hash_value(val.id);
137-
}
110+
RandomizedSequenceStorage *, SetStorage *,
111+
VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue>;
138112

139113
// NOLINTNEXTLINE(readability-identifier-naming)
140114
llvm::hash_code hash_value(const LabelValue &val) {
141-
return llvm::hash_combine(val.id, val.name);
115+
return llvm::hash_value(val.name);
142116
}
143117

144118
// NOLINTNEXTLINE(readability-identifier-naming)
@@ -164,32 +138,16 @@ struct DenseMapInfo<bool> {
164138

165139
static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
166140
};
167-
168-
template <>
169-
struct DenseMapInfo<VirtualRegister> {
170-
static inline VirtualRegister getEmptyKey() {
171-
return VirtualRegister(0, ArrayAttr());
172-
}
173-
static inline VirtualRegister getTombstoneKey() {
174-
return VirtualRegister(~0, ArrayAttr());
175-
}
176-
static unsigned getHashValue(const VirtualRegister &val) {
177-
return llvm::hash_combine(val.id, val.allowedRegs);
178-
}
179-
180-
static bool isEqual(const VirtualRegister &lhs, const VirtualRegister &rhs) {
181-
return lhs == rhs;
182-
}
183-
};
184-
185141
template <>
186142
struct DenseMapInfo<LabelValue> {
187-
static inline LabelValue getEmptyKey() { return LabelValue(StringAttr(), 0); }
143+
static inline LabelValue getEmptyKey() {
144+
return DenseMapInfo<StringAttr>::getEmptyKey();
145+
}
188146
static inline LabelValue getTombstoneKey() {
189-
return LabelValue(StringAttr(), ~0);
147+
return DenseMapInfo<StringAttr>::getTombstoneKey();
190148
}
191149
static unsigned getHashValue(const LabelValue &val) {
192-
return llvm::hash_combine(val.name, val.id);
150+
return hash_value(val);
193151
}
194152

195153
static bool isEqual(const LabelValue &lhs, const LabelValue &rhs) {
@@ -351,6 +309,28 @@ struct RandomizedSequenceStorage {
351309
const SequenceStorage *sequence;
352310
};
353311

312+
/// Represents a unique virtual register.
313+
struct VirtualRegisterStorage {
314+
VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
315+
316+
// NOTE: we don't need an 'isEqual' function and 'hashcode' here because
317+
// VirtualRegisters are never internalized.
318+
319+
// The list of fixed registers allowed to be selected for this virtual
320+
// register.
321+
const ArrayAttr allowedRegs;
322+
};
323+
324+
struct UniqueLabelStorage {
325+
UniqueLabelStorage(StringAttr name) : name(name) {}
326+
327+
// NOTE: we don't need an 'isEqual' function and 'hashcode' here because
328+
// VirtualRegisters are never internalized.
329+
330+
/// The label name. For unique labels, this is just the prefix.
331+
const StringAttr name;
332+
};
333+
354334
/// An 'Internalizer' object internalizes storages and takes ownership of them.
355335
/// When the initializer object is destroyed, all owned storages are also
356336
/// deallocated and thus must not be accessed anymore.
@@ -375,6 +355,12 @@ class Internalizer {
375355
return storagePtr;
376356
}
377357

358+
template <typename StorageTy, typename... Args>
359+
StorageTy *create(Args &&...args) {
360+
return new (allocator.Allocate<StorageTy>())
361+
StorageTy(std::forward<Args>(args)...);
362+
}
363+
378364
private:
379365
template <typename StorageTy>
380366
DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
@@ -459,12 +445,16 @@ static void print(SetStorage *val, llvm::raw_ostream &os) {
459445
os << "} at " << val << ">";
460446
}
461447

462-
static void print(const VirtualRegister &val, llvm::raw_ostream &os) {
463-
os << "<virtual-register " << val.id << " " << val.allowedRegs << ">";
448+
static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
449+
os << "<virtual-register " << val << " " << val->allowedRegs << ">";
450+
}
451+
452+
static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
453+
os << "<unique-label " << val << " " << val->name << ">";
464454
}
465455

466456
static void print(const LabelValue &val, llvm::raw_ostream &os) {
467-
os << "<label " << val.id << " " << val.name << ">";
457+
os << "<label " << val.name << ">";
468458
}
469459

470460
static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
@@ -690,24 +680,26 @@ class Materializer {
690680
return builder.create<RandomizeSequenceOp>(loc, seq);
691681
}
692682

693-
Value visit(const VirtualRegister &val, Location loc,
683+
Value visit(VirtualRegisterStorage *val, Location loc,
694684
std::queue<RandomizedSequenceStorage *> &elabRequests,
695685
function_ref<InFlightDiagnostic()> emitError) {
696-
auto res = builder.create<VirtualRegisterOp>(loc, val.allowedRegs);
686+
Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
697687
materializedValues[val] = res;
698688
return res;
699689
}
700690

701-
Value visit(const LabelValue &val, Location loc,
691+
Value visit(UniqueLabelStorage *val, Location loc,
702692
std::queue<RandomizedSequenceStorage *> &elabRequests,
703693
function_ref<InFlightDiagnostic()> emitError) {
704-
if (val.id == 0) {
705-
auto res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
706-
materializedValues[val] = res;
707-
return res;
708-
}
694+
Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
695+
materializedValues[val] = res;
696+
return res;
697+
}
709698

710-
auto res = builder.create<LabelUniqueDeclOp>(loc, val.name, ValueRange());
699+
Value visit(const LabelValue &val, Location loc,
700+
std::queue<RandomizedSequenceStorage *> &elabRequests,
701+
function_ref<InFlightDiagnostic()> emitError) {
702+
Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
711703
materializedValues[val] = res;
712704
return res;
713705
}
@@ -749,9 +741,6 @@ struct ElaboratorSharedState {
749741
/// The worklist used to keep track of the test and sequence operations to
750742
/// make sure they are processed top-down (BFS traversal).
751743
std::queue<RandomizedSequenceStorage *> worklist;
752-
753-
uint64_t virtualRegisterID = 0;
754-
uint64_t uniqueLabelID = 1;
755744
};
756745

757746
/// A collection of state per RTG test.
@@ -1023,8 +1012,9 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
10231012
}
10241013

10251014
FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
1026-
state[op.getResult()] = VirtualRegister(sharedState.virtualRegisterID++,
1027-
op.getAllowedRegsAttr());
1015+
state[op.getResult()] =
1016+
sharedState.internalizer.create<VirtualRegisterStorage>(
1017+
op.getAllowedRegsAttr());
10281018
return DeletionKind::Delete;
10291019
}
10301020

@@ -1055,9 +1045,8 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
10551045
}
10561046

10571047
FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
1058-
state[op.getLabel()] = LabelValue(
1059-
substituteFormatString(op.getFormatStringAttr(), op.getArgs()),
1060-
sharedState.uniqueLabelID++);
1048+
state[op.getLabel()] = sharedState.internalizer.create<UniqueLabelStorage>(
1049+
substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
10611050
return DeletionKind::Delete;
10621051
}
10631052

0 commit comments

Comments
 (0)