Skip to content

Commit 035c13f

Browse files
committed
[RTG][Elaboration] Support interleave_sequences, factor our sequence inlining and label resolution
1 parent 9f33bc0 commit 035c13f

File tree

10 files changed

+478
-155
lines changed

10 files changed

+478
-155
lines changed

include/circt/Dialect/RTG/IR/RTGVisitors.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class RTGOpVisitor {
4747
RandomNumberInRangeOp,
4848
// Sequences
4949
SequenceOp, GetSequenceOp, SubstituteSequenceOp,
50-
RandomizeSequenceOp, EmbedSequenceOp,
50+
RandomizeSequenceOp, EmbedSequenceOp, InterleaveSequencesOp,
5151
// Sets
5252
SetCreateOp, SetSelectRandomOp, SetDifferenceOp, SetUnionOp,
5353
SetSizeOp>([&](auto expr) -> ResultType {
@@ -86,6 +86,7 @@ class RTGOpVisitor {
8686
HANDLE(GetSequenceOp, Unhandled);
8787
HANDLE(SubstituteSequenceOp, Unhandled);
8888
HANDLE(RandomizeSequenceOp, Unhandled);
89+
HANDLE(InterleaveSequencesOp, Unhandled);
8990
HANDLE(EmbedSequenceOp, Unhandled);
9091
HANDLE(RandomNumberInRangeOp, Unhandled);
9192
HANDLE(OnContextOp, Unhandled);

include/circt/Dialect/RTG/Transforms/RTGPasses.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@ def EmitRTGISAAssemblyPass : Pass<"rtg-emit-isa-assembly", "mlir::ModuleOp"> {
6464
];
6565
}
6666

67+
def InlineSequencesPass : Pass<"rtg-inline-sequences", "mlir::ModuleOp"> {
68+
let summary = "inline and interleave sequences";
69+
let description = [{
70+
Inline all sequences into tests and remove the 'rtg.sequence' operations.
71+
Also computes and materializes all interleaved sequences
72+
('interleave_sequences' operation).
73+
}];
74+
75+
let statistics = [
76+
Statistic<"numSequencesInlined", "num-sequences-inlined",
77+
"Number of sequences inlined into another sequence or test.">,
78+
Statistic<"numSequencesInterleaved", "num-sequences-interleaved",
79+
"Number of sequences interleaved with another sequence.">,
80+
];
81+
}
82+
6783
def LinearScanRegisterAllocationPass : Pass<
6884
"rtg-linear-scan-register-allocation", "rtg::TestOp"> {
6985

@@ -81,4 +97,18 @@ def LinearScanRegisterAllocationPass : Pass<
8197
];
8298
}
8399

100+
def LowerUniqueLabelsPass : Pass<"rtg-lower-unique-labels", "mlir::ModuleOp"> {
101+
let summary = "lower label_unique_decl to label_decl operations";
102+
let description = [{
103+
This pass lowers label_unique_decl operations to label_decl operations by
104+
creating a unique label string based on all the existing unique and
105+
non-unique label declarations in the module.
106+
}];
107+
108+
let statistics = [
109+
Statistic<"numLabelsLowered", "num-labels-lowered",
110+
"Number of unique labels lowered to regular label declarations.">,
111+
];
112+
}
113+
84114
#endif // CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_TD

lib/Dialect/RTG/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
add_circt_dialect_library(CIRCTRTGTransforms
22
ElaborationPass.cpp
33
EmitRTGISAAssemblyPass.cpp
4+
InlineSequencesPass.cpp
45
LinearScanRegisterAllocationPass.cpp
6+
LowerUniqueLabelsPass.cpp
57

68
DEPENDS
79
CIRCTRTGTransformsIncGen

lib/Dialect/RTG/Transforms/ElaborationPass.cpp

Lines changed: 109 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ namespace {
8888
struct BagStorage;
8989
struct SequenceStorage;
9090
struct RandomizedSequenceStorage;
91+
struct InterleavedSequenceStorage;
9192
struct SetStorage;
9293
struct VirtualRegisterStorage;
9394
struct UniqueLabelStorage;
@@ -107,8 +108,9 @@ struct LabelValue {
107108
/// The abstract base class for elaborated values.
108109
using ElaboratorValue =
109110
std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
110-
RandomizedSequenceStorage *, SetStorage *,
111-
VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue>;
111+
RandomizedSequenceStorage *, InterleavedSequenceStorage *,
112+
SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
113+
LabelValue>;
112114

113115
// NOLINTNEXTLINE(readability-identifier-naming)
114116
llvm::hash_code hash_value(const LabelValue &val) {
@@ -309,6 +311,34 @@ struct RandomizedSequenceStorage {
309311
const SequenceStorage *sequence;
310312
};
311313

314+
/// Storage object for interleaved '!rtg.randomized_sequence'es.
315+
struct InterleavedSequenceStorage {
316+
InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
317+
uint32_t batchSize)
318+
: sequences(std::move(sequences)), batchSize(batchSize),
319+
hashcode(llvm::hash_combine(
320+
llvm::hash_combine_range(sequences.begin(), sequences.end()),
321+
batchSize)) {}
322+
323+
explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
324+
: sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
325+
hashcode(llvm::hash_combine(
326+
llvm::hash_combine_range(sequences.begin(), sequences.end()),
327+
batchSize)) {}
328+
329+
bool isEqual(const InterleavedSequenceStorage *other) const {
330+
return hashcode == other->hashcode && sequences == other->sequences &&
331+
batchSize == other->batchSize;
332+
}
333+
334+
const SmallVector<ElaboratorValue> sequences;
335+
336+
const uint32_t batchSize;
337+
338+
// The cached hashcode to avoid repeated computations.
339+
const unsigned hashcode;
340+
};
341+
312342
/// Represents a unique virtual register.
313343
struct VirtualRegisterStorage {
314344
VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
@@ -373,6 +403,8 @@ class Internalizer {
373403
return internedSequences;
374404
else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
375405
return internedRandomizedSequences;
406+
else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
407+
return internedInterleavedSequences;
376408
else
377409
static_assert(!sizeof(StorageTy),
378410
"no intern set available for this storage type.");
@@ -392,6 +424,9 @@ class Internalizer {
392424
DenseSet<HashedStorage<RandomizedSequenceStorage>,
393425
StorageKeyInfo<RandomizedSequenceStorage>>
394426
internedRandomizedSequences;
427+
DenseSet<HashedStorage<InterleavedSequenceStorage>,
428+
StorageKeyInfo<InterleavedSequenceStorage>>
429+
internedInterleavedSequences;
395430
};
396431

397432
} // namespace
@@ -438,6 +473,13 @@ static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
438473
os << ") at " << val << ">";
439474
}
440475

476+
static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
477+
os << "<interleaved-sequence [";
478+
llvm::interleaveComma(val->sequences, os,
479+
[&](const ElaboratorValue &val) { os << val; });
480+
os << "] batch-size " << val->batchSize << " at " << val << ">";
481+
}
482+
441483
static void print(SetStorage *val, llvm::raw_ostream &os) {
442484
os << "<set {";
443485
llvm::interleaveComma(val->set, os,
@@ -677,7 +719,25 @@ class Materializer {
677719
elabRequests.push(val);
678720
Value seq = builder.create<GetSequenceOp>(
679721
loc, SequenceType::get(builder.getContext(), {}), val->name);
680-
return builder.create<RandomizeSequenceOp>(loc, seq);
722+
Value res = builder.create<RandomizeSequenceOp>(loc, seq);
723+
materializedValues[val] = res;
724+
return res;
725+
}
726+
727+
Value visit(InterleavedSequenceStorage *val, Location loc,
728+
std::queue<RandomizedSequenceStorage *> &elabRequests,
729+
function_ref<InFlightDiagnostic()> emitError) {
730+
SmallVector<Value> sequences;
731+
for (auto seqVal : val->sequences)
732+
sequences.push_back(materialize(seqVal, loc, elabRequests, emitError));
733+
734+
if (sequences.size() == 1)
735+
return sequences[0];
736+
737+
Value res =
738+
builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
739+
materializedValues[val] = res;
740+
return res;
681741
}
682742

683743
Value visit(VirtualRegisterStorage *val, Location loc,
@@ -735,7 +795,6 @@ struct ElaboratorSharedState {
735795
SymbolTable &table;
736796
std::mt19937 rng;
737797
Namespace names;
738-
Namespace labelNames;
739798
Internalizer internalizer;
740799

741800
/// The worklist used to keep track of the test and sequence operations to
@@ -841,27 +900,57 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
841900
auto *seq = get<SequenceStorage *>(op.getSequence());
842901

843902
auto name = sharedState.names.newName(seq->familyName.getValue());
844-
state[op.getResult()] =
903+
auto *randomizedSeq =
845904
sharedState.internalizer.internalize<RandomizedSequenceStorage>(
846905
name, currentContext, testState.name, seq);
906+
state[op.getResult()] =
907+
sharedState.internalizer.internalize<InterleavedSequenceStorage>(
908+
randomizedSeq);
847909
return DeletionKind::Delete;
848910
}
849911

850-
FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
851-
auto *seq = get<RandomizedSequenceStorage *>(op.getSequence());
852-
if (seq->context != currentContext) {
853-
auto err = op->emitError("attempting to place sequence ")
854-
<< seq->name << " derived from "
855-
<< seq->sequence->familyName.getValue() << " under context "
856-
<< currentContext
857-
<< ", but it was previously randomized for context ";
858-
if (seq->context)
859-
err << seq->context;
860-
else
861-
err << "'default'";
862-
return err;
912+
FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
913+
SmallVector<ElaboratorValue> sequences;
914+
for (auto seq : op.getSequences())
915+
sequences.push_back(get<InterleavedSequenceStorage *>(seq));
916+
917+
state[op.getResult()] =
918+
sharedState.internalizer.internalize<InterleavedSequenceStorage>(
919+
std::move(sequences), op.getBatchSize());
920+
return DeletionKind::Delete;
921+
}
922+
923+
// NOLINTNEXTLINE(misc-no-recursion)
924+
LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
925+
if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
926+
auto *seq = std::get<RandomizedSequenceStorage *>(value);
927+
if (seq->context != currentContext) {
928+
auto err = op->emitError("attempting to place sequence ")
929+
<< seq->name << " derived from "
930+
<< seq->sequence->familyName.getValue() << " under context "
931+
<< currentContext
932+
<< ", but it was previously randomized for context ";
933+
if (seq->context)
934+
err << seq->context;
935+
else
936+
err << "'default'";
937+
return err;
938+
}
939+
return success();
863940
}
864941

942+
auto *interVal = std::get<InterleavedSequenceStorage *>(value);
943+
for (auto val : interVal->sequences)
944+
if (failed(isValidContext(val, op)))
945+
return failure();
946+
return success();
947+
}
948+
949+
FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
950+
auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
951+
if (failed(isValidContext(seqVal, op)))
952+
return failure();
953+
865954
return DeletionKind::Keep;
866955
}
867956

@@ -1039,7 +1128,6 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
10391128
FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
10401129
auto substituted =
10411130
substituteFormatString(op.getFormatStringAttr(), op.getArgs());
1042-
sharedState.labelNames.add(substituted.getValue());
10431131
state[op.getLabel()] = LabelValue(substituted);
10441132
return DeletionKind::Delete;
10451133
}
@@ -1309,7 +1397,6 @@ struct ElaborationPass
13091397
void runOnOperation() override;
13101398
void cloneTargetsIntoTests(SymbolTable &table);
13111399
LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
1312-
LogicalResult inlineSequences(TestOp testOp, SymbolTable &table);
13131400
};
13141401
} // namespace
13151402

@@ -1407,6 +1494,8 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
14071494
auto seqOp = builder.cloneWithoutRegions(familyOp);
14081495
seqOp.getBodyRegion().emplaceBlock();
14091496
seqOp.setSymName(curr->name);
1497+
seqOp.setSequenceType(
1498+
SequenceType::get(builder.getContext(), ArrayRef<Type>{}));
14101499
table.insert(seqOp);
14111500
assert(seqOp.getSymName() == curr->name && "should not have been renamed");
14121501

@@ -1425,64 +1514,5 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
14251514
materializer.finalize();
14261515
}
14271516

1428-
for (auto testOp : moduleOp.getOps<TestOp>()) {
1429-
// Inline all sequences and remove the operations that place the sequences.
1430-
if (failed(inlineSequences(testOp, table)))
1431-
return failure();
1432-
1433-
// Convert 'rtg.label_unique_decl' to 'rtg.label_decl' by choosing a unique
1434-
// name based on the set of names we collected during elaboration.
1435-
for (auto labelOp :
1436-
llvm::make_early_inc_range(testOp.getOps<LabelUniqueDeclOp>())) {
1437-
IRRewriter rewriter(labelOp);
1438-
auto newName = state.labelNames.newName(labelOp.getFormatString());
1439-
rewriter.replaceOpWithNewOp<LabelDeclOp>(labelOp, newName, ValueRange());
1440-
}
1441-
}
1442-
1443-
// Remove all sequences since they are not accessible from the outside and
1444-
// are not needed anymore since we fully inlined them.
1445-
for (auto seqOp : llvm::make_early_inc_range(moduleOp.getOps<SequenceOp>()))
1446-
seqOp->erase();
1447-
1448-
return success();
1449-
}
1450-
1451-
LogicalResult ElaborationPass::inlineSequences(TestOp testOp,
1452-
SymbolTable &table) {
1453-
OpBuilder builder(testOp);
1454-
for (auto iter = testOp.getBody()->begin();
1455-
iter != testOp.getBody()->end();) {
1456-
auto embedOp = dyn_cast<EmbedSequenceOp>(&*iter);
1457-
if (!embedOp) {
1458-
++iter;
1459-
continue;
1460-
}
1461-
1462-
auto randSeqOp = embedOp.getSequence().getDefiningOp<RandomizeSequenceOp>();
1463-
if (!randSeqOp)
1464-
return embedOp->emitError("sequence operand not directly defined by "
1465-
"'rtg.randomize_sequence' op");
1466-
auto getSeqOp = randSeqOp.getSequence().getDefiningOp<GetSequenceOp>();
1467-
if (!getSeqOp)
1468-
return randSeqOp->emitError(
1469-
"sequence operand not directly defined by 'rtg.get_sequence' op");
1470-
1471-
auto seqOp = table.lookup<SequenceOp>(getSeqOp.getSequenceAttr());
1472-
1473-
builder.setInsertionPointAfter(embedOp);
1474-
IRMapping mapping;
1475-
for (auto &op : *seqOp.getBody())
1476-
builder.clone(op, mapping);
1477-
1478-
(iter++)->erase();
1479-
1480-
if (randSeqOp->use_empty())
1481-
randSeqOp->erase();
1482-
1483-
if (getSeqOp->use_empty())
1484-
getSeqOp->erase();
1485-
}
1486-
14871517
return success();
14881518
}

0 commit comments

Comments
 (0)