diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4bdf910e6..1d3e52359 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,6 +10,10 @@

Bug fixes 🐛

+* Fixes the translation of plxpr control flow for edge cases where the `consts` were being + reordered. + [(#2128)](https://github.com/PennyLaneAI/catalyst/pull/2128) +

Internal changes ⚙️

Documentation 📝

diff --git a/frontend/catalyst/from_plxpr/control_flow.py b/frontend/catalyst/from_plxpr/control_flow.py index 45a255db3..663a55436 100644 --- a/frontend/catalyst/from_plxpr/control_flow.py +++ b/frontend/catalyst/from_plxpr/control_flow.py @@ -66,10 +66,10 @@ def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice # Store all branches consts in a flat list branch_consts = plxpr_invals[const_slice] - all_consts = all_consts + [*branch_consts] evaluator = partial(copy(self).eval, plxpr_branch, branch_consts) new_jaxpr = jax.make_jaxpr(evaluator)(*args) + all_consts = all_consts + new_jaxpr.consts converted_jaxpr_branches.append(new_jaxpr.jaxpr) @@ -99,15 +99,15 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): # Store all branches consts in a flat list branch_consts = plxpr_invals[const_slice] - all_consts = all_consts + [*branch_consts] converted_jaxpr_branch = None closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts) f = partial(_calling_convention, self, closed_jaxpr) - converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr + converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg) - converted_jaxpr_branches.append(converted_jaxpr_branch) + all_consts += converted_jaxpr_branch.consts + converted_jaxpr_branches.append(converted_jaxpr_branch.jaxpr) predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]] @@ -151,14 +151,16 @@ def workflow_for_loop( converter = copy(self) evaluator = partial(converter.eval, jaxpr_body_fn, consts) - converted_jaxpr_branch = jax.make_jaxpr(evaluator)(start, *args).jaxpr - converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ()) + converted_jaxpr_branch = jax.make_jaxpr(evaluator)(start, *args) + converted_closed_jaxpr_branch = ClosedJaxpr( + convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), () + ) # Config additional for loop settings apply_reverse_transform = isinstance(step, int) and step < 0 return for_p.bind( - *consts, + *converted_jaxpr_branch.consts, start, stop, step, @@ -202,12 +204,15 @@ def handle_for_loop( jaxpr = ClosedJaxpr(jaxpr_body_fn, consts) f = partial(_calling_convention, self, jaxpr) - converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg).jaxpr + converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg) - converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ()) + converted_closed_jaxpr_branch = ClosedJaxpr( + convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), () + ) # Build Catalyst compatible input values - for_loop_invals = [*consts, start, stop, step, *start_plus_args_plus_qreg] + new_consts = converted_jaxpr_branch.consts + for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg] # Config additional for loop settings apply_reverse_transform = isinstance(step, int) and step < 0 @@ -258,14 +263,14 @@ def workflow_while_loop( convert_constvars_jaxpr(new_cond_jaxpr.jaxpr), () ) # Build Catalyst compatible input values - while_loop_invals = [*consts_cond, *consts_body, *args] + while_loop_invals = [*new_cond_jaxpr.consts, *new_body_jaxpr.consts, *args] return while_p.bind( *while_loop_invals, cond_jaxpr=converted_cond_closed_jaxpr_branch, body_jaxpr=converted_body_closed_jaxpr_branch, - cond_nconsts=len(consts_cond), - body_nconsts=len(consts_body), + cond_nconsts=len(new_cond_jaxpr.consts), + body_nconsts=len(new_body_jaxpr.consts), nimplicit=0, preserve_dimensions=True, ) diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py index 7d1224b9e..f82a1791a 100644 --- a/frontend/test/pytest/from_plxpr/test_capture_integration.py +++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py @@ -1656,6 +1656,58 @@ def g(i, y): expected = 1.0 + jnp.cos(0) + jnp.cos(1) + jnp.cos(2) assert qml.math.allclose(res, expected) + # pylint: disable=unused-argument + def test_for_loop_consts(self): + """This tests for kinda a weird edge case bug where the consts where getting + reordered when translating the inner jaxpr.""" + + qml.capture.enable() + + @qml.qjit + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(x, n): + @qml.for_loop(3) + def outer(i): + + @qml.for_loop(n) + def inner(j): + qml.RY(x, wires=j) + + inner() + + outer() + + # Expected output: |100...> + return [qml.expval(qml.PauliZ(i)) for i in range(3)] + + res1, res2, res3 = circuit(0.2, 2) + + assert qml.math.allclose(res1, jnp.cos(0.2 * 3)) + assert qml.math.allclose(res2, jnp.cos(0.2 * 3)) + assert qml.math.allclose(res3, 1) + + # pylint: disable=unused-argument + def test_for_loop_consts_outside_qnode(self): + """Similar test as above for weird edge case, but not using a qnode.""" + + qml.capture.enable() + + @qml.qjit + def f(x, n): + @qml.for_loop(3) + def outer(i, a): + + @qml.for_loop(n) + def inner(j, b): + return b + x + + return inner(a) + + return outer(0.0) + + res = f(0.2, 2) + assert qml.math.allclose(res, 0.2 * 2 * 3) + def test_adjoint_transform_integration(): """Test that adjoint transforms can be used with capture enabled."""