Skip to content

Commit 3f4a0a5

Browse files
quir to pulse pass (#142)
A pass to convert quir circuits to pulse sequences --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent afc02e6 commit 3f4a0a5

File tree

6 files changed

+779
-69
lines changed

6 files changed

+779
-69
lines changed

include/Conversion/QUIRToPulse/QUIRToPulse.h

Lines changed: 136 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,150 @@
1313
// that they have been altered from the originals.
1414
//
1515
//===----------------------------------------------------------------------===//
16-
//
17-
// This file declares the pass for converting QUIR to Pulse dialect
18-
//
16+
///
17+
/// This file declares the pass for converting QUIR circuits to Pulse sequences
18+
///
1919
//===----------------------------------------------------------------------===//
2020

21-
#ifndef PULSE_CONVERSION_QUIRTOPULSE_H
22-
#define PULSE_CONVERSION_QUIRTOPULSE_H
21+
#ifndef QUIRTOPULSE_CONVERSION_H
22+
#define QUIRTOPULSE_CONVERSION_H
23+
24+
#include "Dialect/OQ3/IR/OQ3Ops.h"
25+
#include "Dialect/Pulse/IR/PulseOps.h"
26+
#include "Dialect/QCS/IR/QCSOps.h"
2327

24-
#include "mlir/IR/Builders.h"
25-
#include "mlir/IR/Location.h"
2628
#include "mlir/IR/MLIRContext.h"
27-
#include "mlir/IR/Value.h"
2829
#include "mlir/Pass/Pass.h"
29-
#include "mlir/Transforms/DialectConversion.h"
30-
#include <functional>
31-
#include <memory>
30+
31+
#include <queue>
3232

3333
namespace mlir::pulse {
3434

35-
class QUIRTypeConverter : public TypeConverter {
36-
public:
37-
using TypeConverter::TypeConverter;
38-
QUIRTypeConverter();
39-
};
35+
struct QUIRToPulsePass
36+
: public PassWrapper<QUIRToPulsePass, OperationPass<ModuleOp>> {
37+
std::string WAVEFORM_CONTAINER = "";
38+
39+
// this pass can optionally receive a path to a file containing pulse waveform
40+
// container operations, which will contain pulse waveform operations that
41+
// will be passed as argument to pulse calibration sequences.
42+
QUIRToPulsePass() = default;
43+
QUIRToPulsePass(const QUIRToPulsePass &pass) : PassWrapper(pass) {}
44+
QUIRToPulsePass(std::string inWfrContainer) {
45+
WAVEFORM_CONTAINER = std::move(inWfrContainer);
46+
}
47+
48+
void runOnOperation() override;
49+
50+
llvm::StringRef getArgument() const override;
51+
llvm::StringRef getDescription() const override;
52+
53+
// optionally, one can override the path to pulse waveform container file with
54+
// this option; e.g., to write a LIT test one can invoke this pass with
55+
// --quir-to-pulse=waveform-container=<path-to-waveform-container-file>
56+
Option<std::string> waveformContainer{
57+
*this, "waveform-container",
58+
llvm::cl::desc("an MLIR file containing waveform container operations"),
59+
llvm::cl::value_desc("filename"), llvm::cl::init("")};
4060

61+
mlir::Operation *mainFuncFirstOp;
62+
63+
// convert quir circuit to pulse sequence
64+
void convertCircuitToSequence(mlir::quir::CallCircuitOp callCircuitOp,
65+
FuncOp &mainFunc, ModuleOp moduleOp);
66+
// helper datastructure for converting quir circuit to pulse sequence; these
67+
// will be reset every time convertCircuitToSequence is called and will be
68+
// used by several functions that are called within that function
69+
uint convertedSequenceOpArgIndex;
70+
std::map<uint, uint> circuitArgToConvertedSequenceArgMap;
71+
SmallVector<Value> convertedPulseSequenceOpArgs;
72+
std::vector<mlir::Attribute> convertedPulseCallSequenceOpOperandNames;
73+
74+
// process the args of the circuit op, and add corresponding args to the
75+
// converted pulse sequence op
76+
void processCircuitArgs(mlir::quir::CallCircuitOp callCircuitOp,
77+
mlir::quir::CircuitOp circuitOp,
78+
SequenceOp convertedPulseSequenceOp, FuncOp &mainFunc,
79+
mlir::OpBuilder &builder);
80+
81+
// process the args of the pulse cal sequence op corresponding to quirOp
82+
void processPulseCalArgs(mlir::Operation *quirOp,
83+
SequenceOp pulseCalSequenceOp,
84+
SmallVector<Value> &pulseCalSeqArgs,
85+
SequenceOp convertedPulseSequenceOp,
86+
FuncOp &mainFunc, mlir::OpBuilder &builder);
87+
void getQUIROpClassicalOperands(mlir::Operation *quirOp,
88+
std::queue<Value> &angleOperands,
89+
std::queue<Value> &durationOperands);
90+
void processMixFrameOpArg(std::string const &mixFrameName,
91+
std::string const &portName,
92+
SequenceOp convertedPulseSequenceOp,
93+
SmallVector<Value> &quirOpPulseCalSeqArgs,
94+
Value argumentValue, FuncOp &mainFunc,
95+
mlir::OpBuilder &builder);
96+
void processPortOpArg(std::string const &portName,
97+
SequenceOp convertedPulseSequenceOp,
98+
SmallVector<Value> &quirOpPulseCalSeqArgs,
99+
Value argumentValue, FuncOp &mainFunc,
100+
mlir::OpBuilder &builder);
101+
void processWfrOpArg(std::string const &wfrName,
102+
SequenceOp convertedPulseSequenceOp,
103+
SmallVector<Value> &quirOpPulseCalSeqArgs,
104+
Value argumentValue, FuncOp &mainFunc,
105+
mlir::OpBuilder &builder);
106+
void processAngleArg(Value nextAngleOperand,
107+
SequenceOp convertedPulseSequenceOp,
108+
SmallVector<Value> &quirOpPulseCalSeqArgs,
109+
mlir::OpBuilder &builder);
110+
void processDurationArg(Value frontDurOperand,
111+
SequenceOp convertedPulseSequenceOp,
112+
SmallVector<Value> &quirOpPulseCalSeqArgs,
113+
mlir::OpBuilder &builder);
114+
115+
// convert angle to F64
116+
mlir::Value convertAngleToF64(Operation *angleOp, mlir::OpBuilder &builder);
117+
// convert duration to I64
118+
mlir::Value convertDurationToI64(mlir::quir::CallCircuitOp callCircuitOp,
119+
Operation *durOp, uint &cnt,
120+
mlir::OpBuilder &builder, FuncOp &mainFunc);
121+
// map of the hashed location of quir angle/duration ops to their converted
122+
// pulse ops
123+
std::map<std::string, mlir::Value> classicalQUIROpLocToConvertedPulseOpMap;
124+
125+
// port name to Port_CreateOp map
126+
std::map<std::string, mlir::pulse::Port_CreateOp> openedPorts;
127+
// mixframe name to MixFrameOp map
128+
std::map<std::string, mlir::pulse::MixFrameOp> openedMixFrames;
129+
// waveform name to Waveform_CreateOp map
130+
std::map<std::string, mlir::pulse::Waveform_CreateOp> openedWfrs;
131+
// add a port to IR if it's not already added and return the Port_CreateOp
132+
mlir::pulse::Port_CreateOp addPortOpToIR(std::string const &portName,
133+
FuncOp &mainFunc,
134+
mlir::OpBuilder &builder);
135+
// add a mixframe to IR if it's not already added and return the MixFrameOp
136+
mlir::pulse::MixFrameOp addMixFrameOpToIR(std::string const &mixFrameName,
137+
std::string const &portName,
138+
FuncOp &mainFunc,
139+
mlir::OpBuilder &builder);
140+
// add a waveform to IR if it's not already added and return the
141+
// Waveform_CreateOp
142+
mlir::pulse::Waveform_CreateOp addWfrOpToIR(std::string const &wfrName,
143+
FuncOp &mainFunc,
144+
mlir::OpBuilder &builder);
145+
146+
void addCircuitToEraseList(mlir::Operation *op);
147+
void addCallCircuitToEraseList(mlir::Operation *op);
148+
void addCircuitOperandToEraseList(mlir::Operation *op);
149+
std::vector<mlir::Operation *> quirCircuitEraseList;
150+
std::vector<mlir::Operation *> quirCallCircuitEraseList;
151+
std::vector<mlir::Operation *> quirCircuitOperandEraseList;
152+
153+
// parse the waveform containers and add them to pulseNameToWaveformMap
154+
void parsePulseWaveformContainerOps(std::string &waveformContainerPath);
155+
std::map<std::string, Waveform_CreateOp> pulseNameToWaveformMap;
156+
157+
static mlir::quir::CircuitOp
158+
getCircuitOp(mlir::quir::CallCircuitOp callCircuitOp);
159+
};
41160
} // namespace mlir::pulse
42161

43-
#endif // PULSE_CONVERSION_QUIRTOPULSE_H
162+
#endif // QUIRTOPULSE_CONVERSION_H

lib/Conversion/QUIRToPulse/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
add_mlir_conversion_library(QUIRToPulse
1414

1515
LoadPulseCals.cpp
16-
17-
detail/TypeConversion.cpp
16+
QUIRToPulse.cpp
1817

1918
ADDITIONAL_HEADER_DIRS
2019
${PROJECT_SOURCE_DIR}/include/Conversion/QUIRToPulse/

lib/Conversion/QUIRToPulse/LoadPulseCals.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ void LoadPulseCalsPass::runOnOperation() {
4646

4747
// parse the default pulse calibrations
4848
if (!DEFAULT_PULSE_CALS.empty()) {
49-
LLVM_DEBUG(llvm::errs() << "parsing default pulse calibrations.\n");
49+
LLVM_DEBUG(llvm::dbgs() << "parsing default pulse calibrations.\n");
5050
if (auto err = parsePulseCalsModuleOp(DEFAULT_PULSE_CALS,
5151
defaultPulseCalsModule)) {
52-
llvm::errs() << err;
52+
llvm::dbgs() << err;
5353
return signalPassFailure();
5454
}
5555
// add sequence Ops to pulseCalsNameToSequenceMap
@@ -58,15 +58,15 @@ void LoadPulseCalsPass::runOnOperation() {
5858
pulseCalsNameToSequenceMap[sequenceName] = sequenceOp;
5959
});
6060
} else
61-
LLVM_DEBUG(llvm::errs()
61+
LLVM_DEBUG(llvm::dbgs()
6262
<< "default pulse calibrations path is not specified.\n");
6363

6464
// parse the additional pulse calibrations
6565
if (!ADDITIONAL_PULSE_CALS.empty()) {
66-
LLVM_DEBUG(llvm::errs() << "parsing additional pulse calibrations.\n");
66+
LLVM_DEBUG(llvm::dbgs() << "parsing additional pulse calibrations.\n");
6767
if (auto err = parsePulseCalsModuleOp(ADDITIONAL_PULSE_CALS,
6868
additionalPulseCalsModule)) {
69-
llvm::errs() << err;
69+
llvm::dbgs() << err;
7070
return signalPassFailure();
7171
}
7272
// add sequence Ops to pulseCalsNameToSequenceMap
@@ -75,11 +75,11 @@ void LoadPulseCalsPass::runOnOperation() {
7575
pulseCalsNameToSequenceMap[sequenceName] = sequenceOp;
7676
});
7777
} else
78-
LLVM_DEBUG(llvm::errs()
78+
LLVM_DEBUG(llvm::dbgs()
7979
<< "additional pulse calibrations path is not specified.\n");
8080

8181
// parse the user specified pulse calibrations
82-
LLVM_DEBUG(llvm::errs() << "parsing user specified pulse calibrations.\n");
82+
LLVM_DEBUG(llvm::dbgs() << "parsing user specified pulse calibrations.\n");
8383
moduleOp->walk([&](mlir::pulse::SequenceOp sequenceOp) {
8484
auto sequenceName = sequenceOp.sym_name().str();
8585
pulseCalsNameToSequenceMap[sequenceName] = sequenceOp;
@@ -113,7 +113,7 @@ void LoadPulseCalsPass::loadPulseCals(CallCircuitOp callCircuitOp,
113113
else if (auto castOp = dyn_cast<mlir::quir::ResetQubitOp>(op))
114114
loadPulseCals(castOp, callCircuitOp, funcOp);
115115
else {
116-
LLVM_DEBUG(llvm::errs() << "no pulse cal loading needed for " << op);
116+
LLVM_DEBUG(llvm::dbgs() << "no pulse cal loading needed for " << op);
117117
assert((!op->hasTrait<mlir::quir::UnitaryOp>() and
118118
!op->hasTrait<mlir::quir::CPTPOp>()) &&
119119
"unkown operation");
@@ -200,7 +200,7 @@ void LoadPulseCalsPass::loadPulseCals(MeasureOp measureOp,
200200
// check if there exists pulse calibrations for individual qubits, and if
201201
// yes, merge them and add the merged pulse sequence to the module
202202
std::vector<SequenceOp> sequenceOps;
203-
for (auto &qubit : qubits) {
203+
for (const auto &qubit : qubits) {
204204
std::string individualGateMangledName = getMangledName(gateName, qubit);
205205
assert(pulseCalsNameToSequenceMap.find(individualGateMangledName) !=
206206
pulseCalsNameToSequenceMap.end() &&
@@ -236,7 +236,7 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::BarrierOp barrierOp,
236236
// check if there exists pulse calibrations for individual qubits, and if
237237
// yes, merge them and add the merged pulse sequence to the module
238238
std::vector<SequenceOp> sequenceOps;
239-
for (auto &qubit : qubits) {
239+
for (const auto &qubit : qubits) {
240240
std::string individualGateMangledName = getMangledName(gateName, qubit);
241241
assert(pulseCalsNameToSequenceMap.find(individualGateMangledName) !=
242242
pulseCalsNameToSequenceMap.end() &&
@@ -272,7 +272,7 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::DelayOp delayOp,
272272
// check if there exists pulse calibrations for individual qubits, and if
273273
// yes, merge them and add the merged pulse sequence to the module
274274
std::vector<SequenceOp> sequenceOps;
275-
for (auto &qubit : qubits) {
275+
for (const auto &qubit : qubits) {
276276
std::string individualGateMangledName = getMangledName(gateName, qubit);
277277
assert(pulseCalsNameToSequenceMap.find(individualGateMangledName) !=
278278
pulseCalsNameToSequenceMap.end() &&
@@ -309,7 +309,7 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::ResetQubitOp resetOp,
309309
// check if there exists pulse calibrations for individual qubits, and if
310310
// yes, merge them and add the merged pulse sequence to the module
311311
std::vector<SequenceOp> sequenceOps;
312-
for (auto &qubit : qubits) {
312+
for (const auto &qubit : qubits) {
313313
std::string individualGateMangledName = getMangledName(gateName, qubit);
314314
assert(pulseCalsNameToSequenceMap.find(individualGateMangledName) !=
315315
pulseCalsNameToSequenceMap.end() &&
@@ -329,11 +329,11 @@ void LoadPulseCalsPass::addPulseCalToModule(
329329
pulseCalsAddedToIR.end()) {
330330
OpBuilder builder(funcOp.body());
331331
auto *clonedPulseCalOp = builder.clone(*sequenceOp);
332-
auto clonedPulseCalSequenceOp = dyn_cast<SequenceOp>(clonedPulseCalOp);
332+
auto clonedPulseCalSequenceOp = static_cast<SequenceOp>(clonedPulseCalOp);
333333
clonedPulseCalSequenceOp->moveBefore(funcOp);
334334
pulseCalsAddedToIR.insert(sequenceOp.sym_name().str());
335335
} else
336-
LLVM_DEBUG(llvm::errs() << "pulse cal " << sequenceOp.sym_name().str()
336+
LLVM_DEBUG(llvm::dbgs() << "pulse cal " << sequenceOp.sym_name().str()
337337
<< " is already added to IR.\n");
338338
}
339339

@@ -529,7 +529,7 @@ bool LoadPulseCalsPass::doAllSequenceOpsHaveSameDuration(
529529
std::vector<mlir::pulse::SequenceOp> &sequenceOps) {
530530
bool prevSequenceEncountered = false;
531531
uint prevSequencePulseDuration = 0;
532-
for (auto &sequenceOp : sequenceOps) {
532+
for (const auto &sequenceOp : sequenceOps) {
533533
if (!sequenceOp->hasAttrOfType<IntegerAttr>("pulse.duration"))
534534
return false;
535535

@@ -550,7 +550,7 @@ bool LoadPulseCalsPass::mergeAttributes(
550550
const std::string &attrName, std::vector<mlir::Attribute> &attrVector) {
551551

552552
bool allSequenceOpsHasAttr = true;
553-
for (auto &sequenceOp : sequenceOps) {
553+
for (const auto &sequenceOp : sequenceOps) {
554554
if (sequenceOp->hasAttr(attrName)) {
555555
auto pulseArgs = sequenceOp->getAttrOfType<ArrayAttr>(attrName);
556556
for (auto arg : pulseArgs)

0 commit comments

Comments
 (0)