From da53c58c4190fa005ead5d4aa53b0dfdde806246 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 5 Dec 2025 11:38:05 -0500 Subject: [PATCH 1/7] fix for subroutine + for loop + qubit allocation --- frontend/catalyst/from_plxpr/qubit_handler.py | 14 ++++---- .../pytest/test_dynamic_qubit_allocation.py | 32 +++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/frontend/catalyst/from_plxpr/qubit_handler.py b/frontend/catalyst/from_plxpr/qubit_handler.py index b7a54b76f2..f9999677be 100644 --- a/frontend/catalyst/from_plxpr/qubit_handler.py +++ b/frontend/catalyst/from_plxpr/qubit_handler.py @@ -436,21 +436,19 @@ def _get_dynamically_allocated_qregs(plxpr_invals, qubit_index_recorder, init_qr Get the potential dynamically allocated register values that are visible to a jaxpr. Note that dynamically allocated wires have their qreg tracer's id as the global wire index - so the sub jaxpr takes that id in as a "const", since it is closure from the target wire - of gates/measurements/... + so the sub jaxpr takes that id in as a "const" (if it is one; as opposed to tracers), + since it is closure from the target wire of gates/measurements/... + We need to remove that const, so we also let this util return these global indices. """ dynalloced_qregs = [] dynalloced_wire_global_indices = [] for inval in plxpr_invals: - if ( - isinstance(inval, int) - and qubit_index_recorder.contains(inval) - and qubit_index_recorder[inval] is not init_qreg - ): + if qubit_index_recorder.contains(inval) and qubit_index_recorder[inval] is not init_qreg: dyn_qreg = qubit_index_recorder[inval] dyn_qreg.insert_all_dangling_qubits() dynalloced_qregs.append(dyn_qreg) - dynalloced_wire_global_indices.append(inval) + if isinstance(inval, int): + dynalloced_wire_global_indices.append(inval) return dynalloced_qregs, dynalloced_wire_global_indices diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index 4674c04a09..50a75a3d97 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -434,6 +434,38 @@ def circuit(): assert np.allclose(expected, observed) +@pytest.mark.usefixtures("use_capture") +def test_subroutine_and_loop(backend): + """ + Test passing dynamically allocated wires into a subroutine with loops. + """ + + @subroutine + def flip(wire, theta): + """ + Apply three X gates to the input wire, effectively NOT-ing it. + """ + + @qml.for_loop(0, 3, 1) + def loop_rx(i, _theta): + qml.X(wire) + return jnp.sin(_theta) + + _ = loop_rx(theta) + + @qjit + @qml.qnode(qml.device(backend, wires=1)) + def circuit(): + with qml.allocate(1) as q1: + flip(q1[0], 0.0) + qml.CNOT(wires=[q1[0], 0]) + return qml.probs(wires=[0]) + + observed = circuit() + expected = [0, 1] + assert np.allclose(expected, observed) + + def test_no_capture(backend): """ Test error message when used without capture. From 681d28cc5c09a67f0b11d47b1c8c3ea1e6892f88 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 5 Dec 2025 12:05:16 -0500 Subject: [PATCH 2/7] test with many args --- .../pytest/test_dynamic_qubit_allocation.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index 50a75a3d97..7008e8c642 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -466,6 +466,46 @@ def circuit(): assert np.allclose(expected, observed) +@pytest.mark.usefixtures("use_capture") +def test_subroutine_and_loop_multiple_args(backend): + """ + Test passing dynamically allocated wires into a subroutine with loops and multiple arguments. + """ + + @subroutine + def flip(w1, w2, w3, theta): + @qml.for_loop(0, 2, 1) + def loop_rx(i, _theta): + qml.X(w1) + qml.Y(w2) + qml.Z(w3) + qml.ctrl(qml.RX, (w1, w2))(_theta, wires=0) + return jnp.sin(_theta) + + _ = loop_rx(theta) + + @qjit + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + with qml.allocate(2) as q1: + with qml.allocate(3) as q2: + flip(q1[0], q1[1], q2[2], 1.23) + + return qml.expval(qml.Z(0)) + + @qml.qnode(qml.device("default.qubit", wires=6)) + def ref_circuit(): + for i in range(2): + qml.X(0) + qml.Y(1) + qml.Z(2) + qml.ctrl(qml.RX, (0, 1))(1.23, wires=3) + + return qml.expval(qml.Z(3)) + + assert np.allclose(circuit(), ref_circuit()) + + def test_no_capture(backend): """ Test error message when used without capture. From 892e44eb8d7db0e339eb26ad6a604058604b4b80 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 5 Dec 2025 12:08:44 -0500 Subject: [PATCH 3/7] test with multiple args --- frontend/test/pytest/test_dynamic_qubit_allocation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index 7008e8c642..2747da57ad 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -480,28 +480,30 @@ def loop_rx(i, _theta): qml.Y(w2) qml.Z(w3) qml.ctrl(qml.RX, (w1, w2))(_theta, wires=0) + qml.ctrl(qml.RY, (w2, w3))(_theta, wires=1) return jnp.sin(_theta) _ = loop_rx(theta) @qjit - @qml.qnode(qml.device("lightning.qubit", wires=1)) + @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(): with qml.allocate(2) as q1: with qml.allocate(3) as q2: flip(q1[0], q1[1], q2[2], 1.23) - return qml.expval(qml.Z(0)) + return qml.probs(wires=[0, 1]) - @qml.qnode(qml.device("default.qubit", wires=6)) + @qml.qnode(qml.device("default.qubit", wires=7)) def ref_circuit(): for i in range(2): qml.X(0) qml.Y(1) qml.Z(2) qml.ctrl(qml.RX, (0, 1))(1.23, wires=3) + qml.ctrl(qml.RY, (1, 2))(1.23, wires=4) - return qml.expval(qml.Z(3)) + return qml.probs(wires=[3, 4]) assert np.allclose(circuit(), ref_circuit()) From 622ddddf7ee289f58565566b9ea57f14862056fe Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 5 Dec 2025 12:13:22 -0500 Subject: [PATCH 4/7] codefactor --- frontend/test/pytest/test_dynamic_qubit_allocation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index 2747da57ad..97203a9c48 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -447,11 +447,11 @@ def flip(wire, theta): """ @qml.for_loop(0, 3, 1) - def loop_rx(i, _theta): + def loop(i, _theta): # pylint: disable=unused-argument qml.X(wire) return jnp.sin(_theta) - _ = loop_rx(theta) + _ = loop(theta) @qjit @qml.qnode(qml.device(backend, wires=1)) @@ -475,7 +475,7 @@ def test_subroutine_and_loop_multiple_args(backend): @subroutine def flip(w1, w2, w3, theta): @qml.for_loop(0, 2, 1) - def loop_rx(i, _theta): + def loop(i, _theta): # pylint: disable=unused-argument qml.X(w1) qml.Y(w2) qml.Z(w3) @@ -483,7 +483,7 @@ def loop_rx(i, _theta): qml.ctrl(qml.RY, (w2, w3))(_theta, wires=1) return jnp.sin(_theta) - _ = loop_rx(theta) + _ = loop(theta) @qjit @qml.qnode(qml.device("lightning.qubit", wires=2)) @@ -496,7 +496,7 @@ def circuit(): @qml.qnode(qml.device("default.qubit", wires=7)) def ref_circuit(): - for i in range(2): + for _ in range(2): qml.X(0) qml.Y(1) qml.Z(2) From 12f26ee669067fc26f2fd0514fbfd5afdb6991d0 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 5 Dec 2025 12:14:34 -0500 Subject: [PATCH 5/7] pylint again --- frontend/test/pytest/test_dynamic_qubit_allocation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index 97203a9c48..d5a845926e 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -486,7 +486,7 @@ def loop(i, _theta): # pylint: disable=unused-argument _ = loop(theta) @qjit - @qml.qnode(qml.device("lightning.qubit", wires=2)) + @qml.qnode(qml.device(backend, wires=2)) def circuit(): with qml.allocate(2) as q1: with qml.allocate(3) as q2: From 7cb17d099baae2aabcecb0c16c94bbf2d4d7be59 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 5 Dec 2025 12:23:23 -0500 Subject: [PATCH 6/7] skip invals that won't be wire indices --- frontend/catalyst/from_plxpr/qubit_handler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/frontend/catalyst/from_plxpr/qubit_handler.py b/frontend/catalyst/from_plxpr/qubit_handler.py index f9999677be..e17180d494 100644 --- a/frontend/catalyst/from_plxpr/qubit_handler.py +++ b/frontend/catalyst/from_plxpr/qubit_handler.py @@ -444,6 +444,9 @@ def _get_dynamically_allocated_qregs(plxpr_invals, qubit_index_recorder, init_qr dynalloced_qregs = [] dynalloced_wire_global_indices = [] for inval in plxpr_invals: + if not type(inval) in [int, DynamicJaxprTracer]: + # don't care about invals that won't be wire indices + continue if qubit_index_recorder.contains(inval) and qubit_index_recorder[inval] is not init_qreg: dyn_qreg = qubit_index_recorder[inval] dyn_qreg.insert_all_dangling_qubits() From 582e55621271512578fcf34109a1b9f949f53f76 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 5 Dec 2025 12:59:20 -0500 Subject: [PATCH 7/7] changelog --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index cd9ee2d31a..2ac3b5a5bd 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -138,6 +138,7 @@ * Dynamically allocated wires can now be passed into control flow and subroutines. [(#2130)](https://github.com/PennyLaneAI/catalyst/pull/2130) + [(#2268)](https://github.com/PennyLaneAI/catalyst/pull/2268) * The `--adjoint-lowering` pass can now handle PPR operations. [(#2227)](https://github.com/PennyLaneAI/catalyst/pull/2227)