@@ -88,6 +88,7 @@ namespace {
8888struct BagStorage ;
8989struct SequenceStorage ;
9090struct RandomizedSequenceStorage ;
91+ struct InterleavedSequenceStorage ;
9192struct SetStorage ;
9293struct VirtualRegisterStorage ;
9394struct UniqueLabelStorage ;
@@ -107,8 +108,9 @@ struct LabelValue {
107108// / The abstract base class for elaborated values.
108109using ElaboratorValue =
109110 std::variant<TypedAttr, BagStorage *, bool , size_t , SequenceStorage *,
110- RandomizedSequenceStorage *, SetStorage *,
111- VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue>;
111+ RandomizedSequenceStorage *, InterleavedSequenceStorage *,
112+ SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
113+ LabelValue>;
112114
113115// NOLINTNEXTLINE(readability-identifier-naming)
114116llvm::hash_code hash_value (const LabelValue &val) {
@@ -309,6 +311,34 @@ struct RandomizedSequenceStorage {
309311 const SequenceStorage *sequence;
310312};
311313
314+ // / Storage object for interleaved '!rtg.randomized_sequence'es.
315+ struct InterleavedSequenceStorage {
316+ InterleavedSequenceStorage (SmallVector<ElaboratorValue> &&sequences,
317+ uint32_t batchSize)
318+ : sequences(std::move(sequences)), batchSize(batchSize),
319+ hashcode (llvm::hash_combine(
320+ llvm::hash_combine_range (sequences.begin(), sequences.end()),
321+ batchSize)) {}
322+
323+ explicit InterleavedSequenceStorage (RandomizedSequenceStorage *sequence)
324+ : sequences(SmallVector<ElaboratorValue>(1 , sequence)), batchSize(1 ),
325+ hashcode(llvm::hash_combine(
326+ llvm::hash_combine_range (sequences.begin(), sequences.end()),
327+ batchSize)) {}
328+
329+ bool isEqual (const InterleavedSequenceStorage *other) const {
330+ return hashcode == other->hashcode && sequences == other->sequences &&
331+ batchSize == other->batchSize ;
332+ }
333+
334+ const SmallVector<ElaboratorValue> sequences;
335+
336+ const uint32_t batchSize;
337+
338+ // The cached hashcode to avoid repeated computations.
339+ const unsigned hashcode;
340+ };
341+
312342// / Represents a unique virtual register.
313343struct VirtualRegisterStorage {
314344 VirtualRegisterStorage (ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
@@ -373,6 +403,8 @@ class Internalizer {
373403 return internedSequences;
374404 else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
375405 return internedRandomizedSequences;
406+ else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
407+ return internedInterleavedSequences;
376408 else
377409 static_assert (!sizeof (StorageTy),
378410 " no intern set available for this storage type." );
@@ -392,6 +424,9 @@ class Internalizer {
392424 DenseSet<HashedStorage<RandomizedSequenceStorage>,
393425 StorageKeyInfo<RandomizedSequenceStorage>>
394426 internedRandomizedSequences;
427+ DenseSet<HashedStorage<InterleavedSequenceStorage>,
428+ StorageKeyInfo<InterleavedSequenceStorage>>
429+ internedInterleavedSequences;
395430};
396431
397432} // namespace
@@ -438,6 +473,13 @@ static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
438473 os << " ) at " << val << " >" ;
439474}
440475
476+ static void print (InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
477+ os << " <interleaved-sequence [" ;
478+ llvm::interleaveComma (val->sequences , os,
479+ [&](const ElaboratorValue &val) { os << val; });
480+ os << " ] batch-size " << val->batchSize << " at " << val << " >" ;
481+ }
482+
441483static void print (SetStorage *val, llvm::raw_ostream &os) {
442484 os << " <set {" ;
443485 llvm::interleaveComma (val->set , os,
@@ -677,7 +719,25 @@ class Materializer {
677719 elabRequests.push (val);
678720 Value seq = builder.create <GetSequenceOp>(
679721 loc, SequenceType::get (builder.getContext (), {}), val->name );
680- return builder.create <RandomizeSequenceOp>(loc, seq);
722+ Value res = builder.create <RandomizeSequenceOp>(loc, seq);
723+ materializedValues[val] = res;
724+ return res;
725+ }
726+
727+ Value visit (InterleavedSequenceStorage *val, Location loc,
728+ std::queue<RandomizedSequenceStorage *> &elabRequests,
729+ function_ref<InFlightDiagnostic()> emitError) {
730+ SmallVector<Value> sequences;
731+ for (auto seqVal : val->sequences )
732+ sequences.push_back (materialize (seqVal, loc, elabRequests, emitError));
733+
734+ if (sequences.size () == 1 )
735+ return sequences[0 ];
736+
737+ Value res =
738+ builder.create <InterleaveSequencesOp>(loc, sequences, val->batchSize );
739+ materializedValues[val] = res;
740+ return res;
681741 }
682742
683743 Value visit (VirtualRegisterStorage *val, Location loc,
@@ -735,7 +795,6 @@ struct ElaboratorSharedState {
735795 SymbolTable &table;
736796 std::mt19937 rng;
737797 Namespace names;
738- Namespace labelNames;
739798 Internalizer internalizer;
740799
741800 // / The worklist used to keep track of the test and sequence operations to
@@ -841,27 +900,57 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
841900 auto *seq = get<SequenceStorage *>(op.getSequence ());
842901
843902 auto name = sharedState.names .newName (seq->familyName .getValue ());
844- state[op. getResult ()] =
903+ auto *randomizedSeq =
845904 sharedState.internalizer .internalize <RandomizedSequenceStorage>(
846905 name, currentContext, testState.name , seq);
906+ state[op.getResult ()] =
907+ sharedState.internalizer .internalize <InterleavedSequenceStorage>(
908+ randomizedSeq);
847909 return DeletionKind::Delete;
848910 }
849911
850- FailureOr<DeletionKind> visitOp (EmbedSequenceOp op) {
851- auto *seq = get<RandomizedSequenceStorage *>(op.getSequence ());
852- if (seq->context != currentContext) {
853- auto err = op->emitError (" attempting to place sequence " )
854- << seq->name << " derived from "
855- << seq->sequence ->familyName .getValue () << " under context "
856- << currentContext
857- << " , but it was previously randomized for context " ;
858- if (seq->context )
859- err << seq->context ;
860- else
861- err << " 'default'" ;
862- return err;
912+ FailureOr<DeletionKind> visitOp (InterleaveSequencesOp op) {
913+ SmallVector<ElaboratorValue> sequences;
914+ for (auto seq : op.getSequences ())
915+ sequences.push_back (get<InterleavedSequenceStorage *>(seq));
916+
917+ state[op.getResult ()] =
918+ sharedState.internalizer .internalize <InterleavedSequenceStorage>(
919+ std::move (sequences), op.getBatchSize ());
920+ return DeletionKind::Delete;
921+ }
922+
923+ // NOLINTNEXTLINE(misc-no-recursion)
924+ LogicalResult isValidContext (ElaboratorValue value, Operation *op) const {
925+ if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
926+ auto *seq = std::get<RandomizedSequenceStorage *>(value);
927+ if (seq->context != currentContext) {
928+ auto err = op->emitError (" attempting to place sequence " )
929+ << seq->name << " derived from "
930+ << seq->sequence ->familyName .getValue () << " under context "
931+ << currentContext
932+ << " , but it was previously randomized for context " ;
933+ if (seq->context )
934+ err << seq->context ;
935+ else
936+ err << " 'default'" ;
937+ return err;
938+ }
939+ return success ();
863940 }
864941
942+ auto *interVal = std::get<InterleavedSequenceStorage *>(value);
943+ for (auto val : interVal->sequences )
944+ if (failed (isValidContext (val, op)))
945+ return failure ();
946+ return success ();
947+ }
948+
949+ FailureOr<DeletionKind> visitOp (EmbedSequenceOp op) {
950+ auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence ());
951+ if (failed (isValidContext (seqVal, op)))
952+ return failure ();
953+
865954 return DeletionKind::Keep;
866955 }
867956
@@ -1039,7 +1128,6 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
10391128 FailureOr<DeletionKind> visitOp (LabelDeclOp op) {
10401129 auto substituted =
10411130 substituteFormatString (op.getFormatStringAttr (), op.getArgs ());
1042- sharedState.labelNames .add (substituted.getValue ());
10431131 state[op.getLabel ()] = LabelValue (substituted);
10441132 return DeletionKind::Delete;
10451133 }
@@ -1309,7 +1397,6 @@ struct ElaborationPass
13091397 void runOnOperation () override ;
13101398 void cloneTargetsIntoTests (SymbolTable &table);
13111399 LogicalResult elaborateModule (ModuleOp moduleOp, SymbolTable &table);
1312- LogicalResult inlineSequences (TestOp testOp, SymbolTable &table);
13131400};
13141401} // namespace
13151402
@@ -1407,6 +1494,8 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
14071494 auto seqOp = builder.cloneWithoutRegions (familyOp);
14081495 seqOp.getBodyRegion ().emplaceBlock ();
14091496 seqOp.setSymName (curr->name );
1497+ seqOp.setSequenceType (
1498+ SequenceType::get (builder.getContext (), ArrayRef<Type>{}));
14101499 table.insert (seqOp);
14111500 assert (seqOp.getSymName () == curr->name && " should not have been renamed" );
14121501
@@ -1425,64 +1514,5 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
14251514 materializer.finalize ();
14261515 }
14271516
1428- for (auto testOp : moduleOp.getOps <TestOp>()) {
1429- // Inline all sequences and remove the operations that place the sequences.
1430- if (failed (inlineSequences (testOp, table)))
1431- return failure ();
1432-
1433- // Convert 'rtg.label_unique_decl' to 'rtg.label_decl' by choosing a unique
1434- // name based on the set of names we collected during elaboration.
1435- for (auto labelOp :
1436- llvm::make_early_inc_range (testOp.getOps <LabelUniqueDeclOp>())) {
1437- IRRewriter rewriter (labelOp);
1438- auto newName = state.labelNames .newName (labelOp.getFormatString ());
1439- rewriter.replaceOpWithNewOp <LabelDeclOp>(labelOp, newName, ValueRange ());
1440- }
1441- }
1442-
1443- // Remove all sequences since they are not accessible from the outside and
1444- // are not needed anymore since we fully inlined them.
1445- for (auto seqOp : llvm::make_early_inc_range (moduleOp.getOps <SequenceOp>()))
1446- seqOp->erase ();
1447-
1448- return success ();
1449- }
1450-
1451- LogicalResult ElaborationPass::inlineSequences (TestOp testOp,
1452- SymbolTable &table) {
1453- OpBuilder builder (testOp);
1454- for (auto iter = testOp.getBody ()->begin ();
1455- iter != testOp.getBody ()->end ();) {
1456- auto embedOp = dyn_cast<EmbedSequenceOp>(&*iter);
1457- if (!embedOp) {
1458- ++iter;
1459- continue ;
1460- }
1461-
1462- auto randSeqOp = embedOp.getSequence ().getDefiningOp <RandomizeSequenceOp>();
1463- if (!randSeqOp)
1464- return embedOp->emitError (" sequence operand not directly defined by "
1465- " 'rtg.randomize_sequence' op" );
1466- auto getSeqOp = randSeqOp.getSequence ().getDefiningOp <GetSequenceOp>();
1467- if (!getSeqOp)
1468- return randSeqOp->emitError (
1469- " sequence operand not directly defined by 'rtg.get_sequence' op" );
1470-
1471- auto seqOp = table.lookup <SequenceOp>(getSeqOp.getSequenceAttr ());
1472-
1473- builder.setInsertionPointAfter (embedOp);
1474- IRMapping mapping;
1475- for (auto &op : *seqOp.getBody ())
1476- builder.clone (op, mapping);
1477-
1478- (iter++)->erase ();
1479-
1480- if (randSeqOp->use_empty ())
1481- randSeqOp->erase ();
1482-
1483- if (getSeqOp->use_empty ())
1484- getSeqOp->erase ();
1485- }
1486-
14871517 return success ();
14881518}
0 commit comments