Skip to content

Commit f6665f1

Browse files
committed
[RTG][Elaboration] Add support for context operations
1 parent b2a9f52 commit f6665f1

File tree

2 files changed

+239
-12
lines changed

2 files changed

+239
-12
lines changed

lib/Dialect/RTG/Transforms/ElaborationPass.cpp

Lines changed: 123 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//
1414
//===----------------------------------------------------------------------===//
1515

16+
#include "circt/Dialect/RTG/IR/RTGAttributes.h"
1617
#include "circt/Dialect/RTG/IR/RTGOps.h"
1718
#include "circt/Dialect/RTG/IR/RTGVisitors.h"
1819
#include "circt/Dialect/RTG/Transforms/RTGPasses.h"
@@ -323,12 +324,16 @@ struct SequenceStorage {
323324

324325
/// Storage object for an '!rtg.randomized_sequence'.
325326
struct RandomizedSequenceStorage {
326-
RandomizedSequenceStorage(StringRef name, SequenceStorage *sequence)
327-
: hashcode(llvm::hash_combine(name, sequence)), name(name),
328-
sequence(sequence) {}
327+
RandomizedSequenceStorage(StringRef name,
328+
ContextResourceAttrInterface context,
329+
StringAttr test, SequenceStorage *sequence)
330+
: hashcode(llvm::hash_combine(name, context, test, sequence)), name(name),
331+
context(context), test(test), sequence(sequence) {}
329332

330333
bool isEqual(const RandomizedSequenceStorage *other) const {
331-
return hashcode == other->hashcode && sequence == other->sequence;
334+
return hashcode == other->hashcode && name == other->name &&
335+
context == other->context && test == other->test &&
336+
sequence == other->sequence;
332337
}
333338

334339
// The cached hashcode to avoid repeated computations.
@@ -337,6 +342,12 @@ struct RandomizedSequenceStorage {
337342
// The name of this fully substituted and elaborated sequence.
338343
const StringRef name;
339344

345+
// The context under which this sequence is placed.
346+
const ContextResourceAttrInterface context;
347+
348+
// The test in which this sequence is placed.
349+
const StringAttr test;
350+
340351
const SequenceStorage *sequence;
341352
};
342353

@@ -434,7 +445,8 @@ static void print(SequenceStorage *val, llvm::raw_ostream &os) {
434445

435446
static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
436447
os << "<randomized-sequence @" << val->name << " derived from @"
437-
<< val->sequence->familyName.getValue() << "(";
448+
<< val->sequence->familyName.getValue() << " under context "
449+
<< val->context << " in test " << val->test << "(";
438450
llvm::interleaveComma(val->sequence->args, os,
439451
[&](const ElaboratorValue &val) { os << val; });
440452
os << ") at " << val << ">";
@@ -560,6 +572,11 @@ class Materializer {
560572
op->erase();
561573
}
562574

575+
template <typename OpTy, typename... Args>
576+
OpTy create(Location location, Args &&...args) {
577+
return builder.create<OpTy>(location, std::forward<Args>(args)...);
578+
}
579+
563580
private:
564581
void deleteOpsUntil(function_ref<bool(Block::iterator)> stop) {
565582
auto ip = builder.getInsertionPoint();
@@ -737,14 +754,29 @@ struct ElaboratorSharedState {
737754
uint64_t uniqueLabelID = 1;
738755
};
739756

757+
/// A collection of state per RTG test.
758+
struct TestState {
759+
/// The name of the test.
760+
StringAttr name;
761+
762+
/// The context switches registered for this test.
763+
MapVector<
764+
std::pair<ContextResourceAttrInterface, ContextResourceAttrInterface>,
765+
SequenceStorage *>
766+
contextSwitches;
767+
};
768+
740769
/// Interprets the IR to perform and lower the represented randomizations.
741770
class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
742771
public:
743772
using RTGBase = RTGOpVisitor<Elaborator, FailureOr<DeletionKind>>;
744773
using RTGBase::visitOp;
745774

746-
Elaborator(ElaboratorSharedState &sharedState, Materializer &materializer)
747-
: sharedState(sharedState), materializer(materializer) {}
775+
Elaborator(ElaboratorSharedState &sharedState, TestState &testState,
776+
Materializer &materializer,
777+
ContextResourceAttrInterface currentContext = {})
778+
: sharedState(sharedState), testState(testState),
779+
materializer(materializer), currentContext(currentContext) {}
748780

749781
template <typename ValueTy>
750782
inline ValueTy get(Value val) const {
@@ -821,12 +853,26 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
821853

822854
auto name = sharedState.names.newName(seq->familyName.getValue());
823855
state[op.getResult()] =
824-
sharedState.internalizer.internalize<RandomizedSequenceStorage>(name,
825-
seq);
856+
sharedState.internalizer.internalize<RandomizedSequenceStorage>(
857+
name, currentContext, testState.name, seq);
826858
return DeletionKind::Delete;
827859
}
828860

829861
FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
862+
auto *seq = get<RandomizedSequenceStorage *>(op.getSequence());
863+
if (seq->context != currentContext) {
864+
auto err = op->emitError("attempting to place sequence ")
865+
<< seq->name << " derived from "
866+
<< seq->sequence->familyName.getValue() << " under context "
867+
<< currentContext
868+
<< ", but it was previously randomized for context ";
869+
if (seq->context)
870+
err << seq->context;
871+
else
872+
err << "'default'";
873+
return err;
874+
}
875+
830876
return DeletionKind::Keep;
831877
}
832878

@@ -1036,6 +1082,60 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
10361082
return DeletionKind::Delete;
10371083
}
10381084

1085+
FailureOr<DeletionKind> visitOp(OnContextOp op) {
1086+
ContextResourceAttrInterface from = currentContext,
1087+
to = cast<ContextResourceAttrInterface>(
1088+
get<TypedAttr>(op.getContext()));
1089+
if (!currentContext)
1090+
from = DefaultContextAttr::get(op->getContext(), to.getType());
1091+
1092+
auto emitError = [&]() {
1093+
auto diag = op.emitError();
1094+
diag.attachNote(op.getLoc())
1095+
<< "while materializing value for context switching for " << op;
1096+
return diag;
1097+
};
1098+
1099+
if (from == to) {
1100+
Value seqVal = materializer.materialize(
1101+
get<SequenceStorage *>(op.getSequence()), op.getLoc(),
1102+
sharedState.worklist, emitError);
1103+
Value randSeqVal =
1104+
materializer.create<RandomizeSequenceOp>(op.getLoc(), seqVal);
1105+
materializer.create<EmbedSequenceOp>(op.getLoc(), randSeqVal);
1106+
return DeletionKind::Delete;
1107+
}
1108+
1109+
// Switch to the desired context.
1110+
auto *iter = testState.contextSwitches.find({from, to});
1111+
// NOTE: we could think about supporting context switching via intermediate
1112+
// context, i.e., treat it as a transitive relation.
1113+
if (iter == testState.contextSwitches.end())
1114+
return op->emitError("no context transition registered to switch from ")
1115+
<< from << " to " << to;
1116+
1117+
auto familyName = iter->second->familyName;
1118+
SmallVector<ElaboratorValue> args{from, to,
1119+
get<SequenceStorage *>(op.getSequence())};
1120+
auto *seq = sharedState.internalizer.internalize<SequenceStorage>(
1121+
familyName, std::move(args));
1122+
auto *randSeq =
1123+
sharedState.internalizer.internalize<RandomizedSequenceStorage>(
1124+
sharedState.names.newName(familyName.getValue()), to,
1125+
testState.name, seq);
1126+
Value seqVal = materializer.materialize(randSeq, op.getLoc(),
1127+
sharedState.worklist, emitError);
1128+
materializer.create<EmbedSequenceOp>(op.getLoc(), seqVal);
1129+
1130+
return DeletionKind::Delete;
1131+
}
1132+
1133+
FailureOr<DeletionKind> visitOp(ContextSwitchOp op) {
1134+
testState.contextSwitches[{op.getFromAttr(), op.getToAttr()}] =
1135+
get<SequenceStorage *>(op.getSequence());
1136+
return DeletionKind::Delete;
1137+
}
1138+
10391139
FailureOr<DeletionKind> visitOp(scf::IfOp op) {
10401140
bool cond = get<bool>(op.getCondition());
10411141
auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
@@ -1193,12 +1293,18 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
11931293
// State to be shared between all elaborator instances.
11941294
ElaboratorSharedState &sharedState;
11951295

1296+
// State to a specific RTG test and the sequences placed within it.
1297+
TestState &testState;
1298+
11961299
// Allows us to materialize ElaboratorValues to the IR operations necessary to
11971300
// obtain an SSA value representing that elaborated value.
11981301
Materializer &materializer;
11991302

12001303
// A map from SSA values to a pointer of an interned elaborator value.
12011304
DenseMap<Value, ElaboratorValue> state;
1305+
1306+
// The current context we are elaborating under.
1307+
ContextResourceAttrInterface currentContext;
12021308
};
12031309
} // namespace
12041310

@@ -1282,11 +1388,14 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
12821388

12831389
// Initialize the worklist with the test ops since they cannot be placed by
12841390
// other ops.
1391+
DenseMap<StringAttr, TestState> testStates;
12851392
for (auto testOp : moduleOp.getOps<TestOp>()) {
12861393
LLVM_DEBUG(llvm::dbgs()
12871394
<< "\n=== Elaborating test @" << testOp.getSymName() << "\n\n");
12881395
Materializer materializer(OpBuilder::atBlockBegin(testOp.getBody()));
1289-
Elaborator elaborator(state, materializer);
1396+
testStates[testOp.getSymNameAttr()].name = testOp.getSymNameAttr();
1397+
Elaborator elaborator(state, testStates[testOp.getSymNameAttr()],
1398+
materializer);
12901399
if (failed(elaborator.elaborate(testOp.getBodyRegion())))
12911400
return failure();
12921401

@@ -1314,10 +1423,12 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
13141423

13151424
LLVM_DEBUG(llvm::dbgs()
13161425
<< "\n=== Elaborating sequence family @" << familyOp.getSymName()
1317-
<< " into @" << seqOp.getSymName() << "\n\n");
1426+
<< " into @" << seqOp.getSymName() << " under context "
1427+
<< curr->context << "\n\n");
13181428

13191429
Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody()));
1320-
Elaborator elaborator(state, materializer);
1430+
Elaborator elaborator(state, testStates[curr->test], materializer,
1431+
curr->context);
13211432
if (failed(elaborator.elaborate(familyOp.getBodyRegion(),
13221433
curr->sequence->args)))
13231434
return failure();

test/Dialect/RTG/Transform/elaboration.mlir

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,78 @@ rtg.test @randomIntegers : !rtg.dict<> {
389389
func.call @dummy2(%1) : (index) -> ()
390390
}
391391

392+
// CHECK-LABEL: rtg.test @contexts_contextCpu
393+
rtg.test @contexts : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
394+
^bb0(%cpu0: !rtgtest.cpu, %cpu1: !rtgtest.cpu):
395+
// CHECK-NEXT: rtg.label_decl "label0"
396+
// CHECK-NEXT: rtg.label
397+
// CHECK-NEXT: rtg.label_decl "label5"
398+
// CHECK-NEXT: rtg.label
399+
// CHECK-NEXT: rtg.label_decl "label2"
400+
// CHECK-NEXT: rtg.label
401+
// CHECK-NEXT: rtg.label_decl "label7"
402+
// CHECK-NEXT: rtg.label
403+
// CHECK-NEXT: rtg.label_decl "label4"
404+
// CHECK-NEXT: rtg.label
405+
// CHECK-NEXT: rtg.label_decl "label8"
406+
// CHECK-NEXT: rtg.label
407+
// CHECK-NEXT: rtg.label_decl "label3"
408+
// CHECK-NEXT: rtg.label
409+
// CHECK-NEXT: rtg.label_decl "label6"
410+
// CHECK-NEXT: rtg.label
411+
// CHECK-NEXT: rtg.label_decl "label1"
412+
// CHECK-NEXT: rtg.label
413+
%0 = rtg.get_sequence @cpuSeq : !rtg.sequence<!rtgtest.cpu>
414+
%1 = rtg.substitute_sequence %0(%cpu1) : !rtg.sequence<!rtgtest.cpu>
415+
%l0 = rtg.label_decl "label0"
416+
rtg.label local %l0
417+
rtg.on_context %cpu0, %1 : !rtgtest.cpu
418+
%l1 = rtg.label_decl "label1"
419+
rtg.label local %l1
420+
}
421+
422+
rtg.target @contextCpu : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
423+
%cpu0 = rtgtest.cpu_decl <0>
424+
%cpu1 = rtgtest.cpu_decl <1>
425+
%0 = rtg.get_sequence @switchCpuSeq : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
426+
%1 = rtg.get_sequence @switchNestedCpuSeq : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
427+
rtg.context_switch #rtg.default : !rtgtest.cpu -> #rtgtest.cpu<0> : !rtgtest.cpu, %0 : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
428+
rtg.context_switch #rtgtest.cpu<0> : !rtgtest.cpu -> #rtgtest.cpu<1> : !rtgtest.cpu, %1 : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
429+
rtg.yield %cpu0, %cpu1 : !rtgtest.cpu, !rtgtest.cpu
430+
}
431+
432+
rtg.sequence @cpuSeq(%cpu: !rtgtest.cpu) {
433+
%l2 = rtg.label_decl "label2"
434+
rtg.label local %l2
435+
%0 = rtg.get_sequence @nestedCpuSeq : !rtg.sequence
436+
rtg.on_context %cpu, %0 : !rtgtest.cpu
437+
%l3 = rtg.label_decl "label3"
438+
rtg.label local %l3
439+
}
440+
441+
rtg.sequence @nestedCpuSeq() {
442+
%l4 = rtg.label_decl "label4"
443+
rtg.label local %l4
444+
}
445+
446+
rtg.sequence @switchCpuSeq(%parent: !rtgtest.cpu, %child: !rtgtest.cpu, %seq: !rtg.sequence) {
447+
%l5 = rtg.label_decl "label5"
448+
rtg.label local %l5
449+
%0 = rtg.randomize_sequence %seq
450+
rtg.embed_sequence %0
451+
%l6 = rtg.label_decl "label6"
452+
rtg.label local %l6
453+
}
454+
455+
rtg.sequence @switchNestedCpuSeq(%parent: !rtgtest.cpu, %child: !rtgtest.cpu, %seq: !rtg.sequence) {
456+
%l7 = rtg.label_decl "label7"
457+
rtg.label local %l7
458+
%0 = rtg.randomize_sequence %seq
459+
rtg.embed_sequence %0
460+
%l8 = rtg.label_decl "label8"
461+
rtg.label local %l8
462+
}
463+
392464
// -----
393465

394466
rtg.test @nestedRegionsNotSupported : !rtg.dict<> {
@@ -424,3 +496,47 @@ rtg.test @randomIntegers : !rtg.dict<> {
424496
%0 = rtg.random_number_in_range [%c5, %c5)
425497
func.call @dummy2(%0) : (index) -> ()
426498
}
499+
500+
// -----
501+
502+
rtg.sequence @seq0(%seq: !rtg.randomized_sequence) {
503+
// expected-error @below {{attempting to place sequence seq1_0 derived from seq1 under context #rtgtest.cpu<0> : !rtgtest.cpu, but it was previously randomized for context 'default'}}
504+
rtg.embed_sequence %seq
505+
}
506+
rtg.sequence @seq1() { }
507+
rtg.sequence @seq(%arg0: !rtgtest.cpu, %arg1: !rtgtest.cpu, %seq: !rtg.sequence) {
508+
%0 = rtg.randomize_sequence %seq
509+
rtg.embed_sequence %0
510+
}
511+
512+
rtg.target @invalidRandomizationTarget : !rtg.dict<cpu: !rtgtest.cpu> {
513+
%0 = rtg.get_sequence @seq : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
514+
rtg.context_switch #rtg.default : !rtgtest.cpu -> #rtgtest.cpu<0>, %0 : !rtg.sequence<!rtgtest.cpu, !rtgtest.cpu, !rtg.sequence>
515+
%1 = rtgtest.cpu_decl <0>
516+
rtg.yield %1 : !rtgtest.cpu
517+
}
518+
519+
rtg.test @invalidRandomization : !rtg.dict<cpu: !rtgtest.cpu> {
520+
^bb0(%cpu: !rtgtest.cpu):
521+
%0 = rtg.get_sequence @seq1 : !rtg.sequence
522+
%1 = rtg.randomize_sequence %0
523+
%2 = rtg.get_sequence @seq0 : !rtg.sequence<!rtg.randomized_sequence>
524+
%3 = rtg.substitute_sequence %2(%1) : !rtg.sequence<!rtg.randomized_sequence>
525+
rtg.on_context %cpu, %3 : !rtgtest.cpu
526+
}
527+
528+
// -----
529+
530+
rtg.sequence @seq() {}
531+
532+
rtg.target @target : !rtg.dict<cpu: !rtgtest.cpu> {
533+
%0 = rtgtest.cpu_decl <0>
534+
rtg.yield %0 : !rtgtest.cpu
535+
}
536+
537+
rtg.test @contextSwitchNotAvailable : !rtg.dict<cpu: !rtgtest.cpu> {
538+
^bb0(%cpu: !rtgtest.cpu):
539+
%0 = rtg.get_sequence @seq : !rtg.sequence
540+
// expected-error @below {{no context transition registered to switch from #rtg.default : !rtgtest.cpu to #rtgtest.cpu<0> : !rtgtest.cpu}}
541+
rtg.on_context %cpu, %0 : !rtgtest.cpu
542+
}

0 commit comments

Comments
 (0)