Skip to content

Commit 77fcae2

Browse files
committed
[RTG][InlineSequences] Support interleave_sequences
1 parent 035c13f commit 77fcae2

File tree

2 files changed

+242
-36
lines changed

2 files changed

+242
-36
lines changed

lib/Dialect/RTG/Transforms/InlineSequencesPass.cpp

Lines changed: 157 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "circt/Dialect/RTG/IR/RTGISAAssemblyOpInterfaces.h"
1314
#include "circt/Dialect/RTG/IR/RTGOps.h"
15+
#include "circt/Dialect/RTG/IR/RTGVisitors.h"
1416
#include "circt/Dialect/RTG/Transforms/RTGPasses.h"
1517
#include "mlir/IR/IRMapping.h"
18+
#include "llvm/Support/Debug.h"
1619

1720
namespace circt {
1821
namespace rtg {
@@ -25,6 +28,8 @@ using namespace mlir;
2528
using namespace circt;
2629
using namespace circt::rtg;
2730

31+
#define DEBUG_TYPE "rtg-inline-sequences"
32+
2833
//===----------------------------------------------------------------------===//
2934
// Inline Sequences Pass
3035
//===----------------------------------------------------------------------===//
@@ -35,62 +40,183 @@ struct InlineSequencesPass
3540
using Base::Base;
3641

3742
void runOnOperation() override;
38-
LogicalResult inlineSequences(TestOp testOp, SymbolTable &table);
3943
};
40-
} // namespace
4144

42-
LogicalResult InlineSequencesPass::inlineSequences(TestOp testOp,
43-
SymbolTable &table) {
44-
OpBuilder builder(testOp);
45-
for (auto iter = testOp.getBody()->begin();
46-
iter != testOp.getBody()->end();) {
47-
auto embedOp = dyn_cast<EmbedSequenceOp>(&*iter);
48-
if (!embedOp) {
49-
++iter;
50-
continue;
45+
/// Enum to indicate to the visitor driver whether the operation should be
46+
/// deleted.
47+
enum class DeletionKind { Delete, Keep };
48+
49+
/// The SequenceInliner computes sequence interleavings and inlines them.
50+
struct SequenceInliner
51+
: public RTGOpVisitor<SequenceInliner, FailureOr<DeletionKind>> {
52+
using RTGOpVisitor<SequenceInliner, FailureOr<DeletionKind>>::visitOp;
53+
54+
SequenceInliner(ModuleOp moduleOp) : table(moduleOp) {}
55+
56+
LogicalResult inlineSequences(TestOp testOp);
57+
void materializeInterleavedSequence(Value value, ArrayRef<Block *> blocks,
58+
uint32_t batchSize);
59+
60+
// Visitor methods
61+
62+
FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
63+
SmallVector<Block *> blocks;
64+
for (auto [i, seq] : llvm::enumerate(op.getSequences())) {
65+
auto *block = materializedSequences.lookup(seq);
66+
if (!block)
67+
return op->emitError()
68+
<< "sequence operand #" << i
69+
<< " could not be resolved; it was likely produced by an op or "
70+
"block argument not supported by this pass";
71+
72+
blocks.push_back(block);
5173
}
5274

53-
auto randSeqOp = embedOp.getSequence().getDefiningOp<RandomizeSequenceOp>();
54-
if (!randSeqOp)
55-
return embedOp->emitError("sequence operand not directly defined by "
56-
"'rtg.randomize_sequence' op");
57-
auto getSeqOp = randSeqOp.getSequence().getDefiningOp<GetSequenceOp>();
58-
if (!getSeqOp)
59-
return randSeqOp->emitError(
60-
"sequence operand not directly defined by 'rtg.get_sequence' op");
75+
LLVM_DEBUG(llvm::dbgs()
76+
<< " - Computing sequence interleaving: " << op << "\n");
77+
78+
materializeInterleavedSequence(op.getInterleavedSequence(), blocks,
79+
op.getBatchSize());
80+
return DeletionKind::Delete;
81+
}
82+
83+
FailureOr<DeletionKind> visitOp(GetSequenceOp op) {
84+
auto seqOp = table.lookup<SequenceOp>(op.getSequenceAttr());
85+
if (!seqOp)
86+
return op->emitError() << "referenced sequence not found";
6187

62-
auto seqOp = table.lookup<SequenceOp>(getSeqOp.getSequenceAttr());
88+
LLVM_DEBUG(llvm::dbgs() << " - Registering existing sequence: "
89+
<< op.getSequence() << "\n");
6390

64-
builder.setInsertionPointAfter(embedOp);
91+
materializedSequences[op.getResult()] = seqOp.getBody();
92+
return DeletionKind::Delete;
93+
}
94+
95+
FailureOr<DeletionKind> visitOp(RandomizeSequenceOp op) {
96+
LLVM_DEBUG(llvm::dbgs() << " - Randomize sequence: " << op << "\n");
97+
98+
auto *block = materializedSequences.lookup(op.getSequence());
99+
if (!block)
100+
return op->emitError() << "sequence operand could not be resolved; it "
101+
"was likely produced by an op or block "
102+
"argument not supported by this pass";
103+
104+
materializedSequences[op.getResult()] = block;
105+
return DeletionKind::Delete;
106+
}
107+
108+
FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
109+
LLVM_DEBUG(llvm::dbgs() << " - Inlining sequence: " << op << "\n");
110+
111+
auto *block = materializedSequences.lookup(op.getSequence());
112+
if (!block)
113+
return op->emitError() << "sequence operand could not be resolved; it "
114+
"was likely produced by an op or block "
115+
"argument not supported by this pass";
116+
117+
OpBuilder builder(op);
118+
builder.setInsertionPointAfter(op);
65119
IRMapping mapping;
66-
for (auto &op : *seqOp.getBody())
120+
for (auto &op : *block)
67121
builder.clone(op, mapping);
68122

69-
(iter++)->erase();
123+
++numSequencesInlined;
70124

71-
if (randSeqOp->use_empty())
72-
randSeqOp->erase();
125+
return DeletionKind::Delete;
126+
}
73127

74-
if (getSeqOp->use_empty())
75-
getSeqOp->erase();
128+
FailureOr<DeletionKind> visitUnhandledOp(Operation *op) {
129+
return DeletionKind::Keep;
130+
}
76131

77-
++numSequencesInlined;
132+
FailureOr<DeletionKind> visitExternalOp(Operation *op) {
133+
return DeletionKind::Keep;
134+
}
135+
136+
SymbolTable table;
137+
DenseMap<Value, Block *> materializedSequences;
138+
SmallVector<std::unique_ptr<Block>> blockStorage;
139+
size_t numSequencesInlined = 0;
140+
size_t numSequencesInterleaved = 0;
141+
};
142+
143+
} // namespace
144+
145+
void SequenceInliner::materializeInterleavedSequence(Value value,
146+
ArrayRef<Block *> blocks,
147+
uint32_t batchSize) {
148+
auto interleavedBlock = std::make_unique<Block>();
149+
IRMapping mapping;
150+
OpBuilder builder(value.getContext());
151+
builder.setInsertionPointToStart(interleavedBlock.get());
152+
153+
SmallVector<Block::iterator> iters(blocks.size());
154+
for (auto [i, block] : llvm::enumerate(blocks))
155+
iters[i] = block->begin();
156+
157+
llvm::BitVector finishedBlocks(blocks.size());
158+
unsigned i = 0;
159+
while (!finishedBlocks.all()) {
160+
if (finishedBlocks[i]) {
161+
i = (i + 1) % blocks.size();
162+
continue;
163+
}
164+
165+
for (unsigned k = 0; k < batchSize;) {
166+
if (iters[i] == blocks[i]->end()) {
167+
finishedBlocks.set(i);
168+
break;
169+
}
170+
171+
auto *op = builder.clone(*iters[i], mapping);
172+
if (isa<InstructionOpInterface>(op))
173+
++k;
174+
175+
++iters[i];
176+
}
177+
178+
i = (i + 1) % blocks.size();
78179
}
79180

181+
materializedSequences[value] = interleavedBlock.get();
182+
blockStorage.emplace_back(std::move(interleavedBlock));
183+
numSequencesInterleaved += blocks.size();
184+
}
185+
186+
LogicalResult SequenceInliner::inlineSequences(TestOp testOp) {
187+
LLVM_DEBUG(llvm::dbgs() << "\n=== Processing test @" << testOp.getSymName()
188+
<< "\n\n");
189+
190+
SmallVector<Operation *> toDelete;
191+
for (auto &op : *testOp.getBody()) {
192+
auto result = dispatchOpVisitor(&op);
193+
if (failed(result))
194+
return failure();
195+
196+
if (*result == DeletionKind::Delete)
197+
toDelete.push_back(&op);
198+
}
199+
200+
for (auto *op : llvm::reverse(toDelete))
201+
op->erase();
202+
80203
return success();
81204
}
82205

83206
void InlineSequencesPass::runOnOperation() {
84207
auto moduleOp = getOperation();
85-
SymbolTable table(moduleOp);
208+
SequenceInliner inliner(moduleOp);
86209

87210
// Inline all sequences and remove the operations that place the sequences.
88211
for (auto testOp : moduleOp.getOps<TestOp>())
89-
if (failed(inlineSequences(testOp, table)))
90-
return;
212+
if (failed(inliner.inlineSequences(testOp)))
213+
return signalPassFailure();
91214

92215
// Remove all sequences since they are not accessible from the outside and
93216
// are not needed anymore since we fully inlined them.
94217
for (auto seqOp : llvm::make_early_inc_range(moduleOp.getOps<SequenceOp>()))
95218
seqOp->erase();
219+
220+
numSequencesInlined = inliner.numSequencesInlined;
221+
numSequencesInterleaved = inliner.numSequencesInterleaved;
96222
}

test/Dialect/RTG/Transform/inline-sequences.mlir

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,113 @@
33
// CHECK-NOT: rtg.sequence
44
rtg.sequence @seq0() {
55
rtgtest.rv32i.ebreak
6+
rtgtest.rv32i.ebreak
67
}
78

8-
// CHECK-LABEL: @interleaveSequences
9-
rtg.test @interleaveSequences() {
9+
rtg.sequence @seq1() {
10+
rtgtest.rv32i.ecall
11+
rtgtest.rv32i.ecall
12+
}
13+
14+
// CHECK-LABEL: @inlineSequences
15+
rtg.test @inlineSequences() {
1016
// CHECK-NEXT: rtgtest.rv32i.ecall
1117
// CHECK-NEXT: rtgtest.rv32i.ebreak
18+
// CHECK-NEXT: rtgtest.rv32i.ebreak
1219
// CHECK-NEXT: rtgtest.rv32i.ecall
1320
// CHECK-NEXT: }
14-
1521
%0 = rtg.get_sequence @seq0 : !rtg.sequence
1622
%1 = rtg.randomize_sequence %0
1723
rtgtest.rv32i.ecall
1824
rtg.embed_sequence %1
1925
rtgtest.rv32i.ecall
2026
}
2127

28+
// CHECK-LABEL: @interleaveSequences
29+
rtg.test @interleaveSequences() {
30+
%0 = rtg.get_sequence @seq0 : !rtg.sequence
31+
%1 = rtg.get_sequence @seq1 : !rtg.sequence
32+
%2 = rtg.randomize_sequence %0
33+
%3 = rtg.randomize_sequence %1
34+
35+
// CHECK-NEXT: rtgtest.rv32i.ebreak
36+
// CHECK-NEXT: rtgtest.rv32i.ebreak
37+
// CHECK-NEXT: rtgtest.rv32i.ecall
38+
// CHECK-NEXT: rtgtest.rv32i.ecall
39+
%4 = rtg.interleave_sequences %2, %3 batch 2
40+
rtg.embed_sequence %4
41+
42+
// CHECK-NEXT: rtgtest.rv32i.ebreak
43+
// CHECK-NEXT: rtgtest.rv32i.ecall
44+
// CHECK-NEXT: rtgtest.rv32i.ebreak
45+
// CHECK-NEXT: rtgtest.rv32i.ecall
46+
%5 = rtg.interleave_sequences %2, %3
47+
rtg.embed_sequence %5
48+
49+
// CHECK-NEXT: rtgtest.rv32i.ebreak
50+
// CHECK-NEXT: rtgtest.rv32i.ecall
51+
// CHECK-NEXT: rtgtest.rv32i.ecall
52+
// CHECK-NEXT: rtgtest.rv32i.ecall
53+
// CHECK-NEXT: rtgtest.rv32i.ebreak
54+
// CHECK-NEXT: rtgtest.rv32i.ecall
55+
%6 = rtg.interleave_sequences %2, %3
56+
%7 = rtg.interleave_sequences %6, %3
57+
rtg.embed_sequence %7
58+
59+
// CHECK-NEXT: }
60+
}
61+
62+
rtg.sequence @nested0() {
63+
%ra = rtg.fixed_reg #rtgtest.ra
64+
%sp = rtg.fixed_reg #rtgtest.s0
65+
%imm = rtgtest.immediate #rtgtest.imm12<1>
66+
rtgtest.rv32i.jalr %ra, %sp, %imm
67+
}
68+
69+
rtg.sequence @nested1() {
70+
%0 = rtg.get_sequence @nested0 : !rtg.sequence
71+
%1 = rtg.randomize_sequence %0
72+
rtg.embed_sequence %1
73+
%ra = rtg.fixed_reg #rtgtest.ra
74+
%sp = rtg.fixed_reg #rtgtest.sp
75+
%imm = rtgtest.immediate #rtgtest.imm12<0>
76+
rtgtest.rv32i.jalr %ra, %sp, %imm
77+
}
78+
79+
// CHECK-LABEL: @nestedSequences()
80+
rtg.test @nestedSequences() {
81+
// CHECK-NEXT: [[RA0:%.+]] = rtg.fixed_reg #rtgtest.ra : !rtgtest.ireg
82+
// CHECK-NEXT: [[S0:%.+]] = rtg.fixed_reg #rtgtest.s0 : !rtgtest.ireg
83+
// CHECK-NEXT: [[IMM1:%.+]] = rtgtest.immediate #rtgtest.imm12<1> : !rtgtest.imm12
84+
// CHECK-NEXT: rtgtest.rv32i.jalr [[RA0]], [[S0]], [[IMM1]]
85+
// CHECK-NEXT: [[RA1:%.+]] = rtg.fixed_reg #rtgtest.ra : !rtgtest.ireg
86+
// CHECK-NEXT: [[SP:%.+]] = rtg.fixed_reg #rtgtest.sp : !rtgtest.ireg
87+
// CHECK-NEXT: [[IMM0:%.+]] = rtgtest.immediate #rtgtest.imm12<0> : !rtgtest.imm12
88+
// CHECK-NEXT: rtgtest.rv32i.jalr [[RA1]], [[SP]], [[IMM0]]
89+
%0 = rtg.get_sequence @nested1 : !rtg.sequence
90+
%1 = rtg.randomize_sequence %0
91+
rtg.embed_sequence %1
92+
}
93+
2294
// -----
2395

2496
rtg.test @test0(%seq : !rtg.randomized_sequence) {
25-
// expected-error @below {{sequence operand not directly defined by 'rtg.randomize_sequence' op}}
97+
// expected-error @below {{sequence operand could not be resolved; it was likely produced by an op or block argument not supported by this pass}}
2698
rtg.embed_sequence %seq
2799
}
28100

29101
// -----
30102

31103
rtg.test @test0(%seq : !rtg.sequence) {
32-
// expected-error @below {{sequence operand not directly defined by 'rtg.get_sequence' op}}
104+
// expected-error @below {{sequence operand could not be resolved; it was likely produced by an op or block argument not supported by this pass}}
33105
%0 = rtg.randomize_sequence %seq
34106
rtg.embed_sequence %0
35107
}
108+
109+
// -----
110+
111+
rtg.test @test0(%seq0 : !rtg.randomized_sequence, %seq1 : !rtg.randomized_sequence) {
112+
// expected-error @below {{sequence operand #0 could not be resolved; it was likely produced by an op or block argument not supported by this pass}}
113+
%0 = rtg.interleave_sequences %seq0, %seq1
114+
rtg.embed_sequence %0
115+
}

0 commit comments

Comments
 (0)