diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4bdf910e6..1d3e52359 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -10,6 +10,10 @@
Bug fixes 🐛
+* Fixes the translation of plxpr control flow for edge cases where the `consts` were being
+ reordered.
+ [(#2128)](https://github.com/PennyLaneAI/catalyst/pull/2128)
+
Internal changes ⚙️
Documentation 📝
diff --git a/frontend/catalyst/from_plxpr/control_flow.py b/frontend/catalyst/from_plxpr/control_flow.py
index 45a255db3..663a55436 100644
--- a/frontend/catalyst/from_plxpr/control_flow.py
+++ b/frontend/catalyst/from_plxpr/control_flow.py
@@ -66,10 +66,10 @@ def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice
# Store all branches consts in a flat list
branch_consts = plxpr_invals[const_slice]
- all_consts = all_consts + [*branch_consts]
evaluator = partial(copy(self).eval, plxpr_branch, branch_consts)
new_jaxpr = jax.make_jaxpr(evaluator)(*args)
+ all_consts = all_consts + new_jaxpr.consts
converted_jaxpr_branches.append(new_jaxpr.jaxpr)
@@ -99,15 +99,15 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
# Store all branches consts in a flat list
branch_consts = plxpr_invals[const_slice]
- all_consts = all_consts + [*branch_consts]
converted_jaxpr_branch = None
closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts)
f = partial(_calling_convention, self, closed_jaxpr)
- converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr
+ converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg)
- converted_jaxpr_branches.append(converted_jaxpr_branch)
+ all_consts += converted_jaxpr_branch.consts
+ converted_jaxpr_branches.append(converted_jaxpr_branch.jaxpr)
predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]]
@@ -151,14 +151,16 @@ def workflow_for_loop(
converter = copy(self)
evaluator = partial(converter.eval, jaxpr_body_fn, consts)
- converted_jaxpr_branch = jax.make_jaxpr(evaluator)(start, *args).jaxpr
- converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ())
+ converted_jaxpr_branch = jax.make_jaxpr(evaluator)(start, *args)
+ converted_closed_jaxpr_branch = ClosedJaxpr(
+ convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ()
+ )
# Config additional for loop settings
apply_reverse_transform = isinstance(step, int) and step < 0
return for_p.bind(
- *consts,
+ *converted_jaxpr_branch.consts,
start,
stop,
step,
@@ -202,12 +204,15 @@ def handle_for_loop(
jaxpr = ClosedJaxpr(jaxpr_body_fn, consts)
f = partial(_calling_convention, self, jaxpr)
- converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg).jaxpr
+ converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg)
- converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ())
+ converted_closed_jaxpr_branch = ClosedJaxpr(
+ convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ()
+ )
# Build Catalyst compatible input values
- for_loop_invals = [*consts, start, stop, step, *start_plus_args_plus_qreg]
+ new_consts = converted_jaxpr_branch.consts
+ for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg]
# Config additional for loop settings
apply_reverse_transform = isinstance(step, int) and step < 0
@@ -258,14 +263,14 @@ def workflow_while_loop(
convert_constvars_jaxpr(new_cond_jaxpr.jaxpr), ()
)
# Build Catalyst compatible input values
- while_loop_invals = [*consts_cond, *consts_body, *args]
+ while_loop_invals = [*new_cond_jaxpr.consts, *new_body_jaxpr.consts, *args]
return while_p.bind(
*while_loop_invals,
cond_jaxpr=converted_cond_closed_jaxpr_branch,
body_jaxpr=converted_body_closed_jaxpr_branch,
- cond_nconsts=len(consts_cond),
- body_nconsts=len(consts_body),
+ cond_nconsts=len(new_cond_jaxpr.consts),
+ body_nconsts=len(new_body_jaxpr.consts),
nimplicit=0,
preserve_dimensions=True,
)
diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py
index 7d1224b9e..f82a1791a 100644
--- a/frontend/test/pytest/from_plxpr/test_capture_integration.py
+++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py
@@ -1656,6 +1656,58 @@ def g(i, y):
expected = 1.0 + jnp.cos(0) + jnp.cos(1) + jnp.cos(2)
assert qml.math.allclose(res, expected)
+ # pylint: disable=unused-argument
+ def test_for_loop_consts(self):
+ """This tests for kinda a weird edge case bug where the consts where getting
+ reordered when translating the inner jaxpr."""
+
+ qml.capture.enable()
+
+ @qml.qjit
+ @qml.qnode(qml.device("lightning.qubit", wires=3))
+ def circuit(x, n):
+ @qml.for_loop(3)
+ def outer(i):
+
+ @qml.for_loop(n)
+ def inner(j):
+ qml.RY(x, wires=j)
+
+ inner()
+
+ outer()
+
+ # Expected output: |100...>
+ return [qml.expval(qml.PauliZ(i)) for i in range(3)]
+
+ res1, res2, res3 = circuit(0.2, 2)
+
+ assert qml.math.allclose(res1, jnp.cos(0.2 * 3))
+ assert qml.math.allclose(res2, jnp.cos(0.2 * 3))
+ assert qml.math.allclose(res3, 1)
+
+ # pylint: disable=unused-argument
+ def test_for_loop_consts_outside_qnode(self):
+ """Similar test as above for weird edge case, but not using a qnode."""
+
+ qml.capture.enable()
+
+ @qml.qjit
+ def f(x, n):
+ @qml.for_loop(3)
+ def outer(i, a):
+
+ @qml.for_loop(n)
+ def inner(j, b):
+ return b + x
+
+ return inner(a)
+
+ return outer(0.0)
+
+ res = f(0.2, 2)
+ assert qml.math.allclose(res, 0.2 * 2 * 3)
+
def test_adjoint_transform_integration():
"""Test that adjoint transforms can be used with capture enabled."""