Skip to content

Commit c2637fe

Browse files
reza-jgithub-actions[bot]taalexander
authored
Pulse ALAP scheduling (#183)
This PR adds pulse alap scheduling for pulse sequences of quantum gates inside a circuit, based on the availability of involved ports --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Thomas Alexander <[email protected]>
1 parent 503eb93 commit c2637fe

File tree

11 files changed

+300
-286
lines changed

11 files changed

+300
-286
lines changed

include/Dialect/Pulse/IR/PulseInterfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ llvm::Optional<uint64_t> getSetupLatency(mlir::Operation *op);
4242
void setSetupLatency(mlir::Operation *op, uint64_t setupLatency);
4343
llvm::Expected<uint64_t> getDuration(mlir::Operation *op,
4444
mlir::Operation *callSequenceOp = nullptr);
45+
llvm::Expected<mlir::ArrayAttr> getPorts(mlir::Operation *op);
4546
void setDuration(mlir::Operation *op, uint64_t duration);
4647

4748
} // namespace mlir::pulse::interfaces_impl

include/Dialect/Pulse/IR/PulseInterfaces.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,28 @@ def PulseOpSchedulingInterface : OpInterface<"PulseOpSchedulingInterface"> {
100100
return PulseOpSchedulingInterface::setDuration($_op, other);
101101
}]
102102
>,
103+
InterfaceMethod<
104+
/*desc=*/"Get the ports of a pulse operation",
105+
/*retTy=*/"::llvm::Expected<mlir::ArrayAttr>",
106+
/*methodName=*/"getPorts",
107+
/*args=*/(ins),
108+
/*methodBody=*/[{}],
109+
/*defaultImplementation=*/[{
110+
// By default, return the pulse.argPorts attribute
111+
return PulseOpSchedulingInterface::getPorts($_op);
112+
}]
113+
>,
103114
];
104115

105116
let extraSharedClassDeclaration = [{
106117
static llvm::Optional<int64_t> getTimepoint(mlir::Operation *op) {
107118
return interfaces_impl::getTimepoint(op);
108119
}
109120

121+
static llvm::Expected<mlir::ArrayAttr> getPorts(mlir::Operation *op) {
122+
return interfaces_impl::getPorts(op);
123+
}
124+
110125
static void setTimepoint(mlir::Operation *op, int64_t timepoint) {
111126
return interfaces_impl::setTimepoint(op, timepoint);
112127
}

include/Dialect/Pulse/IR/PulseOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ def Pulse_CallSequenceOp : Pulse_Op<"call_sequence", [CallOpInterface, MemRefsNo
712712

713713
def Pulse_SequenceOp : Pulse_Op<"sequence", [
714714
AutomaticAllocationScope, CallableOpInterface,
715+
DeclareOpInterfaceMethods<PulseOpSchedulingInterface, ["getDuration"]>,
715716
FunctionOpInterface, IsolatedFromAbove, Symbol, SequenceAllowed
716717
]> {
717718
let summary = "An operation with a name containing a single `SSACFG` region corresponding to a pulse sequence execution";

include/Dialect/Pulse/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "Dialect/Pulse/Transforms/MergeDelays.h"
2828
#include "Dialect/Pulse/Transforms/RemoveUnusedArguments.h"
2929
#include "Dialect/Pulse/Transforms/SchedulePort.h"
30+
#include "Dialect/Pulse/Transforms/Scheduling.h"
3031
#include "mlir/Pass/Pass.h"
3132
#include "mlir/Pass/PassManager.h"
3233
#include "mlir/Transforms/Passes.h"
Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- Scheduling.h - Add absolute timing to defcal calls. ------*- C++ -*-===//
1+
//===- scheduling.h --- quantum circuits pulse scheduling -------*- C++ -*-===//
22
//
33
// (C) Copyright IBM 2023.
44
//
@@ -14,72 +14,60 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616
///
17-
/// This file declares the pass for adding absolute timing to defcal calls.
17+
/// This file implements the pass for scheduling the quantum circuits at pulse
18+
/// level, based on the availability of involved ports
1819
///
1920
//===----------------------------------------------------------------------===//
2021

21-
//#ifndef PULSE_SCHEDULING_H
22-
//#define PULSE_SCHEDULING_H
23-
24-
//#include <unordered_map>
25-
//#include <unordered_set>
26-
27-
//#include "Dialect/Pulse/IR/PulseOps.h"
28-
//#include "mlir/Pass/Pass.h"
29-
30-
// namespace mlir::pulse {
31-
32-
//// This pass applies absolute timing to each relevant Pulse IR instruction.
33-
//// Timing is calculated on a per frame basis.
34-
///*** Steps:
35-
// * 1. Identify each defcal gate call.
36-
// * 2. Find associated defcal body.
37-
// * 3. Compute and store duration of each waveform and initialze time on each
38-
// *frame. 4. For each play/delay instruction, increment the frame timing. Add
39-
// *time attribute to instruction. For each barrier instruction, resolve to
40-
// delays *on push forward basis (frames will be delayed to maximum time amongst
41-
// all *frames).
42-
// ***/
43-
// struct SchedulingPass : public PassWrapper<SchedulingPass, OperationPass<>> {
44-
45-
// std::unordered_set<uint>
46-
// scheduledDefCals; // hashes of defcal's that have already been scheduled
47-
// std::unordered_map<uint, uint>
48-
// pulseDurations; // mapping of waveform hashes to durations
49-
// std::unordered_map<uint, uint>
50-
// frameTimes; // mapping of frame hashes to time on that frame
51-
52-
// // Hash an operation based on the result.
53-
// auto getResultHash(Operation *op) -> uint;
54-
55-
// // Check if the pulse hash is cached in pulse durations.
56-
// // If it is cached, the hash will be found in pulseDurations.
57-
// auto pulseCached(llvm::hash_code hash) -> bool;
58-
59-
// // Get hash and time of a frame as a std::pair
60-
// auto getFrameHashAndTime(mlir::Value &frame) -> std::pair<uint, uint>;
61-
62-
// // Get the maximum time among a set of frames
63-
// auto getMaxTime(mlir::OperandRange &frames) -> uint;
64-
65-
// // Process each operation in the defcal
66-
// template <class WaveformOp>
67-
// void processOp(WaveformOp &wfrOp);
68-
69-
// void processOp(Frame_CreateOp &frameOp);
70-
71-
// void processOp(DelayOp &delayOp);
72-
// void processOp(BarrierOp &barrierOp);
73-
// void processOp(PlayOp &playOp);
74-
// void processOp(CaptureOp &captureOp);
75-
76-
// // Schedule the defcal
77-
// void schedule(Operation *defCalOp);
78-
79-
// // Entry point for the pass
80-
// void runOnOperation() override;
81-
82-
//}; // end struct SchedulingPass
83-
//} // namespace mlir::pulse
84-
85-
//#endif // PULSE_SCHEDULING_H
22+
#ifndef SCHEDULING_PULSE_SEQUENCES_H
23+
#define SCHEDULING_PULSE_SEQUENCES_H
24+
25+
#include "Dialect/Pulse/IR/PulseOps.h"
26+
27+
#include "mlir/IR/MLIRContext.h"
28+
#include "mlir/Pass/Pass.h"
29+
30+
namespace mlir::pulse {
31+
32+
struct quantumCircuitPulseSchedulingPass
33+
: public PassWrapper<quantumCircuitPulseSchedulingPass,
34+
OperationPass<ModuleOp>> {
35+
public:
36+
enum SchedulingMethod { ALAP, ASAP };
37+
SchedulingMethod SCHEDULING_METHOD = ALAP;
38+
39+
// this pass can optionally receive an string specifying the scheduling
40+
// method; default method is alap scheduling
41+
quantumCircuitPulseSchedulingPass() = default;
42+
quantumCircuitPulseSchedulingPass(
43+
const quantumCircuitPulseSchedulingPass &pass)
44+
: PassWrapper(pass) {}
45+
quantumCircuitPulseSchedulingPass(SchedulingMethod inSchedulingMethod) {
46+
SCHEDULING_METHOD = inSchedulingMethod;
47+
}
48+
49+
void runOnOperation() override;
50+
51+
llvm::StringRef getArgument() const override;
52+
llvm::StringRef getDescription() const override;
53+
54+
// optionally, one can override the scheduling method with this option
55+
Option<std::string> schedulingMethod{
56+
*this, "scheduling-method",
57+
llvm::cl::desc("an string to specify scheduling method"),
58+
llvm::cl::value_desc("filename"), llvm::cl::init("")};
59+
60+
private:
61+
// map to keep track of next availability of ports
62+
std::map<std::string, int> portNameToNextAvailabilityMap;
63+
64+
void scheduleAlap(mlir::pulse::CallSequenceOp quantumCircuitCallSequenceOp);
65+
int getNextAvailableTimeOfPorts(mlir::ArrayAttr ports);
66+
void updatePortAvailabilityMap(mlir::ArrayAttr ports,
67+
int updatedAvailableTime);
68+
static mlir::pulse::SequenceOp
69+
getSequenceOp(mlir::pulse::CallSequenceOp callSequenceOp);
70+
};
71+
} // namespace mlir::pulse
72+
73+
#endif // SCHEDULING_PULSE_SEQUENCES_H

lib/Conversion/QUIRToPulse/LoadPulseCals.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::BarrierOp barrierOp,
253253
void LoadPulseCalsPass::loadPulseCals(mlir::quir::DelayOp delayOp,
254254
CallCircuitOp callCircuitOp,
255255
FuncOp funcOp) {
256-
257256
OpBuilder builder(funcOp.body());
258257

259258
std::vector<Value> qubitOperands;
@@ -264,7 +263,7 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::DelayOp delayOp,
264263
delayOp->setAttr("pulse.calName", builder.getStringAttr(gateMangledName));
265264
if (pulseCalsNameToSequenceMap.find(gateMangledName) !=
266265
pulseCalsNameToSequenceMap.end()) {
267-
// found a pulse calibration for the barrier gate
266+
// found a pulse calibration for the delay gate
268267
addPulseCalToModule(funcOp, pulseCalsNameToSequenceMap[gateMangledName]);
269268
return;
270269
}

lib/Conversion/QUIRToPulse/QUIRToPulse.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,29 @@ void QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp callCircuitOp,
143143
convertedPulseSequenceOpReturnTypes.push_back(type);
144144
for (auto val : pulseCalCallSequenceOp.res())
145145
convertedPulseSequenceOpReturnValues.push_back(val);
146+
147+
// add starting timepoint for delayOp
148+
if (auto delayOp = dyn_cast<mlir::quir::DelayOp>(quirOp)) {
149+
uint64_t durValue = 0;
150+
if (delayOp.time().isa<BlockArgument>()) {
151+
uint argNum = delayOp.time().dyn_cast<BlockArgument>().getArgNumber();
152+
auto durOpConstantOp = callCircuitOp.getOperand(argNum)
153+
.getDefiningOp<mlir::quir::ConstantOp>();
154+
auto durOp = quir::getDuration(durOpConstantOp).get();
155+
durValue = static_cast<uint>(durOp.getDuration().convertToDouble());
156+
assert(durOp.getType().dyn_cast<DurationType>().getUnits() ==
157+
TimeUnits::dt &&
158+
"this pass only accepts durations with dt unit");
159+
} else {
160+
auto durOp = quir::getDuration(delayOp).get();
161+
durValue = static_cast<uint>(durOp.getDuration().convertToDouble());
162+
assert(durOp.getType().dyn_cast<DurationType>().getUnits() ==
163+
TimeUnits::dt &&
164+
"this pass only accepts durations with dt unit");
165+
}
166+
PulseOpSchedulingInterface::setDuration(pulseCalCallSequenceOp,
167+
durValue);
168+
}
146169
} else
147170
assert(((isa<quir::ConstantOp>(quirOp) or isa<quir::ReturnOp>(quirOp) or
148171
isa<quir::CircuitOp>(quirOp))) &&

lib/Dialect/Pulse/IR/PulseInterfaces.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ interfaces_impl::getDuration(Operation *op, Operation *callSequenceOp) {
6969
"Operation does not have a pulse.duration attribute.");
7070
}
7171

72+
llvm::Expected<mlir::ArrayAttr> interfaces_impl::getPorts(mlir::Operation *op) {
73+
if (op->hasAttrOfType<ArrayAttr>("pulse.argPorts"))
74+
return op->getAttrOfType<ArrayAttr>("pulse.argPorts");
75+
return llvm::createStringError(
76+
llvm::inconvertibleErrorCode(),
77+
"Operation does not have a pulse.argPorts attribute.");
78+
}
79+
7280
void interfaces_impl::setDuration(Operation *op, uint64_t duration) {
7381
mlir::OpBuilder builder(op);
7482
op->setAttr("pulse.duration", builder.getI64IntegerAttr(duration));

lib/Dialect/Pulse/IR/PulseOps.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,23 @@ CallSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
184184
//
185185
//===----------------------------------------------------------------------===//
186186

187+
llvm::Expected<uint64_t>
188+
SequenceOp::getDuration(mlir::Operation *callSequenceOp = nullptr) {
189+
// first, check if the sequence has duration attribute. If not, also check if
190+
// the call sequence has duration attribute; e.g., for sequences that receives
191+
// delay arguments, duration of the sequence can vary depending on the
192+
// argument, so we look at the duration of call sequence as well
193+
if ((*this)->hasAttr("pulse.duration"))
194+
return static_cast<uint64_t>(
195+
(*this)->getAttrOfType<IntegerAttr>("pulse.duration").getInt());
196+
if (callSequenceOp->hasAttr("pulse.duration"))
197+
return static_cast<uint64_t>(
198+
callSequenceOp->getAttrOfType<IntegerAttr>("pulse.duration").getInt());
199+
return llvm::createStringError(
200+
llvm::inconvertibleErrorCode(),
201+
"Operation does not have a pulse.duration attribute.");
202+
}
203+
187204
static ParseResult parseSequenceOp(OpAsmParser &parser,
188205
OperationState &result) {
189206
auto buildSequenceType =

lib/Dialect/Pulse/Transforms/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void registerPulsePasses() {
4343
PassRegistration<MergeDelayPass>();
4444
PassRegistration<RemoveUnusedArgumentsPass>();
4545
PassRegistration<SchedulePortPass>();
46+
PassRegistration<quantumCircuitPulseSchedulingPass>();
4647
PassRegistration<ClassicalOnlyDetectionPass>();
4748
}
4849

0 commit comments

Comments
 (0)