Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

<h3>Bug fixes 🐛</h3>

* Fixes the translation of plxpr control flow for edge cases where the `consts` were being
reordered.

<h3>Internal changes ⚙️</h3>

<h3>Documentation 📝</h3>
Expand Down
27 changes: 14 additions & 13 deletions frontend/catalyst/from_plxpr/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@

# Store all branches consts in a flat list
branch_consts = plxpr_invals[const_slice]
all_consts = all_consts + [*branch_consts]

evaluator = partial(copy(self).eval, plxpr_branch, branch_consts)
new_jaxpr = jax.make_jaxpr(evaluator)(*args)
all_consts = all_consts + new_jaxpr.consts

converted_jaxpr_branches.append(new_jaxpr.jaxpr)

Expand Down Expand Up @@ -99,15 +99,15 @@

# Store all branches consts in a flat list
branch_consts = plxpr_invals[const_slice]
all_consts = all_consts + [*branch_consts]

converted_jaxpr_branch = None
closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts)

f = partial(_calling_convention, self, closed_jaxpr)
converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr
converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg)

converted_jaxpr_branches.append(converted_jaxpr_branch)
all_consts += converted_jaxpr_branch.consts
converted_jaxpr_branches.append(converted_jaxpr_branch.jaxpr)

predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]]

Expand Down Expand Up @@ -151,14 +151,14 @@
converter = copy(self)
evaluator = partial(converter.eval, jaxpr_body_fn, consts)

converted_jaxpr_branch = jax.make_jaxpr(evaluator)(start, *args).jaxpr
converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ())
converted_jaxpr_branch = jax.make_jaxpr(evaluator)(start, *args)
converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ())

Check notice on line 155 in frontend/catalyst/from_plxpr/control_flow.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/control_flow.py#L155

Line too long (106/100) (line-too-long)

# Config additional for loop settings
apply_reverse_transform = isinstance(step, int) and step < 0

return for_p.bind(
*consts,
*converted_jaxpr_branch.consts,
start,
stop,
step,
Expand Down Expand Up @@ -202,12 +202,13 @@
jaxpr = ClosedJaxpr(jaxpr_body_fn, consts)

f = partial(_calling_convention, self, jaxpr)
converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg).jaxpr
converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg)

converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ())
converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ())

Check notice on line 207 in frontend/catalyst/from_plxpr/control_flow.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/control_flow.py#L207

Line too long (106/100) (line-too-long)

# Build Catalyst compatible input values
for_loop_invals = [*consts, start, stop, step, *start_plus_args_plus_qreg]
new_consts = converted_jaxpr_branch.consts
for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg]

# Config additional for loop settings
apply_reverse_transform = isinstance(step, int) and step < 0
Expand Down Expand Up @@ -258,14 +259,14 @@
convert_constvars_jaxpr(new_cond_jaxpr.jaxpr), ()
)
# Build Catalyst compatible input values
while_loop_invals = [*consts_cond, *consts_body, *args]
while_loop_invals = [*new_cond_jaxpr.consts, *new_body_jaxpr.consts, *args]

return while_p.bind(
*while_loop_invals,
cond_jaxpr=converted_cond_closed_jaxpr_branch,
body_jaxpr=converted_body_closed_jaxpr_branch,
cond_nconsts=len(consts_cond),
body_nconsts=len(consts_body),
cond_nconsts=len(new_cond_jaxpr.consts),
body_nconsts=len(new_body_jaxpr.consts),
nimplicit=0,
preserve_dimensions=True,
)
Expand Down
46 changes: 46 additions & 0 deletions frontend/test/pytest/from_plxpr/test_capture_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,52 @@
expected = 1.0 + jnp.cos(0) + jnp.cos(1) + jnp.cos(2)
assert qml.math.allclose(res, expected)

def test_for_loop_consts(self):
"""This tests for kinda a weird edge case bug where the consts where getting
reordered when translating the inner jaxpr."""

qml.capture.enable()

@qml.qjit
@qml.qnode(qml.device('lightning.qubit', wires=3))
def circuit(x, n):
@qml.for_loop(3)
def outer(i):

Check notice on line 1669 in frontend/test/pytest/from_plxpr/test_capture_integration.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/from_plxpr/test_capture_integration.py#L1669

Unused argument 'i' (unused-argument)

@qml.for_loop(n)
def inner(j):
qml.RY(x, wires=j)

inner()

outer()

# Expected output: |100...>
return [qml.expval(qml.PauliZ(i)) for i in range(3)]
res1, res2, res3 = circuit(0.2, 2)

Check notice on line 1682 in frontend/test/pytest/from_plxpr/test_capture_integration.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/from_plxpr/test_capture_integration.py#L1682

Trailing whitespace (trailing-whitespace)
assert qml.math.allclose(res1, jnp.cos(0.2 * 3))
assert qml.math.allclose(res2, jnp.cos(0.2 * 3))
assert qml.math.allclose(res3, 1)

def test_for_loop_consts_outside_qnode(self):
"""Similar test as above for weird edge case, but not using a qnode."""

qml.capture.enable()

@qml.qjit
def f(x, n):
@qml.for_loop(3)
def outer(i, a):

Check notice on line 1695 in frontend/test/pytest/from_plxpr/test_capture_integration.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/from_plxpr/test_capture_integration.py#L1695

Unused argument 'i' (unused-argument)

@qml.for_loop(n)
def inner(j, b):

Check notice on line 1698 in frontend/test/pytest/from_plxpr/test_capture_integration.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/from_plxpr/test_capture_integration.py#L1698

Unused argument 'j' (unused-argument)
return b + x
return inner(a)

return outer(0.0)
res = f(0.2, 2)
assert qml.math.allclose(res, 0.2 * 2 * 3)

def test_adjoint_transform_integration():
"""Test that adjoint transforms can be used with capture enabled."""
Expand Down
Loading