Skip to content

Commit 101c7a7

Browse files
authored
Support pytree inputs when program capture enabled (#2165)
**Context:** Change broken off from #2078, since it's basically it's own feature. **Description of the Change:** Just determine the inputs to `from_plxpr` from the plxpr itself. This means they will always match what's expected. **Benefits:** Always accurate with arbitrary nested arguments. **Possible Drawbacks:** Won't work with dynamically shaped inputs, but we don't support that yet anyway. **Related GitHub Issues:** [sc-102817]
1 parent 385db1b commit 101c7a7

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
:func:`~.qjit` with program capture enabled.
1313
[(#2154)](https://github.com/PennyLaneAI/catalyst/pull/2154)
1414

15+
* Pytree inputs can now be used when program capture is enabled.
16+
[(#2165)](https://github.com/PennyLaneAI/catalyst/pull/2165)
17+
1518
<h3>Breaking changes 💔</h3>
1619

1720
<h3>Deprecations 👋</h3>

frontend/catalyst/from_plxpr/from_plxpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def trace_from_pennylane(
399399
fn.static_argnums = static_argnums
400400

401401
plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs)
402-
jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs)
402+
jaxpr = from_plxpr(plxpr)(*plxpr.in_avals)
403403

404404
return jaxpr, out_type, out_treedef, sig
405405

frontend/test/pytest/test_pytree_args.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
import pennylane as qml
2323
import pytest
2424
from jax._src.tree_util import tree_flatten
25+
from pennylane import adjoint, cond, for_loop, grad, qjit
2526

26-
from catalyst import adjoint, cond, for_loop, grad, measure, qjit
27+
from catalyst import measure
2728

2829

2930
class TestPyTreesReturnValues:
3031
"""Test QJIT workflows with different return value data-types."""
3132

33+
@pytest.mark.usefixtures("use_both_frontend")
3234
def test_return_value_float(self, backend):
3335
"""Test constant."""
3436

@@ -41,10 +43,13 @@ def circuit1(params):
4143
jitted_fn = qjit(circuit1)
4244

4345
params = [0.4, 0.8]
44-
expected = 0.64170937
46+
expected = jnp.cos(params[0]) * jnp.cos(params[1])
4547
result = jitted_fn(params)
4648
assert jnp.allclose(result, expected)
4749

50+
def test_return_value_mcm(self, backend):
51+
"""Test that a qnode can return a scalar mcm."""
52+
4853
@qml.qnode(qml.device(backend, wires=2))
4954
def circuit2():
5055
return measure(0)
@@ -166,6 +171,7 @@ def _f(x):
166171
assert result[0] == 4.0
167172
assert result[1] == 6.0
168173

174+
@pytest.mark.usefixtures("use_both_frontend")
169175
def test_return_value_hybrid(self, backend):
170176
"""Test tuples."""
171177

@@ -239,6 +245,9 @@ def else_fn():
239245
def test_return_value_dict(self, backend, tol_stochastic, mcm_method):
240246
"""Test dictionaries."""
241247

248+
if mcm_method == "one-shot" and qml.capture.enabled():
249+
pytest.xfail()
250+
242251
@qml.qnode(qml.device(backend, wires=2))
243252
def circuit1(params):
244253
qml.RX(params[0], wires=0)
@@ -438,6 +447,7 @@ def circuit1(params):
438447
}
439448
result = jitted_fn(params)
440449

450+
@pytest.mark.usefixtures("use_both_frontend")
441451
def test_args_workflow(self, backend):
442452
"""Test arguments with workflows."""
443453

@@ -527,6 +537,7 @@ def workflow2(params):
527537
assert np.allclose(result_flatten, result_flatten_expected)
528538
assert tree == tree_expected
529539

540+
@pytest.mark.usefixtures("use_both_frontend")
530541
@pytest.mark.parametrize("inp", [(np.array([0.2, 0.5])), (jnp.array([0.2, 0.5]))])
531542
def test_args_control_flow(self, backend, inp):
532543
"""Test arguments with control-flows operations."""
@@ -574,6 +585,7 @@ def circuit(dictionary):
574585
result = circuit({"wire": 1})
575586
assert jnp.allclose(result, True)
576587

588+
@pytest.mark.usefixtures("use_both_frontend")
577589
def test_dev_wires_have_pytree(self, backend):
578590
"""Device wires are pytree-compatible."""
579591

@@ -592,6 +604,7 @@ def test_function():
592604
test_function()
593605

594606

607+
@pytest.mark.usefixtures("use_both_frontend")
595608
class TestAuxiliaryData:
596609
"""Test PyTrees with Auxiliary data."""
597610

0 commit comments

Comments
 (0)