diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index cd9ee2d31a..2ac3b5a5bd 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -138,6 +138,7 @@ * Dynamically allocated wires can now be passed into control flow and subroutines. [(#2130)](https://github.com/PennyLaneAI/catalyst/pull/2130) + [(#2268)](https://github.com/PennyLaneAI/catalyst/pull/2268) * The `--adjoint-lowering` pass can now handle PPR operations. [(#2227)](https://github.com/PennyLaneAI/catalyst/pull/2227) diff --git a/frontend/catalyst/from_plxpr/qubit_handler.py b/frontend/catalyst/from_plxpr/qubit_handler.py index b7a54b76f2..e17180d494 100644 --- a/frontend/catalyst/from_plxpr/qubit_handler.py +++ b/frontend/catalyst/from_plxpr/qubit_handler.py @@ -436,21 +436,22 @@ def _get_dynamically_allocated_qregs(plxpr_invals, qubit_index_recorder, init_qr Get the potential dynamically allocated register values that are visible to a jaxpr. Note that dynamically allocated wires have their qreg tracer's id as the global wire index - so the sub jaxpr takes that id in as a "const", since it is closure from the target wire - of gates/measurements/... + so the sub jaxpr takes that id in as a "const" (if it is one; as opposed to tracers), + since it is closure from the target wire of gates/measurements/... + We need to remove that const, so we also let this util return these global indices. """ dynalloced_qregs = [] dynalloced_wire_global_indices = [] for inval in plxpr_invals: - if ( - isinstance(inval, int) - and qubit_index_recorder.contains(inval) - and qubit_index_recorder[inval] is not init_qreg - ): + if not type(inval) in [int, DynamicJaxprTracer]: + # don't care about invals that won't be wire indices + continue + if qubit_index_recorder.contains(inval) and qubit_index_recorder[inval] is not init_qreg: dyn_qreg = qubit_index_recorder[inval] dyn_qreg.insert_all_dangling_qubits() dynalloced_qregs.append(dyn_qreg) - dynalloced_wire_global_indices.append(inval) + if isinstance(inval, int): + dynalloced_wire_global_indices.append(inval) return dynalloced_qregs, dynalloced_wire_global_indices diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index 4674c04a09..d5a845926e 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -434,6 +434,80 @@ def circuit(): assert np.allclose(expected, observed) +@pytest.mark.usefixtures("use_capture") +def test_subroutine_and_loop(backend): + """ + Test passing dynamically allocated wires into a subroutine with loops. + """ + + @subroutine + def flip(wire, theta): + """ + Apply three X gates to the input wire, effectively NOT-ing it. + """ + + @qml.for_loop(0, 3, 1) + def loop(i, _theta): # pylint: disable=unused-argument + qml.X(wire) + return jnp.sin(_theta) + + _ = loop(theta) + + @qjit + @qml.qnode(qml.device(backend, wires=1)) + def circuit(): + with qml.allocate(1) as q1: + flip(q1[0], 0.0) + qml.CNOT(wires=[q1[0], 0]) + return qml.probs(wires=[0]) + + observed = circuit() + expected = [0, 1] + assert np.allclose(expected, observed) + + +@pytest.mark.usefixtures("use_capture") +def test_subroutine_and_loop_multiple_args(backend): + """ + Test passing dynamically allocated wires into a subroutine with loops and multiple arguments. + """ + + @subroutine + def flip(w1, w2, w3, theta): + @qml.for_loop(0, 2, 1) + def loop(i, _theta): # pylint: disable=unused-argument + qml.X(w1) + qml.Y(w2) + qml.Z(w3) + qml.ctrl(qml.RX, (w1, w2))(_theta, wires=0) + qml.ctrl(qml.RY, (w2, w3))(_theta, wires=1) + return jnp.sin(_theta) + + _ = loop(theta) + + @qjit + @qml.qnode(qml.device(backend, wires=2)) + def circuit(): + with qml.allocate(2) as q1: + with qml.allocate(3) as q2: + flip(q1[0], q1[1], q2[2], 1.23) + + return qml.probs(wires=[0, 1]) + + @qml.qnode(qml.device("default.qubit", wires=7)) + def ref_circuit(): + for _ in range(2): + qml.X(0) + qml.Y(1) + qml.Z(2) + qml.ctrl(qml.RX, (0, 1))(1.23, wires=3) + qml.ctrl(qml.RY, (1, 2))(1.23, wires=4) + + return qml.probs(wires=[3, 4]) + + assert np.allclose(circuit(), ref_circuit()) + + def test_no_capture(backend): """ Test error message when used without capture.