1010//
1111// ===----------------------------------------------------------------------===//
1212
13+ #include " circt/Dialect/RTG/IR/RTGISAAssemblyOpInterfaces.h"
1314#include " circt/Dialect/RTG/IR/RTGOps.h"
15+ #include " circt/Dialect/RTG/IR/RTGVisitors.h"
1416#include " circt/Dialect/RTG/Transforms/RTGPasses.h"
1517#include " mlir/IR/IRMapping.h"
18+ #include " llvm/Support/Debug.h"
1619
1720namespace circt {
1821namespace rtg {
@@ -25,6 +28,8 @@ using namespace mlir;
2528using namespace circt ;
2629using namespace circt ::rtg;
2730
31+ #define DEBUG_TYPE " rtg-inline-sequences"
32+
2833// ===----------------------------------------------------------------------===//
2934// Inline Sequences Pass
3035// ===----------------------------------------------------------------------===//
@@ -35,62 +40,183 @@ struct InlineSequencesPass
3540 using Base::Base;
3641
3742 void runOnOperation () override ;
38- LogicalResult inlineSequences (TestOp testOp, SymbolTable &table);
3943};
40- } // namespace
4144
42- LogicalResult InlineSequencesPass::inlineSequences (TestOp testOp,
43- SymbolTable &table) {
44- OpBuilder builder (testOp);
45- for (auto iter = testOp.getBody ()->begin ();
46- iter != testOp.getBody ()->end ();) {
47- auto embedOp = dyn_cast<EmbedSequenceOp>(&*iter);
48- if (!embedOp) {
49- ++iter;
50- continue ;
45+ // / Enum to indicate to the visitor driver whether the operation should be
46+ // / deleted.
47+ enum class DeletionKind { Delete, Keep };
48+
49+ // / The SequenceInliner computes sequence interleavings and inlines them.
50+ struct SequenceInliner
51+ : public RTGOpVisitor<SequenceInliner, FailureOr<DeletionKind>> {
52+ using RTGOpVisitor<SequenceInliner, FailureOr<DeletionKind>>::visitOp;
53+
54+ SequenceInliner (ModuleOp moduleOp) : table(moduleOp) {}
55+
56+ LogicalResult inlineSequences (TestOp testOp);
57+ void materializeInterleavedSequence (Value value, ArrayRef<Block *> blocks,
58+ uint32_t batchSize);
59+
60+ // Visitor methods
61+
62+ FailureOr<DeletionKind> visitOp (InterleaveSequencesOp op) {
63+ SmallVector<Block *> blocks;
64+ for (auto [i, seq] : llvm::enumerate (op.getSequences ())) {
65+ auto *block = materializedSequences.lookup (seq);
66+ if (!block)
67+ return op->emitError ()
68+ << " sequence operand #" << i
69+ << " could not be resolved; it was likely produced by an op or "
70+ " block argument not supported by this pass" ;
71+
72+ blocks.push_back (block);
5173 }
5274
53- auto randSeqOp = embedOp.getSequence ().getDefiningOp <RandomizeSequenceOp>();
54- if (!randSeqOp)
55- return embedOp->emitError (" sequence operand not directly defined by "
56- " 'rtg.randomize_sequence' op" );
57- auto getSeqOp = randSeqOp.getSequence ().getDefiningOp <GetSequenceOp>();
58- if (!getSeqOp)
59- return randSeqOp->emitError (
60- " sequence operand not directly defined by 'rtg.get_sequence' op" );
75+ LLVM_DEBUG (llvm::dbgs ()
76+ << " - Computing sequence interleaving: " << op << " \n " );
77+
78+ materializeInterleavedSequence (op.getInterleavedSequence (), blocks,
79+ op.getBatchSize ());
80+ return DeletionKind::Delete;
81+ }
82+
83+ FailureOr<DeletionKind> visitOp (GetSequenceOp op) {
84+ auto seqOp = table.lookup <SequenceOp>(op.getSequenceAttr ());
85+ if (!seqOp)
86+ return op->emitError () << " referenced sequence not found" ;
6187
62- auto seqOp = table.lookup <SequenceOp>(getSeqOp.getSequenceAttr ());
88+ LLVM_DEBUG (llvm::dbgs () << " - Registering existing sequence: "
89+ << op.getSequence () << " \n " );
6390
64- builder.setInsertionPointAfter (embedOp);
91+ materializedSequences[op.getResult ()] = seqOp.getBody ();
92+ return DeletionKind::Delete;
93+ }
94+
95+ FailureOr<DeletionKind> visitOp (RandomizeSequenceOp op) {
96+ LLVM_DEBUG (llvm::dbgs () << " - Randomize sequence: " << op << " \n " );
97+
98+ auto *block = materializedSequences.lookup (op.getSequence ());
99+ if (!block)
100+ return op->emitError () << " sequence operand could not be resolved; it "
101+ " was likely produced by an op or block "
102+ " argument not supported by this pass" ;
103+
104+ materializedSequences[op.getResult ()] = block;
105+ return DeletionKind::Delete;
106+ }
107+
108+ FailureOr<DeletionKind> visitOp (EmbedSequenceOp op) {
109+ LLVM_DEBUG (llvm::dbgs () << " - Inlining sequence: " << op << " \n " );
110+
111+ auto *block = materializedSequences.lookup (op.getSequence ());
112+ if (!block)
113+ return op->emitError () << " sequence operand could not be resolved; it "
114+ " was likely produced by an op or block "
115+ " argument not supported by this pass" ;
116+
117+ OpBuilder builder (op);
118+ builder.setInsertionPointAfter (op);
65119 IRMapping mapping;
66- for (auto &op : *seqOp. getBody () )
120+ for (auto &op : *block )
67121 builder.clone (op, mapping);
68122
69- (iter++)-> erase () ;
123+ ++numSequencesInlined ;
70124
71- if (randSeqOp-> use_empty ())
72- randSeqOp-> erase ();
125+ return DeletionKind::Delete;
126+ }
73127
74- if (getSeqOp->use_empty ())
75- getSeqOp->erase ();
128+ FailureOr<DeletionKind> visitUnhandledOp (Operation *op) {
129+ return DeletionKind::Keep;
130+ }
76131
77- ++numSequencesInlined;
132+ FailureOr<DeletionKind> visitExternalOp (Operation *op) {
133+ return DeletionKind::Keep;
134+ }
135+
136+ SymbolTable table;
137+ DenseMap<Value, Block *> materializedSequences;
138+ SmallVector<std::unique_ptr<Block>> blockStorage;
139+ size_t numSequencesInlined = 0 ;
140+ size_t numSequencesInterleaved = 0 ;
141+ };
142+
143+ } // namespace
144+
145+ void SequenceInliner::materializeInterleavedSequence (Value value,
146+ ArrayRef<Block *> blocks,
147+ uint32_t batchSize) {
148+ auto interleavedBlock = std::make_unique<Block>();
149+ IRMapping mapping;
150+ OpBuilder builder (value.getContext ());
151+ builder.setInsertionPointToStart (interleavedBlock.get ());
152+
153+ SmallVector<Block::iterator> iters (blocks.size ());
154+ for (auto [i, block] : llvm::enumerate (blocks))
155+ iters[i] = block->begin ();
156+
157+ llvm::BitVector finishedBlocks (blocks.size ());
158+ unsigned i = 0 ;
159+ while (!finishedBlocks.all ()) {
160+ if (finishedBlocks[i]) {
161+ i = (i + 1 ) % blocks.size ();
162+ continue ;
163+ }
164+
165+ for (unsigned k = 0 ; k < batchSize;) {
166+ if (iters[i] == blocks[i]->end ()) {
167+ finishedBlocks.set (i);
168+ break ;
169+ }
170+
171+ auto *op = builder.clone (*iters[i], mapping);
172+ if (isa<InstructionOpInterface>(op))
173+ ++k;
174+
175+ ++iters[i];
176+ }
177+
178+ i = (i + 1 ) % blocks.size ();
78179 }
79180
181+ materializedSequences[value] = interleavedBlock.get ();
182+ blockStorage.emplace_back (std::move (interleavedBlock));
183+ numSequencesInterleaved += blocks.size ();
184+ }
185+
186+ LogicalResult SequenceInliner::inlineSequences (TestOp testOp) {
187+ LLVM_DEBUG (llvm::dbgs () << " \n === Processing test @" << testOp.getSymName ()
188+ << " \n\n " );
189+
190+ SmallVector<Operation *> toDelete;
191+ for (auto &op : *testOp.getBody ()) {
192+ auto result = dispatchOpVisitor (&op);
193+ if (failed (result))
194+ return failure ();
195+
196+ if (*result == DeletionKind::Delete)
197+ toDelete.push_back (&op);
198+ }
199+
200+ for (auto *op : llvm::reverse (toDelete))
201+ op->erase ();
202+
80203 return success ();
81204}
82205
83206void InlineSequencesPass::runOnOperation () {
84207 auto moduleOp = getOperation ();
85- SymbolTable table (moduleOp);
208+ SequenceInliner inliner (moduleOp);
86209
87210 // Inline all sequences and remove the operations that place the sequences.
88211 for (auto testOp : moduleOp.getOps <TestOp>())
89- if (failed (inlineSequences (testOp, table )))
90- return ;
212+ if (failed (inliner. inlineSequences (testOp)))
213+ return signalPassFailure () ;
91214
92215 // Remove all sequences since they are not accessible from the outside and
93216 // are not needed anymore since we fully inlined them.
94217 for (auto seqOp : llvm::make_early_inc_range (moduleOp.getOps <SequenceOp>()))
95218 seqOp->erase ();
219+
220+ numSequencesInlined = inliner.numSequencesInlined ;
221+ numSequencesInterleaved = inliner.numSequencesInterleaved ;
96222}
0 commit comments