Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

<h3>Bug fixes 🐛</h3>

* Fixes the translation of plxpr control flow for edge cases where the `consts` were being
reordered.
[(#2128)](https://github.com/PennyLaneAI/catalyst/pull/2128)

<h3>Internal changes ⚙️</h3>

<h3>Documentation 📝</h3>
Expand Down
31 changes: 18 additions & 13 deletions frontend/catalyst/from_plxpr/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice

# Store all branches consts in a flat list
branch_consts = plxpr_invals[const_slice]
all_consts = all_consts + [*branch_consts]

evaluator = partial(copy(self).eval, plxpr_branch, branch_consts)
new_jaxpr = jax.make_jaxpr(evaluator)(*args)
all_consts = all_consts + new_jaxpr.consts

converted_jaxpr_branches.append(new_jaxpr.jaxpr)

Expand Down Expand Up @@ -99,15 +99,15 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):

# Store all branches consts in a flat list
branch_consts = plxpr_invals[const_slice]
all_consts = all_consts + [*branch_consts]

converted_jaxpr_branch = None
closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts)

f = partial(_calling_convention, self, closed_jaxpr)
converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr
converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg)

converted_jaxpr_branches.append(converted_jaxpr_branch)
all_consts += converted_jaxpr_branch.consts
converted_jaxpr_branches.append(converted_jaxpr_branch.jaxpr)

predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]]

Expand Down Expand Up @@ -151,14 +151,16 @@ def workflow_for_loop(
converter = copy(self)
evaluator = partial(converter.eval, jaxpr_body_fn, consts)

converted_jaxpr_branch = jax.make_jaxpr(evaluator)(start, *args).jaxpr
converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ())
converted_jaxpr_branch = jax.make_jaxpr(evaluator)(start, *args)
converted_closed_jaxpr_branch = ClosedJaxpr(
convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ()
)

# Config additional for loop settings
apply_reverse_transform = isinstance(step, int) and step < 0

return for_p.bind(
*consts,
*converted_jaxpr_branch.consts,
start,
stop,
step,
Expand Down Expand Up @@ -202,12 +204,15 @@ def handle_for_loop(
jaxpr = ClosedJaxpr(jaxpr_body_fn, consts)

f = partial(_calling_convention, self, jaxpr)
converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg).jaxpr
converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg)

converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ())
converted_closed_jaxpr_branch = ClosedJaxpr(
convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ()
)

# Build Catalyst compatible input values
for_loop_invals = [*consts, start, stop, step, *start_plus_args_plus_qreg]
new_consts = converted_jaxpr_branch.consts
for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg]

# Config additional for loop settings
apply_reverse_transform = isinstance(step, int) and step < 0
Expand Down Expand Up @@ -258,14 +263,14 @@ def workflow_while_loop(
convert_constvars_jaxpr(new_cond_jaxpr.jaxpr), ()
)
# Build Catalyst compatible input values
while_loop_invals = [*consts_cond, *consts_body, *args]
while_loop_invals = [*new_cond_jaxpr.consts, *new_body_jaxpr.consts, *args]

return while_p.bind(
*while_loop_invals,
cond_jaxpr=converted_cond_closed_jaxpr_branch,
body_jaxpr=converted_body_closed_jaxpr_branch,
cond_nconsts=len(consts_cond),
body_nconsts=len(consts_body),
cond_nconsts=len(new_cond_jaxpr.consts),
body_nconsts=len(new_body_jaxpr.consts),
nimplicit=0,
preserve_dimensions=True,
)
Expand Down
52 changes: 52 additions & 0 deletions frontend/test/pytest/from_plxpr/test_capture_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,58 @@ def g(i, y):
expected = 1.0 + jnp.cos(0) + jnp.cos(1) + jnp.cos(2)
assert qml.math.allclose(res, expected)

# pylint: disable=unused-argument
def test_for_loop_consts(self):
"""This tests for kinda a weird edge case bug where the consts where getting
reordered when translating the inner jaxpr."""

qml.capture.enable()

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(x, n):
@qml.for_loop(3)
def outer(i):

@qml.for_loop(n)
def inner(j):
qml.RY(x, wires=j)

inner()

outer()

# Expected output: |100...>
return [qml.expval(qml.PauliZ(i)) for i in range(3)]

res1, res2, res3 = circuit(0.2, 2)

assert qml.math.allclose(res1, jnp.cos(0.2 * 3))
assert qml.math.allclose(res2, jnp.cos(0.2 * 3))
assert qml.math.allclose(res3, 1)

# pylint: disable=unused-argument
def test_for_loop_consts_outside_qnode(self):
"""Similar test as above for weird edge case, but not using a qnode."""

qml.capture.enable()

@qml.qjit
def f(x, n):
@qml.for_loop(3)
def outer(i, a):

@qml.for_loop(n)
def inner(j, b):
return b + x

return inner(a)

return outer(0.0)

res = f(0.2, 2)
assert qml.math.allclose(res, 0.2 * 2 * 3)


def test_adjoint_transform_integration():
"""Test that adjoint transforms can be used with capture enabled."""
Expand Down