33
33
using namespace mlir ;
34
34
using namespace mlir ::pulse;
35
35
36
- uint64_t SchedulePortPass::processCall (Operation * module ,
37
- CallSequenceOp &callSequenceOp ) {
36
+ uint64_t SchedulePortPass::processCall (CallSequenceOp &callSequenceOp ,
37
+ bool updateNestedSequences ) {
38
38
39
39
INDENT_DEBUG (" ==== processCall - start ===================\n " );
40
40
INDENT_DUMP (callSequenceOp.dump ());
41
41
INDENT_DEBUG (" =============================================\n " );
42
42
43
+ // check for nested sequence
44
+ auto parentSequence = callSequenceOp->getParentOfType <SequenceOp>();
45
+ if (!updateNestedSequences && parentSequence)
46
+ return 0 ;
47
+
43
48
// walk into region and check arguments
44
49
// look for sequence def match
45
50
auto callee = callSequenceOp.getCallee ();
46
- auto sequenceOp =
47
- dyn_cast<SequenceOp>( SymbolTable::lookupSymbolIn ( module , callee));
48
- if (!sequenceOp ) {
51
+ auto sequenceOpIter = sequenceOps. find (callee);
52
+
53
+ if (sequenceOpIter == sequenceOps. end () ) {
49
54
callSequenceOp->emitError ()
50
55
<< " Unable to find callee symbol " << callee << " ." ;
51
56
signalPassFailure ();
52
57
}
53
- uint64_t calleeDuration = processSequence (sequenceOp);
58
+
59
+ auto sequenceOp = sequenceOpIter->second ;
60
+
61
+ uint64_t calleeDuration;
62
+ if (updateNestedSequences)
63
+ calleeDuration = updateSequence (sequenceOp);
64
+ else
65
+ calleeDuration = processSequence (sequenceOp);
66
+ PulseOpSchedulingInterface::setDuration (callSequenceOp, calleeDuration);
54
67
55
68
INDENT_DEBUG (" ==== processCall - end ====================\n " );
56
69
INDENT_DUMP (callSequenceOp.dump ());
@@ -60,8 +73,6 @@ uint64_t SchedulePortPass::processCall(Operation *module,
60
73
61
74
uint64_t SchedulePortPass::processSequence (SequenceOp sequenceOp) {
62
75
63
- // TODO: Consider returning overall length of sequence to help schedule
64
- // across sequences
65
76
mlir::OpBuilder builder (sequenceOp);
66
77
67
78
uint32_t numMixedFrames = 0 ;
@@ -91,6 +102,42 @@ uint64_t SchedulePortPass::processSequence(SequenceOp sequenceOp) {
91
102
return maxTime;
92
103
}
93
104
105
+ uint64_t SchedulePortPass::updateSequence (SequenceOp sequenceOp) {
106
+
107
+ uint64_t updateDelta = 0 ;
108
+ int64_t returnTimepoint = 0 ;
109
+ for (Region ®ion : sequenceOp->getRegions ()) {
110
+ for (Block &block : region.getBlocks ()) {
111
+ for (Operation &op : block.getOperations ()) {
112
+ int64_t timepoint = 0 ;
113
+
114
+ auto existingTimepoint = PulseOpSchedulingInterface::getTimepoint (&op);
115
+ if (existingTimepoint.hasValue ()) {
116
+ timepoint = existingTimepoint.getValue () + updateDelta;
117
+ PulseOpSchedulingInterface::setTimepoint (&op, timepoint);
118
+ }
119
+
120
+ if (auto castOp = dyn_cast<ReturnOp>(op)) {
121
+ returnTimepoint = timepoint;
122
+ } else if (auto castOp = dyn_cast<CallSequenceOp>(op)) {
123
+ // a nested sequence should only have a duration if it has been
124
+ // updated by this method already
125
+ if (!castOp->hasAttr (" pulse.duration" )) {
126
+ uint64_t calleeDuration =
127
+ processCall (castOp, /* updateNested*/ true );
128
+ PulseOpSchedulingInterface::setDuration (castOp, calleeDuration);
129
+ updateDelta += calleeDuration;
130
+ } else {
131
+ updateDelta +=
132
+ castOp->getAttrOfType <IntegerAttr>(" pulse.duration" ).getInt ();
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+ return returnTimepoint;
139
+ }
140
+
94
141
SchedulePortPass::mixedFrameMap_t
95
142
SchedulePortPass::buildMixedFrameMap (SequenceOp &sequenceOp,
96
143
uint32_t &numMixedFrames) {
@@ -141,6 +188,17 @@ SchedulePortPass::buildMixedFrameMap(SequenceOp &sequenceOp,
141
188
auto index = blockArg.getArgNumber ();
142
189
143
190
mixedFrameSequences[index].push_back (&op);
191
+ } else if (auto castOp = dyn_cast<CallSequenceOp>(op)) {
192
+ // add a call sequence to all mixedFrameSequences for
193
+ // mixedFrames passed to it
194
+ for (auto operand : castOp.operands ()) {
195
+ auto operandType = operand.getType ();
196
+ if (operandType.isa <MixedFrameType>()) {
197
+ auto blockArg = operand.cast <BlockArgument>();
198
+ auto index = blockArg.getArgNumber ();
199
+ mixedFrameSequences[index].push_back (&op);
200
+ }
201
+ }
144
202
}
145
203
}
146
204
}
@@ -161,6 +219,11 @@ void SchedulePortPass::addTimepoints(mlir::OpBuilder &builder,
161
219
for (const auto &index : mixedFrameSequences) {
162
220
int64_t currentTimepoint = 0 ;
163
221
for (auto *op : index.second ) {
222
+ auto existingTimepoint = PulseOpSchedulingInterface::getTimepoint (op);
223
+ if (existingTimepoint.hasValue ())
224
+ if (existingTimepoint.getValue () > currentTimepoint)
225
+ currentTimepoint = existingTimepoint.getValue ();
226
+
164
227
// set attribute on op with current timepoint
165
228
PulseOpSchedulingInterface::setTimepoint (op, currentTimepoint);
166
229
@@ -203,8 +266,12 @@ void SchedulePortPass::sortOpsByTimepoint(SequenceOp &sequenceOp) {
203
266
!isa<arith::ConstantIntOp>(op2))
204
267
return true ;
205
268
206
- if (!op1.hasTrait <mlir::pulse::HasTargetFrame>() ||
207
- !op2.hasTrait <mlir::pulse::HasTargetFrame>())
269
+ bool testOp1 = (op1.hasTrait <mlir::pulse::HasTargetFrame>() ||
270
+ isa<CallSequenceOp>(op1));
271
+ bool testOp2 = (op2.hasTrait <mlir::pulse::HasTargetFrame>() ||
272
+ isa<CallSequenceOp>(op2));
273
+
274
+ if (!testOp1 || !testOp2)
208
275
return false ;
209
276
210
277
llvm::Optional<int64_t > currentTimepoint =
@@ -233,9 +300,18 @@ void SchedulePortPass::runOnOperation() {
233
300
234
301
Operation *module = getOperation ();
235
302
303
+ module ->walk (
304
+ [&](mlir::pulse::SequenceOp op) { sequenceOps[op.sym_name ()] = op; });
305
+
236
306
INDENT_DEBUG (" ===== SchedulePortPass - start ==========\n " );
237
307
238
- module ->walk ([&](CallSequenceOp op) { processCall (module , op); });
308
+ // assign
309
+ module ->walk ([&](CallSequenceOp op) {
310
+ processCall (op, /* updateNestedSequences*/ false );
311
+ });
312
+ module ->walk ([&](CallSequenceOp op) {
313
+ processCall (op, /* updateNestedSequences*/ true );
314
+ });
239
315
240
316
INDENT_DEBUG (" ===== SchedulePortPass - end ===========\n " );
241
317
0 commit comments