Skip to content

Commit a203d55

Browse files
taminobdenialhaag
andauthored
🐛 Fix CtrlOp::getBodyUnitary() for operations with parameters (#1464)
## Description While debugging test cases in #1426, I finally figured out why `CtrlOp::getBodyUnitary()` sometimes returned an invalid `UnitaryOpInterface`. Whenever an operation has at least one parameter, the first operation in the modifier's body will be a `arith.constant` and *not* the unitary operation like it is currently assumed in the function. This PR resolves this by iterating over all operations in the body and only returning when the first `UnitaryOpInterface` operation has been found or there is none. It affects all `crx`, `crzz`, `cu`, ... operations. Required for #1426 ## Checklist: <!--- This checklist serves as a reminder of a couple of things that ensure your pull request will be merged swiftly. --> - [x] The pull request only contains commits that are focused and relevant to this change. - [x] I have added appropriate tests that cover the new/changed functionality. - [ ] I have updated the documentation to reflect these changes. - [ ] I have added entries to the changelog for any noteworthy additions, changes, fixes, or removals. - [ ] I have added migration instructions to the upgrade guide (if needed). - [x] The changes follow the project's style guidelines and introduce no new warnings. - [ ] The changes are fully tested and pass the CI checks. - [x] I have reviewed my own code changes. --------- Co-authored-by: Daniel Haag <121057143+denialhaag@users.noreply.github.com>
1 parent 6928519 commit a203d55

File tree

4 files changed

+69
-19
lines changed

4 files changed

+69
-19
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel
1111

1212
### Added
1313

14-
- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1465]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**])
14+
- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**])
1515

1616
### Changed
1717

@@ -317,6 +317,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool
317317

318318
[#1466]: https://github.com/munich-quantum-toolkit/core/pull/1466
319319
[#1465]: https://github.com/munich-quantum-toolkit/core/pull/1465
320+
[#1464]: https://github.com/munich-quantum-toolkit/core/pull/1464
320321
[#1458]: https://github.com/munich-quantum-toolkit/core/pull/1458
321322
[#1453]: https://github.com/munich-quantum-toolkit/core/pull/1453
322323
[#1447]: https://github.com/munich-quantum-toolkit/core/pull/1447

mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/QC/Builder/QCProgramBuilder.h"
1212

1313
#include "mlir/Dialect/QC/IR/QCDialect.h"
14+
#include "mlir/Dialect/Utils/Utils.h"
1415

1516
#include <cstdint>
1617
#include <functional>
@@ -30,6 +31,8 @@
3031
#include <utility>
3132
#include <variant>
3233

34+
using namespace mlir::utils;
35+
3336
namespace mlir::qc {
3437

3538
QCProgramBuilder::QCProgramBuilder(MLIRContext* context)
@@ -167,7 +170,8 @@ QCProgramBuilder& QCProgramBuilder::reset(Value qubit) {
167170
QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \
168171
const std::variant<double, Value>&(PARAM), ValueRange controls) { \
169172
checkFinalized(); \
170-
CtrlOp::create(*this, controls, [&] { OP_CLASS::create(*this, PARAM); }); \
173+
auto param = variantToValue(*this, getLoc(), PARAM); \
174+
CtrlOp::create(*this, controls, [&] { OP_CLASS::create(*this, param); }); \
171175
return *this; \
172176
}
173177

@@ -228,8 +232,9 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg)
228232
const std::variant<double, Value>&(PARAM), ValueRange controls, \
229233
Value target) { \
230234
checkFinalized(); \
235+
auto param = variantToValue(*this, getLoc(), PARAM); \
231236
CtrlOp::create(*this, controls, \
232-
[&] { OP_CLASS::create(*this, target, PARAM); }); \
237+
[&] { OP_CLASS::create(*this, target, param); }); \
233238
return *this; \
234239
}
235240

@@ -262,8 +267,10 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, theta)
262267
const std::variant<double, Value>&(PARAM2), ValueRange controls, \
263268
Value target) { \
264269
checkFinalized(); \
270+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
271+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
265272
CtrlOp::create(*this, controls, \
266-
[&] { OP_CLASS::create(*this, target, PARAM1, PARAM2); }); \
273+
[&] { OP_CLASS::create(*this, target, param1, param2); }); \
267274
return *this; \
268275
}
269276

@@ -298,8 +305,11 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda)
298305
const std::variant<double, Value>&(PARAM3), ValueRange controls, \
299306
Value target) { \
300307
checkFinalized(); \
308+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
309+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
310+
auto param3 = variantToValue(*this, getLoc(), PARAM3); \
301311
CtrlOp::create(*this, controls, [&] { \
302-
OP_CLASS::create(*this, target, PARAM1, PARAM2, PARAM3); \
312+
OP_CLASS::create(*this, target, param1, param2, param3); \
303313
}); \
304314
return *this; \
305315
}
@@ -355,8 +365,9 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr)
355365
const std::variant<double, Value>&(PARAM), ValueRange controls, \
356366
Value qubit0, Value qubit1) { \
357367
checkFinalized(); \
368+
auto param = variantToValue(*this, getLoc(), PARAM); \
358369
CtrlOp::create(*this, controls, \
359-
[&] { OP_CLASS::create(*this, qubit0, qubit1, PARAM); }); \
370+
[&] { OP_CLASS::create(*this, qubit0, qubit1, param); }); \
360371
return *this; \
361372
}
362373

@@ -390,8 +401,10 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta)
390401
const std::variant<double, Value>&(PARAM2), ValueRange controls, \
391402
Value qubit0, Value qubit1) { \
392403
checkFinalized(); \
404+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
405+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
393406
CtrlOp::create(*this, controls, [&] { \
394-
OP_CLASS::create(*this, qubit0, qubit1, PARAM1, PARAM2); \
407+
OP_CLASS::create(*this, qubit0, qubit1, param1, param2); \
395408
}); \
396409
return *this; \
397410
}

mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h"
1212

1313
#include "mlir/Dialect/QCO/IR/QCODialect.h"
14+
#include "mlir/Dialect/Utils/Utils.h"
1415

1516
#include <cstddef>
1617
#include <cstdint>
@@ -32,6 +33,8 @@
3233
#include <utility>
3334
#include <variant>
3435

36+
using namespace mlir::utils;
37+
3538
namespace mlir::qco {
3639

3740
QCOProgramBuilder::QCOProgramBuilder(MLIRContext* context)
@@ -210,10 +213,11 @@ Value QCOProgramBuilder::reset(Value qubit) {
210213
Value QCOProgramBuilder::c##OP_NAME( \
211214
const std::variant<double, Value>&(PARAM), Value control) { \
212215
checkFinalized(); \
216+
auto param = variantToValue(*this, getLoc(), PARAM); \
213217
const auto controlsOut = \
214218
ctrl(control, {}, \
215219
[&](ValueRange /*targets*/) -> llvm::SmallVector<Value> { \
216-
OP_NAME(PARAM); \
220+
OP_NAME(param); \
217221
return {}; \
218222
}) \
219223
.first; \
@@ -222,10 +226,11 @@ Value QCOProgramBuilder::reset(Value qubit) {
222226
ValueRange QCOProgramBuilder::mc##OP_NAME( \
223227
const std::variant<double, Value>&(PARAM), ValueRange controls) { \
224228
checkFinalized(); \
229+
auto param = variantToValue(*this, getLoc(), PARAM); \
225230
const auto controlsOut = \
226231
ctrl(controls, {}, \
227232
[&](ValueRange /*targets*/) -> llvm::SmallVector<Value> { \
228-
OP_NAME(PARAM); \
233+
OP_NAME(param); \
229234
return {}; \
230235
}) \
231236
.first; \
@@ -295,20 +300,22 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg)
295300
const std::variant<double, Value>&(PARAM), Value control, \
296301
Value target) { \
297302
checkFinalized(); \
303+
auto param = variantToValue(*this, getLoc(), PARAM); \
298304
const auto [controlsOut, targetsOut] = ctrl( \
299305
control, target, [&](ValueRange targets) -> llvm::SmallVector<Value> { \
300-
return {OP_NAME(PARAM, targets[0])}; \
306+
return {OP_NAME(param, targets[0])}; \
301307
}); \
302308
return {controlsOut[0], targetsOut[0]}; \
303309
} \
304310
std::pair<ValueRange, Value> QCOProgramBuilder::mc##OP_NAME( \
305311
const std::variant<double, Value>&(PARAM), ValueRange controls, \
306312
Value target) { \
307313
checkFinalized(); \
314+
auto param = variantToValue(*this, getLoc(), PARAM); \
308315
const auto [controlsOut, targetsOut] = \
309316
ctrl(controls, target, \
310317
[&](ValueRange targets) -> llvm::SmallVector<Value> { \
311-
return {OP_NAME(PARAM, targets[0])}; \
318+
return {OP_NAME(param, targets[0])}; \
312319
}); \
313320
return {controlsOut, targetsOut[0]}; \
314321
}
@@ -337,9 +344,11 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, phi)
337344
const std::variant<double, Value>&(PARAM2), Value control, \
338345
Value target) { \
339346
checkFinalized(); \
347+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
348+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
340349
const auto [controlsOut, targetsOut] = ctrl( \
341350
control, target, [&](ValueRange targets) -> llvm::SmallVector<Value> { \
342-
return {OP_NAME(PARAM1, PARAM2, targets[0])}; \
351+
return {OP_NAME(param1, param2, targets[0])}; \
343352
}); \
344353
return {controlsOut[0], targetsOut[0]}; \
345354
} \
@@ -348,10 +357,12 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, phi)
348357
const std::variant<double, Value>&(PARAM2), ValueRange controls, \
349358
Value target) { \
350359
checkFinalized(); \
360+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
361+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
351362
const auto [controlsOut, targetsOut] = \
352363
ctrl(controls, target, \
353364
[&](ValueRange targets) -> llvm::SmallVector<Value> { \
354-
return {OP_NAME(PARAM1, PARAM2, targets[0])}; \
365+
return {OP_NAME(param1, param2, targets[0])}; \
355366
}); \
356367
return {controlsOut, targetsOut[0]}; \
357368
}
@@ -381,9 +392,12 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda)
381392
const std::variant<double, Value>&(PARAM3), Value control, \
382393
Value target) { \
383394
checkFinalized(); \
395+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
396+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
397+
auto param3 = variantToValue(*this, getLoc(), PARAM3); \
384398
const auto [controlsOut, targetsOut] = ctrl( \
385399
control, target, [&](ValueRange targets) -> llvm::SmallVector<Value> { \
386-
return {OP_NAME(PARAM1, PARAM2, PARAM3, targets[0])}; \
400+
return {OP_NAME(param1, param2, param3, targets[0])}; \
387401
}); \
388402
return {controlsOut[0], targetsOut[0]}; \
389403
} \
@@ -393,10 +407,13 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda)
393407
const std::variant<double, Value>&(PARAM3), ValueRange controls, \
394408
Value target) { \
395409
checkFinalized(); \
410+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
411+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
412+
auto param3 = variantToValue(*this, getLoc(), PARAM3); \
396413
const auto [controlsOut, targetsOut] = \
397414
ctrl(controls, target, \
398415
[&](ValueRange targets) -> llvm::SmallVector<Value> { \
399-
return {OP_NAME(PARAM1, PARAM2, PARAM3, targets[0])}; \
416+
return {OP_NAME(param1, param2, param3, targets[0])}; \
400417
}); \
401418
return {controlsOut, targetsOut[0]}; \
402419
}
@@ -466,10 +483,11 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr)
466483
const std::variant<double, Value>&(PARAM), Value control, Value qubit0, \
467484
Value qubit1) { \
468485
checkFinalized(); \
486+
auto param = variantToValue(*this, getLoc(), PARAM); \
469487
const auto [controlsOut, targetsOut] = \
470488
ctrl(control, {qubit0, qubit1}, \
471489
[&](ValueRange targets) -> llvm::SmallVector<Value> { \
472-
auto [q0, q1] = OP_NAME(PARAM, targets[0], targets[1]); \
490+
auto [q0, q1] = OP_NAME(param, targets[0], targets[1]); \
473491
return {q0, q1}; \
474492
}); \
475493
return {controlsOut[0], {targetsOut[0], targetsOut[1]}}; \
@@ -479,10 +497,11 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr)
479497
const std::variant<double, Value>&(PARAM), ValueRange controls, \
480498
Value qubit0, Value qubit1) { \
481499
checkFinalized(); \
500+
auto param = variantToValue(*this, getLoc(), PARAM); \
482501
const auto [controlsOut, targetsOut] = \
483502
ctrl(controls, {qubit0, qubit1}, \
484503
[&](ValueRange targets) -> llvm::SmallVector<Value> { \
485-
auto [q0, q1] = OP_NAME(PARAM, targets[0], targets[1]); \
504+
auto [q0, q1] = OP_NAME(param, targets[0], targets[1]); \
486505
return {q0, q1}; \
487506
}); \
488507
return {controlsOut, {targetsOut[0], targetsOut[1]}}; \
@@ -515,11 +534,13 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta)
515534
const std::variant<double, Value>&(PARAM2), Value control, Value qubit0, \
516535
Value qubit1) { \
517536
checkFinalized(); \
537+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
538+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
518539
const auto [controlsOut, targetsOut] = \
519540
ctrl(control, {qubit0, qubit1}, \
520541
[&](ValueRange targets) -> llvm::SmallVector<Value> { \
521542
auto [q0, q1] = \
522-
OP_NAME(PARAM1, PARAM2, targets[0], targets[1]); \
543+
OP_NAME(param1, param2, targets[0], targets[1]); \
523544
return {q0, q1}; \
524545
}); \
525546
return {controlsOut[0], {targetsOut[0], targetsOut[1]}}; \
@@ -530,11 +551,13 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta)
530551
const std::variant<double, Value>&(PARAM2), ValueRange controls, \
531552
Value qubit0, Value qubit1) { \
532553
checkFinalized(); \
554+
auto param1 = variantToValue(*this, getLoc(), PARAM1); \
555+
auto param2 = variantToValue(*this, getLoc(), PARAM2); \
533556
const auto [controlsOut, targetsOut] = \
534557
ctrl(controls, {qubit0, qubit1}, \
535558
[&](ValueRange targets) -> llvm::SmallVector<Value> { \
536559
auto [q0, q1] = \
537-
OP_NAME(PARAM1, PARAM2, targets[0], targets[1]); \
560+
OP_NAME(param1, param2, targets[0], targets[1]); \
538561
return {q0, q1}; \
539562
}); \
540563
return {controlsOut, {targetsOut[0], targetsOut[1]}}; \

mlir/unittests/Dialect/QCO/IR/Modifiers/test_qco_ctrl.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,16 @@ TEST_F(QCOCtrlOpTest, ParserErrors) {
221221
.get(),
222222
nullptr);
223223
}
224+
225+
TEST_F(QCOCtrlOpTest, bodyUnitaryWithParameter) {
226+
auto reg = builder.allocQubitRegister(2);
227+
builder.crx(1.0, reg[0], reg[1]);
228+
auto ctrlOp = cast<CtrlOp>(builder.getBlock()->getOperations().back());
229+
module = builder.finalize();
230+
231+
auto bodyUnitary = ctrlOp.getBodyUnitary();
232+
// Test if a valid unitary operation is returned
233+
ASSERT_TRUE(bodyUnitary);
234+
// Ensure it contains the correct operation type
235+
EXPECT_EQ(bodyUnitary.getBaseSymbol(), "rx");
236+
}

0 commit comments

Comments
 (0)