diff --git a/.github/workflows/rc_sync.yaml b/.github/workflows/rc_sync.yaml index 9f8c8868c5..860bee3d5d 100644 --- a/.github/workflows/rc_sync.yaml +++ b/.github/workflows/rc_sync.yaml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 - + # Import Numpy/Pybind11 because it is imported in setup.py - name: Install Numpy and Pybind11 run: | diff --git a/doc/releases/changelog-0.13.0.md b/doc/releases/changelog-0.13.0.md index a66f6bdee2..66dd138ac2 100644 --- a/doc/releases/changelog-0.13.0.md +++ b/doc/releases/changelog-0.13.0.md @@ -9,10 +9,12 @@ Similar to PennyLane's behaviour, this experimental feature will fall back to the old system whenever the graph cannot find decomposition rules for all unsupported operators in the program, and a ``UserWarning`` is raised. - [(#2099)](https://github.com/PennyLaneAI/catalyst/pull/2099) - [(#2091)](https://github.com/PennyLaneAI/catalyst/pull/2091) - [(#2029)](https://github.com/PennyLaneAI/catalyst/pull/2029) [(#2001)](https://github.com/PennyLaneAI/catalyst/pull/2001) + [(#2029)](https://github.com/PennyLaneAI/catalyst/pull/2029) + [(#2068)](https://github.com/PennyLaneAI/catalyst/pull/2068) + [(#2091)](https://github.com/PennyLaneAI/catalyst/pull/2091) + [(#2099)](https://github.com/PennyLaneAI/catalyst/pull/2099) + * Catalyst now supports dynamic wire allocation with ``qml.allocate()`` and ``qml.deallocate()`` when program capture is enabled. diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py index 3f67d2b94b..7d1224b9e9 100644 --- a/frontend/test/pytest/from_plxpr/test_capture_integration.py +++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py @@ -1313,6 +1313,48 @@ def circuit(x: float, y: float, z: float): assert jnp.allclose(circuit(1.5, 2.5, 3.5), capture_result) + def test_transform_graph_decompose_workflow(self, backend): + """Test the integration for a circuit with a 'decompose' graph transform.""" + + # Capture enabled + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) + @qml.qnode(qml.device(backend, wires=2)) + def captured_circuit(x: float, y: float, z: float): + m = qml.measure(0) + + @qml.cond(m) + def cond_fn(): + qml.Rot(x, y, z, 0) + + cond_fn() + return qml.expval(qml.PauliZ(0)) + + capture_result = captured_circuit(1.5, 2.5, 3.5) + + qml.decomposition.disable_graph() + qml.capture.disable() + + # Capture disabled + @qjit + @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) + @qml.qnode(qml.device(backend, wires=2)) + def circuit(x: float, y: float, z: float): + m = catalyst.measure(0) + + @catalyst.cond(m) + def cond_fn(): + qml.Rot(x, y, z, 0) + + cond_fn() + return qml.expval(qml.PauliZ(0)) + + assert jnp.allclose(circuit(1.5, 2.5, 3.5), capture_result) + def test_transform_map_wires_workflow(self, backend): """Test the integration for a circuit with a 'map_wires' transform.""" diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 9dcc4ea1ad..7d39c1d9f1 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h" +#include "Quantum/IR/QuantumInterfaces.h" #include "Quantum/IR/QuantumOps.h" #include "Quantum/Transforms/Patterns.h" @@ -40,6 +41,7 @@ namespace quantum { /// - A runtime Value (for dynamic indices computed at runtime) /// - An IntegerAttr (for compile-time constant indices) /// - Invalid/uninitialized (represented by std::monostate) +/// And a qreg value to represent the qreg that the index belongs to /// /// The struct uses std::variant to ensure only one type is active at a time, /// preventing invalid states. @@ -54,17 +56,21 @@ namespace quantum { /// Value idx = dynamicIdx.getValue(); // Get the Value /// } /// } -struct QubitIndex { +class QubitIndex { + private: // use monostate to represent the invalid index std::variant index; + Value qreg; - QubitIndex() : index(std::monostate()) {} - QubitIndex(Value val) : index(val) {} - QubitIndex(IntegerAttr attr) : index(attr) {} + public: + QubitIndex() : index(std::monostate()), qreg(nullptr) {} + QubitIndex(Value val, Value qreg) : index(val), qreg(qreg) {} + QubitIndex(IntegerAttr attr, Value qreg) : index(attr), qreg(qreg) {} bool isValue() const { return std::holds_alternative(index); } bool isAttr() const { return std::holds_alternative(index); } operator bool() const { return isValue() || isAttr(); } + Value getReg() const { return qreg; } Value getValue() const { return isValue() ? std::get(index) : nullptr; } IntegerAttr getAttr() const { return isAttr() ? std::get(index) : nullptr; } }; @@ -76,25 +82,16 @@ class OpSignatureAnalyzer { public: OpSignatureAnalyzer() = delete; OpSignatureAnalyzer(CustomOp op, bool enableQregMode) - : signature(OpSignature{ - .params = op.getParams(), - .inQubits = op.getInQubits(), - .inCtrlQubits = op.getInCtrlQubits(), - .inCtrlValues = op.getInCtrlValues(), - .outQubits = op.getOutQubits(), - .outCtrlQubits = op.getOutCtrlQubits(), - }) + : signature(OpSignature{.params = op.getParams(), + .inQubits = op.getNonCtrlQubitOperands(), + .inCtrlQubits = op.getCtrlQubitOperands(), + .inCtrlValues = op.getCtrlValueOperands(), + .outQubits = op.getNonCtrlQubitResults(), + .outCtrlQubits = op.getCtrlQubitResults()}) { if (!enableQregMode) return; - signature.sourceQreg = getSourceQreg(signature.inQubits.front()); - if (!signature.sourceQreg) { - op.emitError("Cannot get source qreg"); - isValid = false; - return; - } - // input wire indices for (Value qubit : signature.inQubits) { const QubitIndex index = getExtractIndex(qubit); @@ -117,6 +114,9 @@ class OpSignatureAnalyzer { signature.inCtrlWireIndices.emplace_back(index); } + assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 && + "inWireIndices or inCtrlWireIndices should not be empty"); + // Output qubit indices are the same as input qubit indices signature.outQubitIndices = signature.inWireIndices; signature.outCtrlQubitIndices = signature.inCtrlWireIndices; @@ -124,6 +124,25 @@ class OpSignatureAnalyzer { operator bool() const { return isValid; } + Value getUpdatedQreg(PatternRewriter &rewriter, Location loc) + { + // FIXME: This will cause an issue when the decomposition function has cross-qreg + // inputs and outputs. Now, we just assume has only one qreg input, the global one exists. + // raise an error if the qreg is not the same + Value qreg = signature.inWireIndices[0].getReg(); + + bool sameQreg = true; + for (const auto &index : signature.inWireIndices) { + sameQreg &= index.getReg() == qreg; + } + for (const auto &index : signature.inCtrlWireIndices) { + sameQreg &= index.getReg() == qreg; + } + + assert(sameQreg && "The qreg of the input wires should be the same"); + return qreg; + } + // Prepare the operands for calling the decomposition function // There are two cases: // 1. The first input is a qreg, which means the decomposition function is a qreg mode function @@ -144,15 +163,8 @@ class OpSignatureAnalyzer { int operandIdx = 0; if (isa(funcInputs[0])) { - Value updatedQreg = signature.sourceQreg; - for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) { - const QubitIndex &index = signature.inWireIndices[i]; - updatedQreg = - rewriter.create(loc, updatedQreg.getType(), updatedQreg, - index.getValue(), index.getAttr(), qubit); - } + operands[operandIdx++] = getUpdatedQreg(rewriter, loc); - operands[operandIdx++] = updatedQreg; if (!signature.params.empty()) { auto [startIdx, endIdx] = findParamTypeRange(funcInputs, signature.params.size(), operandIdx); @@ -163,16 +175,12 @@ class OpSignatureAnalyzer { } } - if (!signature.inWireIndices.empty()) { - operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices, - funcInputs[operandIdx], rewriter, loc); - operandIdx++; - } - - if (!signature.inCtrlWireIndices.empty()) { - operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices, - funcInputs[operandIdx], rewriter, loc); - operandIdx++; + for (const auto &indices : {signature.inWireIndices, signature.inCtrlWireIndices}) { + if (!indices.empty()) { + operands[operandIdx] = + fromTensorOrAsIs(indices, funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } } } else { @@ -218,18 +226,16 @@ class OpSignatureAnalyzer { SmallVector newResults; rewriter.setInsertionPointAfter(callOp); - for (const QubitIndex &index : signature.outQubitIndices) { - auto extractOp = rewriter.create( - callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), - index.getAttr()); - newResults.emplace_back(extractOp.getResult()); - } - for (const QubitIndex &index : signature.outCtrlQubitIndices) { - auto extractOp = rewriter.create( - callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), - index.getAttr()); - newResults.emplace_back(extractOp.getResult()); + + for (const auto &indices : {signature.outQubitIndices, signature.outCtrlQubitIndices}) { + for (const auto &index : indices) { + auto extractOp = rewriter.create( + callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), + index.getAttr()); + newResults.emplace_back(extractOp.getResult()); + } } + return newResults; } @@ -245,7 +251,6 @@ class OpSignatureAnalyzer { ValueRange outCtrlQubits; // Qreg mode specific information - Value sourceQreg = nullptr; SmallVector inWireIndices; SmallVector inCtrlWireIndices; SmallVector outQubitIndices; @@ -333,39 +338,21 @@ class OpSignatureAnalyzer { return values.front(); } - Value getSourceQreg(Value qubit) - { - while (qubit) { - if (auto extractOp = qubit.getDefiningOp()) { - return extractOp.getQreg(); - } - - if (auto customOp = dyn_cast_or_null(qubit.getDefiningOp())) { - if (customOp.getQubitOperands().empty()) { - break; - } - qubit = customOp.getQubitOperands()[0]; - } - } - - return nullptr; - } - QubitIndex getExtractIndex(Value qubit) { while (qubit) { if (auto extractOp = qubit.getDefiningOp()) { if (Value idx = extractOp.getIdx()) { - return QubitIndex(idx); + return QubitIndex(idx, extractOp.getQreg()); } if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { - return QubitIndex(idxAttr); + return QubitIndex(idxAttr, extractOp.getQreg()); } } - if (auto customOp = dyn_cast_or_null(qubit.getDefiningOp())) { - auto qubitOperands = customOp.getQubitOperands(); - auto qubitResults = customOp.getQubitResults(); + if (auto gate = dyn_cast_or_null(qubit.getDefiningOp())) { + auto qubitOperands = gate.getQubitOperands(); + auto qubitResults = gate.getQubitResults(); auto it = llvm::find_if(qubitResults, [&](Value result) { return result == qubit; }); @@ -377,6 +364,10 @@ class OpSignatureAnalyzer { } } } + else if (auto measureOp = dyn_cast_or_null(qubit.getDefiningOp())) { + qubit = measureOp.getInQubit(); + continue; + } break; } @@ -394,7 +385,8 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern { DecomposeLoweringRewritePattern(MLIRContext *context, const llvm::StringMap ®istry, const llvm::StringSet &gateSet) - : OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet) + : OpRewritePattern(context), decompositionRegistry(registry), + targetGateSet(gateSet) { } @@ -421,11 +413,12 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern { assert(decompFunc.getFunctionType().getNumResults() >= 1 && "Decomposition function must have at least one result"); + rewriter.setInsertionPointAfter(op); + auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); auto analyzer = OpSignatureAnalyzer(op, enableQreg); assert(analyzer && "Analyzer should be valid"); - rewriter.setInsertionPointAfter(op); auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index 91bfbe7778..2a5b24eb54 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -508,3 +508,49 @@ module @cnot_alternative_decomposition { return %out_qubits_2#0, %out_qubits_4 : !quantum.bit, !quantum.bit } } + +// ----- + +module @mcm_example { + func.func public @test_mcm_hadamard() -> tensor<2xf64> { + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %mres, %out_qubit = quantum.measure %1 : i1, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubit : !quantum.reg, !quantum.bit + + // CHECK: [[RZ_QUBIT:%.+]] = quantum.custom "RZ"([[CST_0:%.+]]) + // CHECK: [[RY_QUBIT:%.+]] = quantum.custom "RY"([[CST_1:%.+]]) [[RZ_QUBIT]] : !quantum.bit + // CHECK: [[REG_1:%.+]] = quantum.insert [[REG:%.+]][[[EXTRACTED:%.+]]], [[RY_QUBIT]] : !quantum.reg, !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %3 = quantum.extract %2[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %3 : !quantum.bit + %4 = quantum.insert %2[ 0], %out_qubits : !quantum.reg, !quantum.bit + + %5 = quantum.compbasis qreg %4 : !quantum.obs + %6 = quantum.probs %5 : tensor<2xf64> + quantum.dealloc %4 : !quantum.reg + return %6 : tensor<2xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func public @rz_ry + func.func public @rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %0 = stablehlo.slice %arg1 [0:1] : (tensor<1xi64>) -> tensor<1xi64> + %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor + %extracted = tensor.extract %1[] : tensor + %2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "RZ"(%cst_0) %2 : !quantum.bit + %3 = stablehlo.slice %arg1 [0:1] : (tensor<1xi64>) -> tensor<1xi64> + %4 = stablehlo.reshape %3 : (tensor<1xi64>) -> tensor + %extracted_1 = tensor.extract %1[] : tensor + %5 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit + %extracted_2 = tensor.extract %4[] : tensor + %6 = quantum.extract %5[%extracted_2] : !quantum.reg -> !quantum.bit + %out_qubits_3 = quantum.custom "RY"(%cst) %6 : !quantum.bit + %extracted_4 = tensor.extract %4[] : tensor + %7 = quantum.insert %5[%extracted_4], %out_qubits_3 : !quantum.reg, !quantum.bit + return %7 : !quantum.reg + } +}