Skip to content

Commit 9b7b4c1

Browse files
authored
Update SchedulePort and RemoveUnusedArgument Passes for Nested Sequences (#206)
Updates the SchedulePort and RemoveUnusedArgument passes to include support for nested sequences.
1 parent 2e4a78a commit 9b7b4c1

File tree

4 files changed

+126
-13
lines changed

4 files changed

+126
-13
lines changed

include/Dialect/Pulse/Transforms/SchedulePort.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,18 @@ class SchedulePortPass
4949
private:
5050
using mixedFrameMap_t = std::map<uint32_t, std::vector<Operation *>>;
5151

52-
uint64_t processCall(Operation *module, CallSequenceOp &callSequenceOp);
52+
uint64_t processCall(CallSequenceOp &callSequenceOp,
53+
bool updateNestedSequences);
5354
uint64_t processSequence(SequenceOp sequenceOp);
55+
uint64_t updateSequence(SequenceOp sequenceOp);
5456

5557
mixedFrameMap_t buildMixedFrameMap(SequenceOp &sequenceOp,
5658
uint32_t &numMixedFrames);
5759

5860
void addTimepoints(mlir::OpBuilder &builder,
5961
mixedFrameMap_t &mixedFrameSequences, int64_t &maxTime);
6062
void sortOpsByTimepoint(SequenceOp &sequenceOp);
63+
llvm::StringMap<mlir::pulse::SequenceOp> sequenceOps;
6164
};
6265
} // namespace mlir::pulse
6366

lib/Dialect/Pulse/Transforms/RemoveUnusedArguments.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ struct RemoveUnusedArgumentsPattern
6666
LLVM_DEBUG(argumentResult.value().dump());
6767

6868
auto *argOp = callSequenceOp.getOperand(index).getDefiningOp();
69-
testEraseList.push_back(argOp);
69+
if (argOp)
70+
testEraseList.push_back(argOp);
7071
}
7172
}
7273

@@ -78,6 +79,22 @@ struct RemoveUnusedArgumentsPattern
7879
sequenceOp.eraseArguments(argIndicesBV);
7980
callSequenceOp->eraseOperands(argIndicesBV);
8081

82+
// check for other CallSequenceOps calling the same sequence
83+
auto moduleOp = sequenceOp->getParentOfType<mlir::ModuleOp>();
84+
assert(moduleOp && "Operation outside of a Module");
85+
moduleOp->walk([&](pulse::CallSequenceOp op) {
86+
if (op == callSequenceOp)
87+
return;
88+
if (op.callee() != sequenceOp.sym_name())
89+
return;
90+
// verify that the sequence and the new callSequenceOp are in
91+
// the same module
92+
auto checkModuleOp = op->getParentOfType<mlir::ModuleOp>();
93+
if (checkModuleOp != moduleOp)
94+
return;
95+
op->eraseOperands(argIndicesBV);
96+
});
97+
8198
// remove defining ops if the have no usage
8299
for (auto *argOp : testEraseList)
83100
if (argOp->use_empty())

lib/Dialect/Pulse/Transforms/SchedulePort.cpp

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,37 @@
3333
using namespace mlir;
3434
using namespace mlir::pulse;
3535

36-
uint64_t SchedulePortPass::processCall(Operation *module,
37-
CallSequenceOp &callSequenceOp) {
36+
uint64_t SchedulePortPass::processCall(CallSequenceOp &callSequenceOp,
37+
bool updateNestedSequences) {
3838

3939
INDENT_DEBUG("==== processCall - start ===================\n");
4040
INDENT_DUMP(callSequenceOp.dump());
4141
INDENT_DEBUG("=============================================\n");
4242

43+
// check for nested sequence
44+
auto parentSequence = callSequenceOp->getParentOfType<SequenceOp>();
45+
if (!updateNestedSequences && parentSequence)
46+
return 0;
47+
4348
// walk into region and check arguments
4449
// look for sequence def match
4550
auto callee = callSequenceOp.getCallee();
46-
auto sequenceOp =
47-
dyn_cast<SequenceOp>(SymbolTable::lookupSymbolIn(module, callee));
48-
if (!sequenceOp) {
51+
auto sequenceOpIter = sequenceOps.find(callee);
52+
53+
if (sequenceOpIter == sequenceOps.end()) {
4954
callSequenceOp->emitError()
5055
<< "Unable to find callee symbol " << callee << ".";
5156
signalPassFailure();
5257
}
53-
uint64_t calleeDuration = processSequence(sequenceOp);
58+
59+
auto sequenceOp = sequenceOpIter->second;
60+
61+
uint64_t calleeDuration;
62+
if (updateNestedSequences)
63+
calleeDuration = updateSequence(sequenceOp);
64+
else
65+
calleeDuration = processSequence(sequenceOp);
66+
PulseOpSchedulingInterface::setDuration(callSequenceOp, calleeDuration);
5467

5568
INDENT_DEBUG("==== processCall - end ====================\n");
5669
INDENT_DUMP(callSequenceOp.dump());
@@ -60,8 +73,6 @@ uint64_t SchedulePortPass::processCall(Operation *module,
6073

6174
uint64_t SchedulePortPass::processSequence(SequenceOp sequenceOp) {
6275

63-
// TODO: Consider returning overall length of sequence to help schedule
64-
// across sequences
6576
mlir::OpBuilder builder(sequenceOp);
6677

6778
uint32_t numMixedFrames = 0;
@@ -91,6 +102,42 @@ uint64_t SchedulePortPass::processSequence(SequenceOp sequenceOp) {
91102
return maxTime;
92103
}
93104

105+
uint64_t SchedulePortPass::updateSequence(SequenceOp sequenceOp) {
106+
107+
uint64_t updateDelta = 0;
108+
int64_t returnTimepoint = 0;
109+
for (Region &region : sequenceOp->getRegions()) {
110+
for (Block &block : region.getBlocks()) {
111+
for (Operation &op : block.getOperations()) {
112+
int64_t timepoint = 0;
113+
114+
auto existingTimepoint = PulseOpSchedulingInterface::getTimepoint(&op);
115+
if (existingTimepoint.hasValue()) {
116+
timepoint = existingTimepoint.getValue() + updateDelta;
117+
PulseOpSchedulingInterface::setTimepoint(&op, timepoint);
118+
}
119+
120+
if (auto castOp = dyn_cast<ReturnOp>(op)) {
121+
returnTimepoint = timepoint;
122+
} else if (auto castOp = dyn_cast<CallSequenceOp>(op)) {
123+
// a nested sequence should only have a duration if it has been
124+
// updated by this method already
125+
if (!castOp->hasAttr("pulse.duration")) {
126+
uint64_t calleeDuration =
127+
processCall(castOp, /*updateNested*/ true);
128+
PulseOpSchedulingInterface::setDuration(castOp, calleeDuration);
129+
updateDelta += calleeDuration;
130+
} else {
131+
updateDelta +=
132+
castOp->getAttrOfType<IntegerAttr>("pulse.duration").getInt();
133+
}
134+
}
135+
}
136+
}
137+
}
138+
return returnTimepoint;
139+
}
140+
94141
SchedulePortPass::mixedFrameMap_t
95142
SchedulePortPass::buildMixedFrameMap(SequenceOp &sequenceOp,
96143
uint32_t &numMixedFrames) {
@@ -141,6 +188,17 @@ SchedulePortPass::buildMixedFrameMap(SequenceOp &sequenceOp,
141188
auto index = blockArg.getArgNumber();
142189

143190
mixedFrameSequences[index].push_back(&op);
191+
} else if (auto castOp = dyn_cast<CallSequenceOp>(op)) {
192+
// add a call sequence to all mixedFrameSequences for
193+
// mixedFrames passed to it
194+
for (auto operand : castOp.operands()) {
195+
auto operandType = operand.getType();
196+
if (operandType.isa<MixedFrameType>()) {
197+
auto blockArg = operand.cast<BlockArgument>();
198+
auto index = blockArg.getArgNumber();
199+
mixedFrameSequences[index].push_back(&op);
200+
}
201+
}
144202
}
145203
}
146204
}
@@ -161,6 +219,11 @@ void SchedulePortPass::addTimepoints(mlir::OpBuilder &builder,
161219
for (const auto &index : mixedFrameSequences) {
162220
int64_t currentTimepoint = 0;
163221
for (auto *op : index.second) {
222+
auto existingTimepoint = PulseOpSchedulingInterface::getTimepoint(op);
223+
if (existingTimepoint.hasValue())
224+
if (existingTimepoint.getValue() > currentTimepoint)
225+
currentTimepoint = existingTimepoint.getValue();
226+
164227
// set attribute on op with current timepoint
165228
PulseOpSchedulingInterface::setTimepoint(op, currentTimepoint);
166229

@@ -203,8 +266,12 @@ void SchedulePortPass::sortOpsByTimepoint(SequenceOp &sequenceOp) {
203266
!isa<arith::ConstantIntOp>(op2))
204267
return true;
205268

206-
if (!op1.hasTrait<mlir::pulse::HasTargetFrame>() ||
207-
!op2.hasTrait<mlir::pulse::HasTargetFrame>())
269+
bool testOp1 = (op1.hasTrait<mlir::pulse::HasTargetFrame>() ||
270+
isa<CallSequenceOp>(op1));
271+
bool testOp2 = (op2.hasTrait<mlir::pulse::HasTargetFrame>() ||
272+
isa<CallSequenceOp>(op2));
273+
274+
if (!testOp1 || !testOp2)
208275
return false;
209276

210277
llvm::Optional<int64_t> currentTimepoint =
@@ -233,9 +300,18 @@ void SchedulePortPass::runOnOperation() {
233300

234301
Operation *module = getOperation();
235302

303+
module->walk(
304+
[&](mlir::pulse::SequenceOp op) { sequenceOps[op.sym_name()] = op; });
305+
236306
INDENT_DEBUG("===== SchedulePortPass - start ==========\n");
237307

238-
module->walk([&](CallSequenceOp op) { processCall(module, op); });
308+
// assign
309+
module->walk([&](CallSequenceOp op) {
310+
processCall(op, /*updateNestedSequences*/ false);
311+
});
312+
module->walk([&](CallSequenceOp op) {
313+
processCall(op, /*updateNestedSequences*/ true);
314+
});
239315

240316
INDENT_DEBUG("===== SchedulePortPass - end ===========\n");
241317

test/Dialect/Pulse/Transforms/schedule-port-acquire.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,20 @@ module @acquire_0 attributes {quir.nodeId = 7 : i32, quir.nodeType = "acquire",
5151
// CHECK: pulse.return {pulse.timepoint = 18096 : i64} %{{.*}} : i1
5252
pulse.return %0 : i1
5353
}
54+
pulse.sequence @seq_1(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame) -> (i1, i1) {
55+
%c1000_i32 = arith.constant 1000 : i32
56+
pulse.delay(%arg0, %c1000_i32) : (!pulse.mixed_frame, i32)
57+
// CHECK-NOT: pulse.delay(%[[ARG1]], %c1000_i32) : (!pulse.mixed_frame, i32)
58+
%0 = pulse.call_sequence @seq_0(%arg0, %arg1, %arg2, %arg3, %arg4) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> i1
59+
// CHECK: %0 = pulse.call_sequence @seq_0(%arg0, %arg1, %arg2, %arg3, %arg4) {pulse.duration = 18096 : i64, pulse.timepoint = 1000 : i64}
60+
pulse.delay(%arg0, %c1000_i32) : (!pulse.mixed_frame, i32)
61+
// CHECK-NOT: pulse.delay(%[[ARG1]], %c1000_i32) : (!pulse.mixed_frame, i32)
62+
%1 = pulse.call_sequence @seq_0(%arg0, %arg1, %arg2, %arg3, %arg4) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> i1
63+
// CHECK: %1 = pulse.call_sequence @seq_0(%arg0, %arg1, %arg2, %arg3, %arg4) {pulse.duration = 18096 : i64, pulse.timepoint = 20096 : i64}
64+
pulse.delay(%arg0, %c1000_i32) : (!pulse.mixed_frame, i32)
65+
// CHECK-NOT: pulse.delay(%[[ARG1]], %c1000_i32) : (!pulse.mixed_frame, i32)
66+
pulse.return %0, %1 : i1, i1
67+
}
5468
func @main() -> i32 attributes {quir.classicalOnly = false} {
5569
%c0_i32 = arith.constant 0 : i32
5670
%2 = "pulse.create_port"() {uid = "p0"} : () -> !pulse.port
@@ -63,6 +77,9 @@ module @acquire_0 attributes {quir.nodeId = 7 : i32, quir.nodeType = "acquire",
6377
%13 = "pulse.create_port"() {uid = "p3"} : () -> !pulse.port
6478
%15 = "pulse.mix_frame"(%13) {uid = "mf0-p3"} : (!pulse.port) -> !pulse.mixed_frame
6579
%16 = pulse.call_sequence @seq_0(%4, %6, %9, %12, %15) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> i1
80+
// CHECK: {{.*}} = pulse.call_sequence @seq_0(%1, %2, %4, %6, %8) {pulse.duration = 18096 : i64}
81+
%17:2 = pulse.call_sequence @seq_1(%4, %6, %9, %12, %15) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1,i1)
82+
// CHECK: {{.*}}:2 = pulse.call_sequence @seq_1(%1, %2, %4, %6, %8) {pulse.duration = 39192 : i64}
6683
return %c0_i32 : i32
6784
}
6885
}

0 commit comments

Comments
 (0)