diff --git a/CHANGELOG.md b/CHANGELOG.md index 98bb0296f..fa7dc54f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added -- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**]) +- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430], [#1436]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**]) ### Changed @@ -308,6 +308,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1437]: https://github.com/munich-quantum-toolkit/core/pull/1437 +[#1436]: https://github.com/munich-quantum-toolkit/core/pull/1436 [#1430]: https://github.com/munich-quantum-toolkit/core/pull/1430 [#1428]: https://github.com/munich-quantum-toolkit/core/pull/1428 [#1415]: https://github.com/munich-quantum-toolkit/core/pull/1415 diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 2d8917c33..b201e8cea 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -988,21 +989,21 @@ class QCOProgramBuilder final : public OpBuilder { * @par Example: * ```c++ * {controls_out, targets_out} = - * builder.ctrl(q0_in, q1_in, [&](ValueRange targets) { - * auto q1_res = builder.x(targets[0]); - * return {q1_res}; + * builder.ctrl(q0_in, q1_in, + * [&](ValueRange targets) -> llvm::SmallVector { + * return {builder.x(targets[0])}; * }); * ``` * ```mlir - * %controls_out, %targets_out = qco.ctrl(%q0_in) %q1_in { - * %q1_res = qco.x %q1_in : !qco.qubit -> !qco.qubit + * %controls_out, %targets_out = qco.ctrl(%q0_in) targets(%t = %q1_in) { + * %q1_res = qco.x %t : !qco.qubit -> !qco.qubit * qco.yield %q1_res * } : ({!qco.qubit}, {!qco.qubit}) -> ({!qco.qubit}, {!qco.qubit}) * ``` */ std::pair ctrl(ValueRange controls, ValueRange targets, - const std::function& body); + llvm::function_ref(ValueRange)> body); //===--------------------------------------------------------------------===// // Deallocation diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index a85943941..aef50a319 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1025,10 +1025,10 @@ def YieldOp : QCOOp<"yield", traits = [Terminator]> { Example: ```mlir - %ctrl_q_out, %tgt_q_out = qco.ctrl(%ctrl_q_in) %tgt_q_in { - %tgt_q_res = qco.h %tgt_q_in : !qco.qubit -> !qco.qubit - qco.yield %tgt_q_res : !qco.qubit - } : ({!qco.qubit}, {!qco.qubit}) -> ({!qco.qubit}, {!qco.qubit}) + %res_ctrl, %res_tgt:2 = qco.ctrl(%ctrl) targets(%a0 = %q0, %a1 = %q1) { + %a0_1, %a1_1 = qco.swap %a0, %a1 : !qco.qubit, !qco.qubit -> !qco.qubit, !qco.qubit + qco.yield %a0_1, %a1_1 + } : ({!qco.qubit}, {!qco.qubit, !qco.qubit}) -> ({!qco.qubit}, {!qco.qubit, !qco.qubit}) ``` }]; @@ -1057,10 +1057,10 @@ def CtrlOp : QCOOp<"ctrl", traits = Example: ```mlir - %ctrl_q_out, %tgt_q_out = qco.ctrl(%ctrl_q_in) %tgt_q_in { - %tgt_q_res = qco.h %tgt_q_in : !qco.qubit -> !qco.qubit - qco.yield %tgt_q_res : !qco.qubit - } : ({!qco.qubit}, {!qco.qubit}) -> ({!qco.qubit}, {!qco.qubit}) + %res_ctrl, %res_tgt:2 = qco.ctrl(%ctrl) targets(%a0 = %q0, %a1 = %q1) { + %a0_1, %a1_1 = qco.swap %a0, %a1 : !qco.qubit, !qco.qubit -> !qco.qubit, !qco.qubit + qco.yield %a0_1, %a1_1 + } : ({!qco.qubit}, {!qco.qubit, !qco.qubit}) -> ({!qco.qubit}, {!qco.qubit, !qco.qubit}) ``` }]; @@ -1069,8 +1069,10 @@ def CtrlOp : QCOOp<"ctrl", traits = let results = (outs Variadic:$controls_out, Variadic:$targets_out); let regions = (region SizedRegion<1>:$region); let assemblyFormat = [{ - `(` $controls_in `)` $targets_in - $region attr-dict `:` + `(` $controls_in `)` + `targets` + custom($region, $targets_in) + attr-dict `:` `(` `{` type($controls_in) `}` ( `,` `{` type($targets_in)^ `}` )? `)` `->` `(` `{` type($controls_out) `}` ( `,` `{` type($targets_out)^ `}` )? `)` @@ -1099,7 +1101,7 @@ def CtrlOp : QCOOp<"ctrl", traits = build($_builder, $_state, controls.getTypes(), targets.getTypes(), controls, targets); }]>, OpBuilder<(ins "ValueRange":$controls, "ValueRange":$targets, "UnitaryOpInterface":$bodyUnitary)>, - OpBuilder<(ins "ValueRange":$controls, "ValueRange":$targets, "const std::function&":$bodyBuilder)> + OpBuilder<(ins "ValueRange":$controls, "ValueRange":$targets, "llvm::function_ref(ValueRange)>":$bodyBuilder)> ]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 5685258f6..501432fc6 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -774,12 +774,37 @@ struct ConvertQCOCtrlOp final : OpConversionPattern { const auto& qcControls = adaptor.getControlsIn(); // Create qc.ctrl operation - auto qcoOp = rewriter.create(op.getLoc(), qcControls); + auto qcOp = qc::CtrlOp::create(rewriter, op.getLoc(), qcControls); // Clone body region from QCO to QC - auto& dstRegion = qcoOp.getRegion(); + auto& dstRegion = qcOp.getRegion(); rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + auto& entryBlock = dstRegion.front(); + const auto numArgs = entryBlock.getNumArguments(); + if (adaptor.getTargetsIn().size() != numArgs) { + return op.emitOpError() << "qco.ctrl: entry block args (" << numArgs + << ") must match number of target operands (" + << adaptor.getTargetsIn().size() << ")"; + } + + // Remove all block arguments in the cloned region + rewriter.modifyOpInPlace(qcOp, [&] { + // 1. Replace uses (Must be done BEFORE erasing) + // We iterate 0..N using indices since the block args are still stable + // here. + for (auto i = 0UL; i < numArgs; ++i) { + entryBlock.getArgument(i).replaceAllUsesWith(adaptor.getTargetsIn()[i]); + } + + // 2. Erase all block arguments + // Now that they have no uses, we can safely wipe them. + // We use a bulk erase for efficiency (start index 0, count N). + if (numArgs > 0) { + entryBlock.eraseArguments(0, numArgs); + } + }); + // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 9b1fc5b7b..19fb7df80 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1106,7 +1106,7 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { // Create qco.ctrl auto qcoOp = - rewriter.create(op.getLoc(), qcoControls, qcoTargets); + qco::CtrlOp::create(rewriter, op.getLoc(), qcoControls, qcoTargets); // Update the state map if this is a top-level CtrlOp // Nested CtrlOps are managed via the targetsIn and targetsOut maps @@ -1124,12 +1124,27 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { // Update modifier information state.inCtrlOp++; - state.targetsIn.try_emplace(state.inCtrlOp, qcoTargets); // Clone body region from QC to QCO auto& dstRegion = qcoOp.getRegion(); rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + // Create block arguments for target qubits and store them in + // `state.targetsIn`. + auto& entryBlock = dstRegion.front(); + assert(entryBlock.getNumArguments() == 0 && + "QC ctrl region unexpectedly has entry block arguments"); + SmallVector qcoTargetAliases; + qcoTargetAliases.reserve(numTargets); + const auto qubitType = qco::QubitType::get(qcoOp.getContext()); + const auto opLoc = op.getLoc(); + rewriter.modifyOpInPlace(qcoOp, [&] { + for (auto i = 0UL; i < numTargets; i++) { + qcoTargetAliases.emplace_back(entryBlock.addArgument(qubitType, opLoc)); + } + }); + state.targetsIn[state.inCtrlOp] = std::move(qcoTargetAliases); + rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 46d7fbf34..07a0f41a2 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -14,8 +14,8 @@ #include #include -#include #include +#include #include #include #include @@ -209,20 +209,24 @@ Value QCOProgramBuilder::reset(Value qubit) { const std::variant&(PARAM), Value control) { \ checkFinalized(); \ const auto controlsOut = \ - ctrl(control, {}, [&](ValueRange /*targets*/) -> ValueRange { \ - OP_CLASS::create(*this, loc, PARAM); \ - return {}; \ - }).first; \ + ctrl(control, {}, \ + [&](ValueRange /*targets*/) -> llvm::SmallVector { \ + OP_NAME(PARAM); \ + return {}; \ + }) \ + .first; \ return controlsOut[0]; \ } \ ValueRange QCOProgramBuilder::mc##OP_NAME( \ const std::variant&(PARAM), ValueRange controls) { \ checkFinalized(); \ const auto controlsOut = \ - ctrl(controls, {}, [&](ValueRange /*targets*/) -> ValueRange { \ - OP_CLASS::create(*this, loc, PARAM); \ - return {}; \ - }).first; \ + ctrl(controls, {}, \ + [&](ValueRange /*targets*/) -> llvm::SmallVector { \ + OP_NAME(PARAM); \ + return {}; \ + }) \ + .first; \ return controlsOut; \ } @@ -243,10 +247,9 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) std::pair QCOProgramBuilder::c##OP_NAME(Value control, \ Value target) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = \ - ctrl(control, target, [&](ValueRange targets) -> ValueRange { \ - const auto op = OP_CLASS::create(*this, loc, targets[0]); \ - return op->getResults(); \ + const auto [controlsOut, targetsOut] = ctrl( \ + control, target, [&](ValueRange targets) -> llvm::SmallVector { \ + return {OP_NAME(targets[0])}; \ }); \ return {controlsOut[0], targetsOut[0]}; \ } \ @@ -254,10 +257,10 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) ValueRange controls, Value target) { \ checkFinalized(); \ const auto [controlsOut, targetsOut] = \ - ctrl(controls, target, [&](ValueRange targets) -> ValueRange { \ - const auto op = OP_CLASS::create(*this, loc, targets[0]); \ - return op->getResults(); \ - }); \ + ctrl(controls, target, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + return {OP_NAME(targets[0])}; \ + }); \ return {controlsOut, targetsOut[0]}; \ } @@ -290,10 +293,9 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) const std::variant&(PARAM), Value control, \ Value target) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = \ - ctrl(control, target, [&](ValueRange targets) -> ValueRange { \ - const auto op = OP_CLASS::create(*this, loc, targets[0], PARAM); \ - return op->getResults(); \ + const auto [controlsOut, targetsOut] = ctrl( \ + control, target, [&](ValueRange targets) -> llvm::SmallVector { \ + return {OP_NAME(PARAM, targets[0])}; \ }); \ return {controlsOut[0], targetsOut[0]}; \ } \ @@ -302,10 +304,10 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) Value target) { \ checkFinalized(); \ const auto [controlsOut, targetsOut] = \ - ctrl(controls, target, [&](ValueRange targets) -> ValueRange { \ - const auto op = OP_CLASS::create(*this, loc, targets[0], PARAM); \ - return op->getResults(); \ - }); \ + ctrl(controls, target, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + return {OP_NAME(PARAM, targets[0])}; \ + }); \ return {controlsOut, targetsOut[0]}; \ } @@ -333,11 +335,9 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, phi) const std::variant&(PARAM2), Value control, \ Value target) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = \ - ctrl(control, target, [&](ValueRange targets) -> ValueRange { \ - const auto op = \ - OP_CLASS::create(*this, loc, targets[0], PARAM1, PARAM2); \ - return op->getResults(); \ + const auto [controlsOut, targetsOut] = ctrl( \ + control, target, [&](ValueRange targets) -> llvm::SmallVector { \ + return {OP_NAME(PARAM1, PARAM2, targets[0])}; \ }); \ return {controlsOut[0], targetsOut[0]}; \ } \ @@ -347,11 +347,10 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, phi) Value target) { \ checkFinalized(); \ const auto [controlsOut, targetsOut] = \ - ctrl(controls, target, [&](ValueRange targets) -> ValueRange { \ - const auto op = \ - OP_CLASS::create(*this, loc, targets[0], PARAM1, PARAM2); \ - return op->getResults(); \ - }); \ + ctrl(controls, target, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + return {OP_NAME(PARAM1, PARAM2, targets[0])}; \ + }); \ return {controlsOut, targetsOut[0]}; \ } @@ -380,11 +379,9 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) const std::variant&(PARAM3), Value control, \ Value target) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = \ - ctrl(control, target, [&](ValueRange targets) -> ValueRange { \ - const auto op = OP_CLASS::create(*this, loc, targets[0], PARAM1, \ - PARAM2, PARAM3); \ - return op->getResults(); \ + const auto [controlsOut, targetsOut] = ctrl( \ + control, target, [&](ValueRange targets) -> llvm::SmallVector { \ + return {OP_NAME(PARAM1, PARAM2, PARAM3, targets[0])}; \ }); \ return {controlsOut[0], targetsOut[0]}; \ } \ @@ -395,11 +392,10 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) Value target) { \ checkFinalized(); \ const auto [controlsOut, targetsOut] = \ - ctrl(controls, target, [&](ValueRange targets) -> ValueRange { \ - const auto op = OP_CLASS::create(*this, loc, targets[0], PARAM1, \ - PARAM2, PARAM3); \ - return op->getResults(); \ - }); \ + ctrl(controls, target, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + return {OP_NAME(PARAM1, PARAM2, PARAM3, targets[0])}; \ + }); \ return {controlsOut, targetsOut[0]}; \ } @@ -423,24 +419,24 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, u, theta, phi, lambda) std::pair> QCOProgramBuilder::c##OP_NAME( \ Value control, Value qubit0, Value qubit1) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = ctrl( \ - control, {qubit0, qubit1}, [&](ValueRange targets) -> ValueRange { \ - const auto op = \ - OP_CLASS::create(*this, loc, targets[0], targets[1]); \ - return op->getResults(); \ - }); \ + const auto [controlsOut, targetsOut] = \ + ctrl(control, {qubit0, qubit1}, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + auto [q0, q1] = OP_NAME(targets[0], targets[1]); \ + return {q0, q1}; \ + }); \ return {controlsOut[0], {targetsOut[0], targetsOut[1]}}; \ } \ std::pair> \ QCOProgramBuilder::mc##OP_NAME(ValueRange controls, Value qubit0, \ Value qubit1) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = ctrl( \ - controls, {qubit0, qubit1}, [&](ValueRange targets) -> ValueRange { \ - const auto op = \ - OP_CLASS::create(*this, loc, targets[0], targets[1]); \ - return op->getResults(); \ - }); \ + const auto [controlsOut, targetsOut] = \ + ctrl(controls, {qubit0, qubit1}, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + auto [q0, q1] = OP_NAME(targets[0], targets[1]); \ + return {q0, q1}; \ + }); \ return {controlsOut, {targetsOut[0], targetsOut[1]}}; \ } @@ -468,12 +464,12 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) const std::variant&(PARAM), Value control, Value qubit0, \ Value qubit1) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = ctrl( \ - control, {qubit0, qubit1}, [&](ValueRange targets) -> ValueRange { \ - const auto op = \ - OP_CLASS::create(*this, loc, targets[0], targets[1], PARAM); \ - return op->getResults(); \ - }); \ + const auto [controlsOut, targetsOut] = \ + ctrl(control, {qubit0, qubit1}, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + auto [q0, q1] = OP_NAME(PARAM, targets[0], targets[1]); \ + return {q0, q1}; \ + }); \ return {controlsOut[0], {targetsOut[0], targetsOut[1]}}; \ } \ std::pair> \ @@ -481,12 +477,12 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) const std::variant&(PARAM), ValueRange controls, \ Value qubit0, Value qubit1) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = ctrl( \ - controls, {qubit0, qubit1}, [&](ValueRange targets) -> ValueRange { \ - const auto op = \ - OP_CLASS::create(*this, loc, targets[0], targets[1], PARAM); \ - return op->getResults(); \ - }); \ + const auto [controlsOut, targetsOut] = \ + ctrl(controls, {qubit0, qubit1}, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + auto [q0, q1] = OP_NAME(PARAM, targets[0], targets[1]); \ + return {q0, q1}; \ + }); \ return {controlsOut, {targetsOut[0], targetsOut[1]}}; \ } @@ -517,12 +513,13 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) const std::variant&(PARAM2), Value control, Value qubit0, \ Value qubit1) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = ctrl( \ - control, {qubit0, qubit1}, [&](ValueRange targets) -> ValueRange { \ - const auto op = OP_CLASS::create(*this, loc, targets[0], targets[1], \ - PARAM1, PARAM2); \ - return op->getResults(); \ - }); \ + const auto [controlsOut, targetsOut] = \ + ctrl(control, {qubit0, qubit1}, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + auto [q0, q1] = \ + OP_NAME(PARAM1, PARAM2, targets[0], targets[1]); \ + return {q0, q1}; \ + }); \ return {controlsOut[0], {targetsOut[0], targetsOut[1]}}; \ } \ std::pair> \ @@ -531,12 +528,13 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) const std::variant&(PARAM2), ValueRange controls, \ Value qubit0, Value qubit1) { \ checkFinalized(); \ - const auto [controlsOut, targetsOut] = ctrl( \ - controls, {qubit0, qubit1}, [&](ValueRange targets) -> ValueRange { \ - const auto op = OP_CLASS::create(*this, loc, targets[0], targets[1], \ - PARAM1, PARAM2); \ - return op->getResults(); \ - }); \ + const auto [controlsOut, targetsOut] = \ + ctrl(controls, {qubit0, qubit1}, \ + [&](ValueRange targets) -> llvm::SmallVector { \ + auto [q0, q1] = \ + OP_NAME(PARAM1, PARAM2, targets[0], targets[1]); \ + return {q0, q1}; \ + }); \ return {controlsOut, {targetsOut[0], targetsOut[1]}}; \ } @@ -562,12 +560,27 @@ ValueRange QCOProgramBuilder::barrier(ValueRange qubits) { // Modifiers //===----------------------------------------------------------------------===// -std::pair -QCOProgramBuilder::ctrl(ValueRange controls, ValueRange targets, - const std::function& body) { +std::pair QCOProgramBuilder::ctrl( + ValueRange controls, ValueRange targets, + llvm::function_ref(ValueRange)> body) { checkFinalized(); - auto ctrlOp = CtrlOp::create(*this, loc, controls, targets, body); + auto ctrlOp = CtrlOp::create(*this, loc, controls, targets); + auto& block = ctrlOp.getBodyRegion().emplaceBlock(); + const auto qubitType = QubitType::get(getContext()); + for (const auto target : targets) { + const auto arg = block.addArgument(qubitType, loc); + updateQubitTracking(target, arg); + } + const InsertionGuard guard(*this); + setInsertionPointToStart(&block); + const auto innerTargetsOut = body(block.getArguments()); + YieldOp::create(*this, loc, innerTargetsOut); + + if (innerTargetsOut.size() != targets.size()) { + llvm::reportFatalUsageError( + "Ctrl body must return exactly one output qubit per target"); + } // Update tracking const auto& controlsOut = ctrlOp.getControlsOut(); @@ -575,7 +588,8 @@ QCOProgramBuilder::ctrl(ValueRange controls, ValueRange targets, updateQubitTracking(control, controlOut); } const auto& targetsOut = ctrlOp.getTargetsOut(); - for (const auto& [target, targetOut] : llvm::zip(targets, targetsOut)) { + for (const auto& [target, targetOut] : + llvm::zip(innerTargetsOut, targetsOut)) { updateQubitTracking(target, targetOut); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index be7095098..5d49b3a44 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -11,12 +11,14 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include -#include #include +#include +#include #include #include #include #include +#include #include #include #include @@ -49,11 +51,8 @@ struct MergeNestedCtrl final : OpRewritePattern { } // Merge controls - SmallVector newControls(op.getControlsIn()); - for (const auto control : bodyCtrlOp.getControlsIn()) { - newControls.push_back(control); - } - + const auto newControls = llvm::to_vector( + llvm::concat(op.getControlsIn(), bodyCtrlOp.getControlsIn())); rewriter.replaceOpWithNewOp(op, newControls, op.getTargetsIn(), bodyCtrlOp.getBodyUnitary()); @@ -103,18 +102,15 @@ struct CtrlInlineGPhase final : OpRewritePattern { return failure(); } - SmallVector newControls(op.getControlsIn()); - const auto newTarget = newControls.back(); - newControls.pop_back(); - auto ctrlOp = CtrlOp::create( - rewriter, op.getLoc(), newControls, newTarget, [&](ValueRange targets) { + const auto controls = op.getControlsIn(); + rewriter.replaceOpWithNewOp( + op, controls.drop_back(), controls.back(), + [&](ValueRange targets) -> llvm::SmallVector { auto pOp = POp::create(rewriter, op.getLoc(), targets[0], gPhaseOp.getTheta()); - return pOp.getQubitOut(); + return {pOp.getQubitOut()}; }); - rewriter.replaceOp(op, ctrlOp.getResults()); - return success(); } }; @@ -129,21 +125,15 @@ struct CtrlInlineId final : OpRewritePattern { PatternRewriter& rewriter) const override { // Require at least one positive control // Trivial case is handled by RemoveTrivialCtrl - if (op.getNumControls() == 0) { - return failure(); - } - - if (!llvm::isa(op.getBodyUnitary().getOperation())) { + if (op.getNumControls() == 0 || + !llvm::isa(op.getBodyUnitary().getOperation())) { return failure(); } - auto idOp = rewriter.create(op.getLoc(), op.getTargetsIn().front()); + auto idOp = IdOp::create(rewriter, op.getLoc(), op.getTargetsIn().front()); - SmallVector newOperands; - newOperands.reserve(op.getNumControls() + 1); - newOperands.append(op.getControlsIn().begin(), op.getControlsIn().end()); - newOperands.push_back(idOp.getQubitOut()); - rewriter.replaceOp(op, newOperands); + rewriter.replaceOp(op, llvm::to_vector(llvm::concat( + op.getControlsIn(), idOp->getResults()))); return success(); } @@ -251,24 +241,36 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, controls, targets); auto& block = odsState.regions.front()->emplaceBlock(); - // Move the unitary op into the block + // Create block arguments and map targets to them + IRMapping mapping; + const auto qubitType = QubitType::get(odsBuilder.getContext()); + for (const auto target : targets) { + mapping.map(target, block.addArgument(qubitType, odsState.location)); + } + + // Clone the operation using the mapping const OpBuilder::InsertionGuard guard(odsBuilder); odsBuilder.setInsertionPointToStart(&block); - auto* op = odsBuilder.clone(*bodyUnitary.getOperation()); - odsBuilder.create(odsState.location, op->getResults()); + auto* op = odsBuilder.clone(*bodyUnitary.getOperation(), mapping); + YieldOp::create(odsBuilder, odsState.location, op->getResults()); } -void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, - ValueRange controls, ValueRange targets, - const std::function& bodyBuilder) { +void CtrlOp::build( + OpBuilder& odsBuilder, OperationState& odsState, ValueRange controls, + ValueRange targets, + llvm::function_ref(ValueRange)> bodyBuilder) { build(odsBuilder, odsState, controls, targets); auto& block = odsState.regions.front()->emplaceBlock(); - // Move the unitary op into the block + const auto qubitType = QubitType::get(odsBuilder.getContext()); + for (size_t i = 0; i < targets.size(); ++i) { + block.addArgument(qubitType, odsState.location); + } + const OpBuilder::InsertionGuard guard(odsBuilder); odsBuilder.setInsertionPointToStart(&block); - auto targetsOut = bodyBuilder(targets); - odsBuilder.create(odsState.location, targetsOut); + YieldOp::create(odsBuilder, odsState.location, + bodyBuilder(block.getArguments())); } LogicalResult CtrlOp::verify() { @@ -276,6 +278,18 @@ LogicalResult CtrlOp::verify() { if (block.getOperations().size() != 2) { return emitOpError("body region must have exactly two operations"); } + const auto numTargets = getNumTargets(); + if (block.getArguments().size() != numTargets) { + return emitOpError( + "number of block arguments must match the number of targets"); + } + const auto qubitType = QubitType::get(getContext()); + for (size_t i = 0; i < numTargets; ++i) { + if (block.getArgument(i).getType() != qubitType) { + return emitOpError("block argument type at index ") + << i << " does not match target type"; + } + } if (!llvm::isa(block.front())) { return emitOpError( "first operation in body region must be a unitary operation"); @@ -284,10 +298,10 @@ LogicalResult CtrlOp::verify() { return emitOpError( "second operation in body region must be a yield operation"); } - if (block.back().getNumOperands() != getNumTargets()) { + if (const auto numYieldOperands = block.back().getNumOperands(); + numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") - << getNumTargets() << " values, but found " - << block.back().getNumOperands(); + << numTargets << " values, but found " << numYieldOperands; } SmallPtrSet uniqueQubitsIn; @@ -296,13 +310,34 @@ LogicalResult CtrlOp::verify() { return emitOpError("duplicate control qubit found"); } } + for (const auto& target : getTargetsIn()) { + if (!uniqueQubitsIn.insert(target).second) { + return emitOpError("duplicate target qubit found"); + } + } + auto bodyUnitary = getBodyUnitary(); + if (bodyUnitary.getNumQubits() != numTargets) { + return emitOpError("body unitary must operate on exactly ") + << numTargets << " target qubits, but found " + << bodyUnitary.getNumQubits(); + } const auto numQubits = bodyUnitary.getNumQubits(); for (size_t i = 0; i < numQubits; i++) { - if (!uniqueQubitsIn.insert(bodyUnitary.getInputQubit(i)).second) { - return emitOpError("duplicate qubit found"); + if (bodyUnitary.getInputQubit(i) != block.getArgument(i)) { + return emitOpError("body unitary must use target alias block argument ") + << i << " (and not the original target operand)"; + } + } + + // Also require yield to forward the unitary's outputs in-order. + for (size_t i = 0; i < numTargets; ++i) { + if (block.back().getOperand(i) != bodyUnitary.getOutputQubit(i)) { + return emitOpError("yield operand ") + << i << " must be the body unitary output qubit " << i; } } + SmallPtrSet uniqueQubitsOut; for (const auto& control : getControlsOut()) { if (!uniqueQubitsOut.insert(control).second) { diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index 54e58d67e..111fa4a68 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -10,6 +10,14 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" // IWYU pragma: associated +#include +#include +#include +#include +#include +#include +#include + // The following headers are needed for some template instantiations. // IWYU pragma: begin_keep #include @@ -21,6 +29,90 @@ using namespace mlir; using namespace mlir::qco; +//===----------------------------------------------------------------------===// +// Custom Parsers +//===----------------------------------------------------------------------===// + +static ParseResult +parseTargetAliasing(OpAsmParser& parser, Region& region, + SmallVectorImpl& operands) { + // 1. Parse the opening parenthesis + if (parser.parseLParen()) { + return failure(); + } + + // Temporary storage for block arguments we are about to create + SmallVector blockArgs; + + // 2. Prepare to parse the list + if (failed(parser.parseOptionalRParen())) { + do { + OpAsmParser::Argument newArg; // The "new" variable name + OpAsmParser::UnresolvedOperand oldOperand; // The "old" input variable + + // Parse "%new" + if (parser.parseArgument(newArg)) { + return failure(); + } + + // Parse "=" + if (parser.parseEqual()) { + return failure(); + } + + // Parse "%old" + if (parser.parseOperand(oldOperand)) { + return failure(); + } + operands.push_back(oldOperand); + + // Hard-code QubitType since targets in qco.ctrl are always qubits. + // This avoids double-binding type($targets_in) in the assembly format + // while keeping the parser simple and the assembly format clean. + newArg.type = QubitType::get(parser.getBuilder().getContext()); + blockArgs.push_back(newArg); + + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) { + return failure(); + } + } + + // 4. Parse the Region + // We explicitly pass the blockArgs we just parsed so they become the entry + // block! + if (parser.parseRegion(region, blockArgs)) { + return failure(); + } + + return success(); +} + +static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, + Region& region, OperandRange targetsIn) { + printer << "("; + if (region.empty()) { + printer << ") "; + printer.printRegion(region, false); + return; + } + Block& entryBlock = region.front(); + + const auto numTargets = targetsIn.size(); + for (unsigned i = 0; i < numTargets; ++i) { + if (i > 0) { + printer << ", "; + } + printer.printOperand(entryBlock.getArgument(i)); + printer << " = "; + printer.printOperand(targetsIn[i]); + } + printer << ") "; + + printer.printRegion(region, false); +} + //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// @@ -32,11 +124,13 @@ void QCODialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/QCO/IR/QCOOpsTypes.cpp.inc" + >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/QCO/IR/QCOOps.cpp.inc" + >(); } diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index 50f95e4cf..bf3838d9e 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -7,7 +7,9 @@ # Licensed under the MIT License add_subdirectory(pipeline) +add_subdirectory(Dialect) add_custom_target(mqt-core-mlir-unittests) -add_dependencies(mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test) +add_dependencies(mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test + mqt-core-mlir-dialect-qco-ir-modifiers-test) diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt new file mode 100644 index 000000000..6ced278d0 --- /dev/null +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_subdirectory(QCO) diff --git a/mlir/unittests/Dialect/QCO/CMakeLists.txt b/mlir/unittests/Dialect/QCO/CMakeLists.txt new file mode 100644 index 000000000..b181a84fe --- /dev/null +++ b/mlir/unittests/Dialect/QCO/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_subdirectory(IR) diff --git a/mlir/unittests/Dialect/QCO/IR/CMakeLists.txt b/mlir/unittests/Dialect/QCO/IR/CMakeLists.txt new file mode 100644 index 000000000..0a87c6bc8 --- /dev/null +++ b/mlir/unittests/Dialect/QCO/IR/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_subdirectory(Modifiers) diff --git a/mlir/unittests/Dialect/QCO/IR/Modifiers/CMakeLists.txt b/mlir/unittests/Dialect/QCO/IR/Modifiers/CMakeLists.txt new file mode 100644 index 000000000..657f49c02 --- /dev/null +++ b/mlir/unittests/Dialect/QCO/IR/Modifiers/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_executable(mqt-core-mlir-dialect-qco-ir-modifiers-test test_qco_ctrl.cpp) + +target_link_libraries( + mqt-core-mlir-dialect-qco-ir-modifiers-test + PRIVATE GTest::gtest_main + MQT::CoreIR + MLIRQCOProgramBuilder + MLIRFuncDialect + MLIRIR + MLIRParser + MLIRSupport + LLVMSupport) + +gtest_discover_tests(mqt-core-mlir-dialect-qco-ir-modifiers-test) diff --git a/mlir/unittests/Dialect/QCO/IR/Modifiers/test_qco_ctrl.cpp b/mlir/unittests/Dialect/QCO/IR/Modifiers/test_qco_ctrl.cpp new file mode 100644 index 000000000..91ed0d868 --- /dev/null +++ b/mlir/unittests/Dialect/QCO/IR/Modifiers/test_qco_ctrl.cpp @@ -0,0 +1,223 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::qco; + +class QCOCtrlOpTest : public ::testing::Test { +protected: + MLIRContext context; + QCOProgramBuilder builder; + OwningOpRef module; + + QCOCtrlOpTest() : builder(&context) {} + + void SetUp() override { + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + // Setup Builder + builder.initialize(); + } + + OwningOpRef testParse(const StringRef ctrlOpAssembly) { + // Wrap the op in a function to provide operands + const std::string source = + (Twine("func.func @test(%q0: !qco.qubit, %q1: !qco.qubit) {\n") + + ctrlOpAssembly + "\n" + " return\n" + "}") + .str(); + const ScopedDiagnosticHandler diagHandler(&context); + return parseSourceString(source, &context); + }; +}; + +TEST_F(QCOCtrlOpTest, LambdaBuilder) { + // Allocate qubits to use as operands + const auto q = builder.allocQubitRegister(3); + + // Create CtrlOp using the lambda builder + builder.ctrl(q[0], {q[1], q[2]}, + [&](ValueRange innerTargets) -> SmallVector { + // Create the inner operation + auto [q0, q1] = builder.swap(innerTargets[0], innerTargets[1]); + return {q0, q1}; + }); + auto ctrlOp = cast(builder.getBlock()->getOperations().back()); + module = builder.finalize(); + + // Verify the operation structure + EXPECT_EQ(ctrlOp.getNumControls(), 1); + EXPECT_EQ(ctrlOp.getNumTargets(), 2); + EXPECT_EQ(ctrlOp.getResults().size(), 3); // 1 control out + 2 targets out + + // Verify operation + ASSERT_TRUE(mlir::verify(ctrlOp).succeeded()); +} + +TEST_F(QCOCtrlOpTest, UnitaryOpBuilder) { + // Allocate qubits + const auto q = builder.allocQubitRegister(2); + + // Create a template unitary operation (X gate) + auto xOp = XOp::create(builder, builder.getUnknownLoc(), q[1]); + + // Create CtrlOp using the UnitaryOpInterface builder + auto ctrlOp = CtrlOp::create(builder, builder.getUnknownLoc(), q[0], q[1], + cast(xOp.getOperation())); + + // Erase the template op so it doesn't consume q[1] in the main block + xOp.erase(); + + // Verify structure + EXPECT_EQ(ctrlOp.getNumControls(), 1); + EXPECT_EQ(ctrlOp.getNumTargets(), 1); + EXPECT_EQ(ctrlOp.getResults().size(), 2); // 1 control out + 1 target out + + // Verify operation + EXPECT_TRUE(mlir::verify(ctrlOp).succeeded()); +} + +TEST_F(QCOCtrlOpTest, VerifierBodySize) { + const auto q = builder.allocQubitRegister(2); + + // Create valid CtrlOp + builder.ctrl(q[0], q[1], [&](ValueRange innerTargets) -> SmallVector { + return {builder.x(innerTargets[0])}; + }); + auto ctrlOp = cast(builder.getBlock()->getOperations().back()); + + // Insert an extra operation into the body + auto& region = ctrlOp.getRegion(); + auto& block = region.front(); + + { + const OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(&block.back()); // Before Yield + // We can insert another XOp + XOp::create(builder, builder.getUnknownLoc(), block.getArgument(0)); + } + module = builder.finalize(); + + // Should fail because body must have exactly 2 operations + EXPECT_TRUE(mlir::verify(ctrlOp).failed()); +} + +TEST_F(QCOCtrlOpTest, VerifierBlockArgsCount) { + const auto q = builder.allocQubitRegister(2); + + // Create valid CtrlOp + builder.ctrl(q[0], q[1], [&](ValueRange innerTargets) -> SmallVector { + return {builder.x(innerTargets[0])}; + }); + auto ctrlOp = cast(builder.getBlock()->getOperations().back()); + module = builder.finalize(); + + // Add an extra argument to the block + auto& region = ctrlOp.getRegion(); + auto& block = region.front(); + const auto qType = QubitType::get(&context); + block.addArgument(qType, builder.getUnknownLoc()); + + // Should fail because number of block args must match number of targets (1) + EXPECT_TRUE(mlir::verify(ctrlOp).failed()); +} + +TEST_F(QCOCtrlOpTest, VerifierInputTypes) { + const auto q = builder.allocQubitRegister(2); + + // Create valid CtrlOp + builder.ctrl(q[0], q[1], [&](ValueRange innerTargets) -> SmallVector { + return {builder.x(innerTargets[0])}; + }); + auto ctrlOp = cast(builder.getBlock()->getOperations().back()); + module = builder.finalize(); + + // Change the block argument type to a non-qubit + auto& region = ctrlOp.getRegion(); + auto& block = region.front(); + block.getArgument(0).setType(builder.getI1Type()); + + EXPECT_TRUE(mlir::verify(ctrlOp).failed()); +} + +TEST_F(QCOCtrlOpTest, VerifierBodyFirstOp) { + const auto q = builder.allocQubitRegister(2); + + // Create valid CtrlOp + builder.ctrl(q[0], q[1], [&](ValueRange innerTargets) -> SmallVector { + return {builder.reset(innerTargets[0])}; + }); + auto ctrlOp = cast(builder.getBlock()->getOperations().back()); + module = builder.finalize(); + + // Should fail because body must use a unitary as first operation + EXPECT_TRUE(mlir::verify(ctrlOp).failed()); +} + +TEST_F(QCOCtrlOpTest, ParserErrors) { + // 1. Missing opening parenthesis for targets + EXPECT_EQ( + testParse( + "qco.ctrl(%q0) targets %a = %q1) { qco.yield %a } : ({!qco.qubit}, " + "{!qco.qubit}) -> ({!qco.qubit}, {!qco.qubit})") + .get(), + nullptr); + + // 2. Missing argument name + EXPECT_EQ( + testParse( + "qco.ctrl(%q0) targets ( = %q1) { qco.yield %q1 } : ({!qco.qubit}, " + "{!qco.qubit}) -> ({!qco.qubit}, {!qco.qubit})") + .get(), + nullptr); + + // 3. Missing equals sign + EXPECT_EQ( + testParse( + "qco.ctrl(%q0) targets (%a %q1) { qco.yield %a } : ({!qco.qubit}, " + "{!qco.qubit}) -> ({!qco.qubit}, {!qco.qubit})") + .get(), + nullptr); + + // 4. Missing operand (old value) + EXPECT_EQ( + testParse( + "qco.ctrl(%q0) targets (%a = ) { qco.yield %a } : ({!qco.qubit}, " + "{!qco.qubit}) -> ({!qco.qubit}, {!qco.qubit})") + .get(), + nullptr); + + // 5. Missing closing parenthesis + EXPECT_EQ( + testParse( + "qco.ctrl(%q0) targets (%a = %q1 { qco.yield %a } : ({!qco.qubit}, " + "{!qco.qubit}) -> ({!qco.qubit}, {!qco.qubit})") + .get(), + nullptr); +}