From 17c39cfeba36763887be3b644fc514c41eb9cb40 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 21 Nov 2025 17:49:38 -0500 Subject: [PATCH] switch to using experimental_capture qjit kwarg --- frontend/catalyst/from_plxpr/from_plxpr.py | 13 +- frontend/catalyst/jit.py | 5 +- frontend/catalyst/pipelines.py | 4 + frontend/test/pytest.ini | 3 + frontend/test/pytest/conftest.py | 6 +- .../from_plxpr/test_capture_integration.py | 394 ++++-------------- .../from_plxpr/test_decompose_transform.py | 36 +- .../test/pytest/from_plxpr/test_from_plxpr.py | 2 +- .../test_from_plxpr_qubit_handler.py | 2 + frontend/test/pytest/test_adjoint.py | 98 ++--- frontend/test/pytest/test_autograph.py | 25 +- .../pytest/test_dynamic_qubit_allocation.py | 87 ++-- frontend/test/pytest/test_jit_behaviour.py | 5 +- frontend/test/pytest/test_switch.py | 40 +- 14 files changed, 239 insertions(+), 481 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index d819d55a5e..4140b46769 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -25,7 +25,7 @@ import pennylane as qml from jax.extend.core import ClosedJaxpr, Jaxpr from jax.extend.linear_util import wrap_init -from pennylane.capture import PlxprInterpreter, qnode_prim +from pennylane.capture import PlxprInterpreter, qnode_prim, enable, disable, enabled from pennylane.capture.expand_transforms import ExpandTransformsInterpreter from pennylane.capture.primitives import jacobian_prim as pl_jac_prim from pennylane.capture.primitives import transform_prim @@ -414,9 +414,14 @@ def trace_from_pennylane( # https://github.com/jax-ml/jax/blob/636691bba40b936b8b64a4792c1d2158296e9dd4/jax/_src/linear_util.py#L231 # Therefore we need to coordinate them manually fn.static_argnums = static_argnums - - plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs) - jaxpr = from_plxpr(plxpr)(*plxpr.in_avals) + capture_on = enabled() + enable() + try: + plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs) + jaxpr = from_plxpr(plxpr)(*plxpr.in_avals) + finally: + if not capture_on: + disable() return jaxpr, out_type, out_treedef, sig diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 81f873917c..33c80e468e 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -77,6 +77,7 @@ def qjit( *, autograph=False, autograph_include=(), + experimental_capture=False, async_qnodes=False, target="binary", keep_intermediate=False, @@ -698,7 +699,7 @@ def jit_compile(self, args, **kwargs): def pre_compilation(self): """Perform pre-processing tasks on the Python function, such as AST transformations.""" if self.compile_options.autograph: - if qml.capture.enabled(): + if self.compile_options.experimental_capture: if self.compile_options.autograph_include: raise NotImplementedError( "capture autograph does not yet support autograph_include." @@ -731,7 +732,7 @@ def capture(self, args, **kwargs): dbg = debug_info("qjit_capture", self.user_function, args, kwargs) - if qml.capture.enabled(): + if self.compile_options.experimental_capture: with Patcher( ( jax._src.interpreters.partial_eval, # pylint: disable=protected-access diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index e5bbdc7287..4eae18d9a1 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -112,8 +112,12 @@ class CompileOptions: Default is ``None``. pass_plugins (Optional[Iterable[Path]]): List of paths to pass plugins. dialect_plugins (Optional[Iterable[Path]]): List of paths to dialect plugins. + experimental_capture (bool): If set to ``True``, + use PennyLane's experimental program capture capabilities + to capture the function for compilation. """ + experimental_capture : bool = False verbose: Optional[bool] = False logfile: Optional[TextIOWrapper] = sys.stderr target: Optional[str] = "binary" diff --git a/frontend/test/pytest.ini b/frontend/test/pytest.ini index 55eed915cf..73b2c427a3 100644 --- a/frontend/test/pytest.ini +++ b/frontend/test/pytest.ini @@ -1,2 +1,5 @@ [pytest] +markers = + capture_only: marks tests for capture only + old_frontend: tests for capture off xfail_strict=true \ No newline at end of file diff --git a/frontend/test/pytest/conftest.py b/frontend/test/pytest/conftest.py index 90ee23cc22..ee0c0d3940 100644 --- a/frontend/test/pytest/conftest.py +++ b/frontend/test/pytest/conftest.py @@ -62,15 +62,17 @@ def use_capture(): @pytest.fixture(scope="function") def use_capture_dgraph(): """Enable capture and graph-decomposition before and disable them both after the test.""" - qml.capture.enable() qml.decomposition.enable_graph() try: yield finally: qml.decomposition.disable_graph() - qml.capture.disable() +@pytest.fixture(params=[True, False], scope="function") +def experimental_capture(request): + yield request.param + @pytest.fixture(params=["capture", "no_capture"], scope="function") def use_both_frontend(request): """Runs the test once with capture enabled and once with it disabled.""" diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py index 61ae0e5417..9492ac24e2 100644 --- a/frontend/test/pytest/from_plxpr/test_capture_integration.py +++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py @@ -25,15 +25,13 @@ from catalyst import qjit from catalyst.from_plxpr import register_transform -pytestmark = pytest.mark.usefixtures("disable_capture") +pytestmark = pytest.mark.capture_only def circuit_aot_builder(dev): """Test AOT builder.""" - qml.capture.enable() - - @catalyst.qjit + @catalyst.qjit(experimental_capture=True) @qml.qnode(device=dev) def catalyst_circuit_aot(x: float): qml.Hadamard(wires=0) @@ -43,7 +41,6 @@ def catalyst_circuit_aot(x: float): qml.Hadamard(wires=1) return qml.expval(qml.PauliY(wires=0)) - qml.capture.disable() return catalyst_circuit_aot @@ -135,9 +132,7 @@ def test_simple_circuit(self, backend, theta, capture): dev = qml.device(backend, wires=2) - qml.capture.enable() - - @catalyst.qjit + @catalyst.qjit(experimental_capture=capture) @qml.qnode(device=dev) def captured_circuit(x): qml.Hadamard(wires=0) @@ -149,11 +144,6 @@ def captured_circuit(x): capture_result = captured_circuit(theta) - qml.capture.disable() - - if capture: - qml.capture.enable() - @qml.qnode(device=dev) def circuit(x): qml.Hadamard(wires=0) @@ -163,13 +153,6 @@ def circuit(x): qml.Hadamard(wires=1) return qml.expval(qml.PauliY(wires=0)) - if capture: - assert qml.capture.enabled() - else: - assert not qml.capture.enabled() - - qml.capture.disable() - assert jnp.allclose(capture_result, circuit(theta)) @pytest.mark.parametrize("theta", (jnp.pi, 0.1, 0.0)) @@ -178,10 +161,7 @@ def test_simple_workflow(self, backend, theta): dev = qml.device(backend, wires=2) # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(device=dev) def captured_circuit(x): qml.Hadamard(wires=0) @@ -193,11 +173,7 @@ def captured_circuit(x): capture_result = captured_circuit(theta**2) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(device=dev) def circuit(x): qml.Hadamard(wires=0) @@ -224,11 +200,8 @@ def test_basis_state(self, backend, n_wires, basis_state): """Test the integration for a circuit with BasisState.""" dev = qml.device(backend, wires=n_wires) - # Capture enabled - - qml.capture.enable() - @qjit + @qjit(experimental_capture=True) @qml.qnode(dev) def captured_circuit(_basis_state): qml.BasisState(_basis_state, wires=list(range(n_wires))) @@ -236,11 +209,7 @@ def captured_circuit(_basis_state): capture_result = captured_circuit(basis_state) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(dev) def circuit(_basis_state): qml.BasisState(_basis_state, wires=list(range(n_wires))) @@ -264,11 +233,7 @@ def test_state_prep(self, backend, n_wires, init_state): """Test the integration for a circuit with StatePrep.""" dev = qml.device(backend, wires=n_wires) - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(dev) def captured_circuit(init_state): qml.StatePrep(init_state, wires=list(range(n_wires))) @@ -276,11 +241,7 @@ def captured_circuit(init_state): capture_result = captured_circuit(init_state) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(dev) def circuit(init_state): qml.StatePrep(init_state, wires=list(range(n_wires))) @@ -293,11 +254,7 @@ def test_adjoint(self, backend, theta, val): """Test the integration for a circuit with adjoint.""" device = qml.device(backend, wires=2) - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(device) def captured_circuit(theta, val): qml.adjoint(qml.RY)(jnp.pi, val) @@ -306,8 +263,6 @@ def captured_circuit(theta, val): capture_result = captured_circuit(theta, val) - qml.capture.disable() - # Capture disabled @qml.qnode(device) @@ -323,11 +278,7 @@ def test_ctrl(self, backend, theta): """Test the integration for a circuit with control.""" device = qml.device(backend, wires=3) - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(device) def captured_circuit(theta): qml.ctrl(qml.RX(theta, wires=0), control=[1], control_values=[False]) @@ -336,10 +287,6 @@ def captured_circuit(theta): capture_result = captured_circuit(theta) - qml.capture.disable() - - # Capture disabled - @qml.qnode(device) def circuit(theta): qml.ctrl(qml.RX(theta, wires=0), control=[1], control_values=[False]) @@ -362,9 +309,8 @@ def test_measure(self, backend, reset, op, expected): """ device = qml.device(backend, wires=1) - qml.capture.enable() - @qjit + @qjit(experimental_capture=True) @qml.qnode(device) def captured_circuit(): op(wires=0) @@ -373,8 +319,6 @@ def captured_circuit(): capture_result = captured_circuit() - qml.capture.disable() - assert jnp.allclose(capture_result, expected) def test_measure_postselect(self, backend): @@ -384,9 +328,7 @@ def test_measure_postselect(self, backend): """ device = qml.device(backend, wires=1) - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(device) def captured_circuit(): qml.H(wires=0) @@ -395,7 +337,6 @@ def captured_circuit(): capture_result = captured_circuit() - qml.capture.disable() expected_result = -1 @@ -405,11 +346,7 @@ def captured_circuit(): def test_forloop(self, backend, theta): """Test the integration for a circuit with a for loop.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=4)) def captured_circuit(x): @@ -424,8 +361,6 @@ def loop(i): capture_result = captured_circuit(theta) - qml.capture.disable() - # Capture disabled @qml.qnode(qml.device(backend, wires=4)) @@ -442,11 +377,7 @@ def circuit(x): def test_forloop_workflow(self, backend): """Test the integration for a circuit with a for loop primitive.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(n, x): @@ -462,10 +393,6 @@ def loop_rx(_, x): capture_result = captured_circuit(10, 0.3) - qml.capture.disable() - - # Capture disabled - @qjit @qml.qnode(qml.device(backend, wires=1)) def circuit(n, x): @@ -485,11 +412,8 @@ def loop_rx(_, x): def test_nested_loops(self, backend): """Test the integration for a circuit with a nested for loop primitive.""" - # Capture enabled - - qml.capture.enable() - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=4)) def captured_circuit(n): # Input state: equal superposition @@ -516,11 +440,7 @@ def inner(j): capture_result = captured_circuit(4) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=4)) def circuit(n): # Input state: equal superposition @@ -553,11 +473,8 @@ def inner(j): def test_while_loop_workflow(self, backend): """Test the integration for a circuit with a while_loop primitive.""" - # Capture enabled - - qml.capture.enable() - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def capturted_circuit(x: float): @@ -579,11 +496,7 @@ def loop_rx(a): capture_result_1_iteration = capturted_circuit(9) capture_result_0_iterations = capturted_circuit(11) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -608,12 +521,7 @@ def loop_rx(a): def test_while_loop_workflow_closure(self, backend): """Test the integration for a circuit with a while_loop primitive using a closure variable.""" - - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float, step: float): @@ -633,11 +541,7 @@ def loop_rx(a): capture_result = captured_circuit(0, 2) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float, step: float): @@ -660,11 +564,7 @@ def loop_rx(a): def test_while_loop_workflow_nested(self, backend): """Test the integration for a circuit with a nested while_loop primitive.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float, y: float): @@ -691,11 +591,7 @@ def inner_loop(b): capture_result = captured_circuit(0, 0) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float, y: float): @@ -725,11 +621,7 @@ def inner_loop(b): def test_cond_workflow_if_else(self, backend): """Test the integration for a circuit with a cond primitive with true and false branches.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -746,11 +638,7 @@ def ansatz_false(): capture_result = captured_circuit(0.1) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -770,11 +658,7 @@ def ansatz_false(): def test_cond_workflow_if(self, backend): """Test the integration for a circuit with a cond primitive with a true branch only.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -788,11 +672,7 @@ def ansatz_true(): capture_result = captured_circuit(1.5) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -810,11 +690,7 @@ def test_cond_workflow_with_custom_primitive(self, backend): """Test the integration for a circuit with a cond primitive containing a custom primitive.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -833,11 +709,7 @@ def ansatz_false(): capture_result = captured_circuit(0.1) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -860,11 +732,7 @@ def test_cond_workflow_with_abstract_measurement(self, backend): """Test the integration for a circuit with a cond primitive containing an abstract measurement.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -883,11 +751,8 @@ def ansatz_false(): capture_result = captured_circuit(0.1) - qml.capture.disable() - - # Capture disabled - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -910,11 +775,7 @@ def test_cond_workflow_with_simple_primitive(self, backend): """Test the integration for a circuit with a cond primitive containing an simple primitive.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -933,11 +794,7 @@ def ansatz_false(): capture_result = captured_circuit(0.1) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -959,11 +816,7 @@ def ansatz_false(): def test_cond_workflow_nested(self, backend): """Test the integration for a circuit with a nested cond primitive.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float, y: float): @@ -987,11 +840,8 @@ def branch_false(): capture_result = captured_circuit(0.1, 1.5) - qml.capture.disable() - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float, y: float): @@ -1019,11 +869,7 @@ def test_cond_workflow_operator(self, backend): """Test the integration for a circuit with a cond primitive returning an Operator.""" - # Capture enabled - - qml.capture.enable() - - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -1033,11 +879,7 @@ def captured_circuit(x: float): capture_result = captured_circuit(0.1) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -1047,7 +889,6 @@ def circuit(x: float): assert jnp.allclose(circuit(0.1), capture_result) - @pytest.mark.usefixtures("use_capture") def test_pass_with_options(self, backend): """Test the integration for a circuit with a pass that takes in options.""" @@ -1058,7 +899,7 @@ def my_pass(_tape, my_option=None, my_other_option=None): # pylint: disable=unu register_transform(my_pass, "my-pass", False) - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @partial(my_pass, my_option="my_option_value", my_other_option=False) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(): @@ -1074,11 +915,7 @@ def captured_circuit(): def test_transform_cancel_inverses_workflow(self, backend): """Test the integration for a circuit with a 'cancel_inverses' transform.""" - # Capture enabled - - qml.capture.enable() - - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @qml.transforms.cancel_inverses @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -1090,11 +927,7 @@ def captured_circuit(x: float): capture_result = captured_circuit(0.1) assert 'transform.apply_registered_pass "cancel-inverses"' in captured_circuit.mlir - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.transforms.cancel_inverses @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -1108,11 +941,7 @@ def circuit(x: float): def test_transform_merge_rotations_workflow(self, backend): """Test the integration for a circuit with a 'merge_rotations' transform.""" - # Capture enabled - - qml.capture.enable() - - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @qml.transforms.merge_rotations @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -1124,11 +953,7 @@ def captured_circuit(x: float): capture_result = captured_circuit(0.1) assert 'transform.apply_registered_pass "merge-rotations"' in captured_circuit.mlir - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.transforms.merge_rotations @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -1143,9 +968,6 @@ def test_chained_catalyst_transforms_workflow(self, backend): """Test the integration for a circuit with a combination of 'merge_rotations' and 'cancel_inverses' transforms.""" - # Capture enabled - - qml.capture.enable() @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(x: float): @@ -1158,6 +980,7 @@ def captured_circuit(x: float): captured_inverses_rotations = qjit( qml.transforms.cancel_inverses(qml.transforms.merge_rotations(captured_circuit)), target="mlir", + experimental_capture=True ) captured_inverses_rotations_result = captured_inverses_rotations(0.1) assert has_catalyst_transforms(captured_inverses_rotations.mlir) @@ -1165,13 +988,11 @@ def captured_circuit(x: float): captured_rotations_inverses = qjit( qml.transforms.merge_rotations(qml.transforms.cancel_inverses(captured_circuit)), target="mlir", + experimental_capture=True ) captured_rotations_inverses_result = captured_rotations_inverses(0.1) assert has_catalyst_transforms(captured_rotations_inverses.mlir) - qml.capture.disable() - - # Capture disabled @qml.qnode(qml.device(backend, wires=1)) def circuit(x: float): @@ -1183,9 +1004,9 @@ def circuit(x: float): inverses_rotations_result = qjit( qml.transforms.cancel_inverses(qml.transforms.merge_rotations(circuit)) - )(0.1) + , experimental_capture=False)(0.1) rotations_inverses_result = qjit( - qml.transforms.merge_rotations(qml.transforms.cancel_inverses(circuit)) + qml.transforms.merge_rotations(qml.transforms.cancel_inverses(circuit)), experimental_capture=False )(0.1) assert ( @@ -1200,11 +1021,7 @@ def test_transform_unitary_to_rot_workflow(self, backend): U = qml.Rot(1.0, 2.0, 3.0, wires=0) - # Capture enabled - - qml.capture.enable() - - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @qml.transforms.unitary_to_rot @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(U: ShapedArray([2, 2], float)): @@ -1214,11 +1031,7 @@ def captured_circuit(U: ShapedArray([2, 2], float)): capture_result = captured_circuit(U.matrix()) assert is_unitary_rotated(captured_circuit.mlir) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.transforms.unitary_to_rot @qml.qnode(qml.device(backend, wires=1)) def circuit(U: ShapedArray([2, 2], float)): @@ -1233,10 +1046,6 @@ def test_mixed_transforms_workflow(self, backend): U = qml.Rot(1.0, 2.0, 3.0, wires=0) - # Capture enabled - - qml.capture.enable() - @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(U: ShapedArray([2, 2], float)): qml.QubitUnitary(U, 0) @@ -1250,6 +1059,7 @@ def captured_circuit(U: ShapedArray([2, 2], float)): captured_inverses_unitary = qjit( qml.transforms.cancel_inverses(qml.transforms.unitary_to_rot(captured_circuit)), target="mlir", + experimental_capture=True ) captured_inverses_unitary_result = captured_inverses_unitary(U.matrix()) @@ -1265,6 +1075,7 @@ def captured_circuit(U: ShapedArray([2, 2], float)): captured_unitary_inverses = qjit( qml.transforms.unitary_to_rot(qml.transforms.cancel_inverses(captured_circuit)), target="mlir", + experimental_capture=True, ) captured_unitary_inverses_result = captured_unitary_inverses(U.matrix()) @@ -1275,9 +1086,6 @@ def captured_circuit(U: ShapedArray([2, 2], float)): assert 'quantum.custom "Hadamard"' not in capture_mlir assert is_unitary_rotated(capture_mlir) - qml.capture.disable() - - # Capture disabled @qml.qnode(qml.device(backend, wires=1)) def circuit(U: ShapedArray([2, 2], float)): @@ -1287,10 +1095,12 @@ def circuit(U: ShapedArray([2, 2], float)): return qml.expval(qml.PauliZ(0)) inverses_unitary_result = qjit( - qml.transforms.cancel_inverses(qml.transforms.unitary_to_rot(captured_circuit)) + qml.transforms.cancel_inverses(qml.transforms.unitary_to_rot(captured_circuit)), + experimental_capture=False )(U.matrix()) unitary_inverses_result = qjit( - qml.transforms.unitary_to_rot(qml.transforms.cancel_inverses(captured_circuit)) + qml.transforms.unitary_to_rot(qml.transforms.cancel_inverses(captured_circuit)), + experimental_capture=False )(U.matrix()) assert ( @@ -1303,11 +1113,8 @@ def circuit(U: ShapedArray([2, 2], float)): def test_transform_decompose_workflow(self, backend): """Test the integration for a circuit with a 'decompose' transform.""" - # Capture enabled - - qml.capture.enable() - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) @qml.qnode(qml.device(backend, wires=2)) def captured_circuit(x: float, y: float, z: float): @@ -1317,11 +1124,7 @@ def captured_circuit(x: float, y: float, z: float): capture_result = captured_circuit(1.5, 2.5, 3.5) assert is_rot_decomposed(captured_circuit.mlir) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=True) @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) @qml.qnode(qml.device(backend, wires=2)) def circuit(x: float, y: float, z: float): @@ -1335,10 +1138,9 @@ def test_transform_graph_decompose_workflow(self, backend): # Capture enabled - qml.capture.enable() qml.decomposition.enable_graph() - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) @qml.qnode(qml.device(backend, wires=2)) def captured_circuit(x: float, y: float, z: float): @@ -1354,10 +1156,9 @@ def cond_fn(): capture_result = captured_circuit(1.5, 2.5, 3.5) qml.decomposition.disable_graph() - qml.capture.disable() # Capture disabled - @qjit + @qjit(experimental_capture=False) @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) @qml.qnode(qml.device(backend, wires=2)) def circuit(x: float, y: float, z: float): @@ -1375,11 +1176,7 @@ def cond_fn(): def test_transform_single_qubit_fusion_workflow(self, backend): """Test the integration for a circuit with a 'single_qubit_fusion' transform.""" - # Capture enabled - - qml.capture.enable() - - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @qml.transforms.single_qubit_fusion @qml.qnode(qml.device(backend, wires=1)) def captured_circuit(): @@ -1394,11 +1191,7 @@ def captured_circuit(): assert is_single_qubit_fusion_applied(captured_circuit.mlir) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.transforms.single_qubit_fusion @qml.qnode(qml.device(backend, wires=1)) def circuit(): @@ -1414,11 +1207,7 @@ def circuit(): def test_transform_commute_controlled_workflow(self, backend): """Test the integration for a circuit with a 'commute_controlled' transform.""" - # Capture enabled - - qml.capture.enable() - - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @partial(qml.transforms.commute_controlled, direction="left") @qml.qnode(qml.device(backend, wires=3)) def captured_circuit(): @@ -1440,11 +1229,8 @@ def captured_circuit(): capture_mlir, 'quantum.custom "PauliX"', 'quantum.custom "CRX"' ) - qml.capture.disable() - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @partial(qml.transforms.commute_controlled, direction="left") @qml.qnode(qml.device(backend, wires=3)) def circuit(): @@ -1461,11 +1247,7 @@ def circuit(): def test_transform_merge_amplitude_embedding_workflow(self, backend): """Test the integration for a circuit with a 'merge_amplitude_embedding' transform.""" - # Capture enabled - - qml.capture.enable() - - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @qml.transforms.merge_amplitude_embedding @qml.qnode(qml.device(backend, wires=2)) def captured_circuit(): @@ -1476,11 +1258,7 @@ def captured_circuit(): capture_result = captured_circuit() assert is_amplitude_embedding_merged_and_decomposed(captured_circuit.mlir) - qml.capture.disable() - - # Capture disabled - - @qjit + @qjit(experimental_capture=False) @qml.transforms.merge_amplitude_embedding @qml.qnode(qml.device(backend, wires=2)) def circuit(): @@ -1493,16 +1271,12 @@ def circuit(): def test_shots_usage(self, backend): """Test the integration for a circuit using shots explicitly.""" - # Capture enabled - - qml.capture.enable() - # TODO: try set_shots after capture work is completed with pytest.warns( qml.exceptions.PennyLaneDeprecationWarning, match="shots on device is deprecated" ): - @qjit(target="mlir") + @qjit(target="mlir", experimental_capture=True) @qml.qnode(qml.device(backend, wires=2, shots=10)) def captured_circuit(): @qml.for_loop(0, 2, 1) @@ -1517,9 +1291,7 @@ def loop_0(i): capture_result = captured_circuit() assert "shots(%" in captured_circuit.mlir - qml.capture.disable() - - @qjit + @qjit(experimental_capture=False) @qml.set_shots(10) @qml.qnode(qml.device(backend, wires=2)) def circuit(): @@ -1537,10 +1309,8 @@ def loop_0(i): def test_static_variable_qnode(self, backend): """Test the integration for a circuit with a static variable.""" - qml.capture.enable() - # Basic test - @qjit(static_argnums=(0,)) + @qjit(static_argnums=(0,), experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit_1(x, y): qml.RX(x, wires=0) @@ -1554,7 +1324,7 @@ def captured_circuit_1(x, y): assert "%cst = arith.constant 2.0" not in captured_circuit_1_mlir # Test that qjit static_argnums takes precedence over the one on the qnode - @qjit(static_argnums=1) + @qjit(static_argnums=1, experimental_capture=True) @qml.qnode(qml.device(backend, wires=1), static_argnums=0) # should be ignored def captured_circuit_2(x, y): qml.RX(x, wires=0) @@ -1570,7 +1340,7 @@ def captured_circuit_2(x, y): assert jnp.allclose(result_1, result_2) # Test under a non qnode workflow function - @qjit(static_argnums=(0,)) + @qjit(static_argnums=(0,), experimental_capture=True) def workflow(x, y): @qml.qnode(qml.device(backend, wires=1)) def c(): @@ -1585,8 +1355,6 @@ def c(): assert "%cst = arith.constant 1.5" in captured_circuit_3_mlir assert 'quantum.custom "RX"(%cst)' in captured_circuit_3_mlir - qml.capture.disable() - class TestControlFlow: """Integration tests for control flow.""" @@ -1595,7 +1363,6 @@ class TestControlFlow: def test_for_loop_outside_qnode(self, reverse): """Test that a for loop outside qnode can be executed.""" - qml.capture.enable() if reverse: start, stop, step = 6, 0, -2 # 6, 4, 2 @@ -1607,7 +1374,7 @@ def c(x): qml.RX(x, 0) return qml.expval(qml.Z(0)) - @qml.qjit + @qml.qjit(experimental_capture=True) def f(i0): @qml.for_loop(start, stop, step) def g(i, x): @@ -1620,14 +1387,13 @@ def g(i, x): def test_while_loop(self): """Test that a outside a qnode can be executed.""" - qml.capture.enable() @qml.qnode(qml.device("lightning.qubit", wires=1)) def circuit(x): qml.RX(x, 0) return qml.expval(qml.Z(0)) - @qml.qjit + @qml.qjit(experimental_capture=True) def f(x): const = jnp.array([0, 1, 2]) @@ -1648,9 +1414,7 @@ def test_for_loop_consts(self): """This tests for kinda a weird edge case bug where the consts where getting reordered when translating the inner jaxpr.""" - qml.capture.enable() - - @qml.qjit + @qml.qjit(experimental_capture=True) @qml.qnode(qml.device("lightning.qubit", wires=3)) def circuit(x, n): @qml.for_loop(3) @@ -1677,9 +1441,7 @@ def inner(j): def test_for_loop_consts_outside_qnode(self): """Similar test as above for weird edge case, but not using a qnode.""" - qml.capture.enable() - - @qml.qjit + @qml.qjit(experimental_capture=True) def f(x, n): @qml.for_loop(3) def outer(i, a): @@ -1699,13 +1461,11 @@ def inner(j, b): def test_adjoint_transform_integration(): """Test that adjoint transforms can be used with capture enabled.""" - qml.capture.enable() - def f(x): qml.IsingXX(2 * x, wires=(0, 1)) qml.H(0) - @qml.qjit + @qml.qjit(experimental_capture=True) @qml.qnode(qml.device("lightning.qubit", wires=3)) def c(x): qml.adjoint(f)(x) @@ -1727,7 +1487,7 @@ def f(x, y): qml.RY(3 * y, wires=3) qml.RX(2 * x, wires=3) - @qml.qjit + @qml.qjit(experimental_capture=True) @qml.qnode(qml.device("lightning.qubit", wires=4)) def c(x, y): qml.X(1) @@ -1747,8 +1507,6 @@ def c(x, y): def test_different_static_argnums(): """Test that the same qnode can be called different times with different static argnums.""" - qml.capture.enable() - @qml.qnode(qml.device("lightning.qubit", wires=1), static_argnums=1) def c(x, pauli): if pauli == "X": @@ -1759,7 +1517,7 @@ def c(x, pauli): qml.RZ(x, 0) return qml.state() - @qml.qjit + @qml.qjit(experimental_capture=True) def w(x): return c(x, "X"), c(x, "Y"), c(x, "Z") diff --git a/frontend/test/pytest/from_plxpr/test_decompose_transform.py b/frontend/test/pytest/from_plxpr/test_decompose_transform.py index 811eaa9b1b..aed5714c20 100644 --- a/frontend/test/pytest/from_plxpr/test_decompose_transform.py +++ b/frontend/test/pytest/from_plxpr/test_decompose_transform.py @@ -24,6 +24,8 @@ from pennylane.wires import WiresLike +pytestmark = pytest.mark.capture_only + class TestGraphDecomposition: """Test the new graph-based decomposition integration with from_plxpr.""" @@ -31,7 +33,7 @@ class TestGraphDecomposition: def test_with_multiple_decomps_transforms(self): """Test that a circuit with multiple decompositions and transforms can be converted.""" - @qml.qjit(target="mlir") + @qml.qjit(target="mlir", experimental_capture=True) @partial( qml.transforms.decompose, gate_set={"RX", "RY"}, @@ -54,7 +56,7 @@ def circuit(x): def test_fallback_warnings(self): """Test the fallback to legacy decomposition system with warnings.""" - @qml.qjit + @qml.qjit(experimental_capture=True) @partial(qml.transforms.decompose, gate_set={qml.GlobalPhase}) @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(x): @@ -81,7 +83,7 @@ def circuit(): return qml.expval(qml.X(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -110,7 +112,7 @@ def circuit(): qml.CNOT(wires=[0, 1]) return qml.state() - qjited_circuit = qml.qjit(circuit) + qjited_circuit = qml.qjit(circuit, experimental_capture=True) expected = np.array([1, 0, 0, 1]) / np.sqrt(2) assert qml.math.allclose(qjited_circuit(), expected) @@ -165,7 +167,7 @@ def circuit(): return qml.expval(qml.Z(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -199,7 +201,7 @@ def circuit(x, y): y = 0.3 without_qjit = circuit(x, y) - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit(x, y)) expected_resources = qml.specs(circuit, level="device")(x, y)["resources"].gate_types @@ -221,7 +223,7 @@ def circuit(x, y, z): z = 0.2 without_qjit = circuit(x, y, z) - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit(x, y, z)) @@ -249,7 +251,7 @@ def circuit(): OSError, match="undefined symbol", # ___catalyst__qis__RotXZX ): - qml.qjit(circuit)() + qml.qjit(circuit, experimental_capture=True)() @pytest.mark.usefixtures("use_capture_dgraph") def test_ftqc_rotxzx(self): @@ -266,7 +268,7 @@ def circuit(): return qml.expval(qml.X(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -293,7 +295,7 @@ def circuit(): return qml.expval(qml.X(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -317,7 +319,7 @@ def circuit(): return qml.expval(qml.Z(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -342,7 +344,7 @@ def circuit(): return qml.expval(qml.Z(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) expected_resources = qml.specs(circuit, level="device")()["resources"].gate_types @@ -366,7 +368,7 @@ def circuit(): return qml.expval(qml.Z(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -390,7 +392,7 @@ def circuit(): return qml.expval(qml.Z(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -412,7 +414,7 @@ def circuit(): return qml.expval(qml.Z(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -438,7 +440,7 @@ def circuit(): return qml.expval(qml.PauliX(0)) without_qjit = circuit() - with_qjit = qml.qjit(circuit) + with_qjit = qml.qjit(circuit, experimental_capture=True) assert qml.math.allclose(without_qjit, with_qjit()) @@ -472,7 +474,7 @@ def circuit(): without_qjit = qml.transforms.decompose(circuit, gate_set={"RZ", "CNOT"}) with_qjit = qml.qjit( qml.transforms.decompose(circuit, gate_set={"RZ", "CNOT"}), autograph=True - ) + , experimental_capture=True) assert qml.math.allclose(without_qjit(), with_qjit()) diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py index 9090469f55..37af0faa94 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py @@ -34,7 +34,7 @@ while_p, ) -pytestmark = pytest.mark.usefixtures("disable_capture") +pytestmark = pytest.mark.capture_only def catalyst_execute_jaxpr(jaxpr): diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py b/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py index 4d57bdf15b..3f25b1f1cb 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py @@ -47,6 +47,8 @@ from catalyst.jax_primitives import AbstractQbit, AbstractQreg, qalloc_p, qextract_p from catalyst.utils.exceptions import CompileError +pytestmark = pytest.mark.capture_only + @pytest.fixture(autouse=True) def launch_empty_jaxpr_interpreter(): diff --git a/frontend/test/pytest/test_adjoint.py b/frontend/test/pytest/test_adjoint.py index dd1835c0d2..e573e925ff 100644 --- a/frontend/test/pytest/test_adjoint.py +++ b/frontend/test/pytest/test_adjoint.py @@ -26,7 +26,7 @@ from pennylane.ops.op_math.adjoint import Adjoint, AdjointOperation import catalyst -from catalyst import debug, measure, qjit +from catalyst import debug, measure # pylint: disable=too-many-lines,missing-class-docstring,missing-function-docstring,too-many-public-methods @@ -34,14 +34,14 @@ class TestCatalyst: """Integration tests for Catalyst adjoint functionality.""" - def verify_catalyst_adjoint_against_pennylane(self, quantum_func, device, *args): + def verify_catalyst_adjoint_against_pennylane(self, quantum_func, device, *args, experimental_capture=False): """ A helper function for verifying Catalyst's native adjoint against the behaviour of PennyLane's adjoint function. This is specialized to verifying the behaviour of a single function that has its adjoint computed. """ - @qjit + @qjit(experimental_capture=experimental_capture) @qml.qnode(device) def catalyst_workflow(*args): adjoint(quantum_func)(*args) @@ -52,18 +52,11 @@ def pennylane_workflow(*args): qml.adjoint(quantum_func)(*args) return qml.state() - capture_enabled = qml.capture.enabled() - qml.capture.disable() - try: - pl_res = pennylane_workflow(*args) - finally: - if capture_enabled: - qml.capture.enable() + pl_res = pennylane_workflow(*args) assert_allclose(catalyst_workflow(*args), pl_res) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_func(self, backend): + def test_adjoint_func(self, backend, experimental_capture): """Ensures that catalyst.adjoint accepts simple Python functions as argument. Makes sure that simple quantum gates are adjointed correctly.""" @@ -74,7 +67,7 @@ def func(): device = qml.device(backend, wires=2) - @qjit + @qjit(experimental_capture=experimental_capture) @qml.qnode(device) def C_workflow(): qml.PauliX(wires=0) @@ -93,13 +86,12 @@ def PL_workflow(): desired = PL_workflow() assert_allclose(actual, desired) - @pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("theta, val", [(jnp.pi, 0), (-100.0, 1)]) - def test_adjoint_op(self, theta, val, backend): + def test_adjoint_op(self, theta, val, backend, experimental_capture): """Ensures that catalyst.adjoint accepts single PennyLane operators classes as argument.""" device = qml.device(backend, wires=2) - @qjit + @qjit(experimental_capture=experimental_capture) @qml.qnode(device) def C_workflow(theta, val): adjoint(qml.RY)(jnp.pi, val) @@ -116,14 +108,13 @@ def PL_workflow(theta, val): desired = PL_workflow(theta, val) assert_allclose(actual, desired) - @pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("theta, val", [(np.pi, 0), (-100.0, 2)]) - def test_adjoint_bound_op(self, theta, val, backend): + def test_adjoint_bound_op(self, theta, val, backend, experimental_capture): """Ensures that catalyst.adjoint accepts single PennyLane operators objects as argument.""" device = qml.device(backend, wires=3) - @qjit + @qjit(experimental_capture=experimental_capture) @qml.qnode(device) def C_workflow(theta, val): adjoint(qml.RX(jnp.pi, val)) @@ -142,9 +133,8 @@ def PL_workflow(theta, val): desired = PL_workflow(theta, val) assert_allclose(actual, desired, atol=1e-6, rtol=1e-6) - @pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("w, p", [(0, 0.5), (0, -100.0), (1, 123.22)]) - def test_adjoint_param_fun(self, w, p, backend): + def test_adjoint_param_fun(self, w, p, backend,experimental_capture): """Ensures that catalyst.adjoint accepts parameterized Python functions as arguments.""" def func(w, theta1, theta2, theta3=1): @@ -154,7 +144,7 @@ def func(w, theta1, theta2, theta3=1): device = qml.device(backend, wires=2) - @qjit + @qjit(experimental_capture=experimental_capture) @qml.qnode(device) def C_workflow(w, theta): qml.PauliX(wires=0) @@ -173,8 +163,7 @@ def PL_workflow(w, theta): desired = PL_workflow(w, p) assert_allclose(actual, desired) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_nested_fun(self, backend): + def test_adjoint_nested_fun(self, backend, experimental_capture): """Ensures that catalyst.adjoint allows arbitrary nesting.""" def func(A, I): @@ -184,7 +173,7 @@ def func(A, I): I = I + 1 A(partial(func, A=A, I=I))() - @qjit + @qjit(experimental_capture=experimental_capture) @qml.qnode(qml.device(backend, wires=2)) def C_workflow(): qml.RX(np.pi / 2, wires=0) @@ -201,8 +190,7 @@ def PL_workflow(): assert_allclose(C_workflow(), PL_workflow()) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_qubitunitary(self, backend): + def test_adjoint_qubitunitary(self, backend, experimental_capture): """Ensures that catalyst.adjoint supports QubitUnitary oprtations.""" def func(): @@ -218,10 +206,9 @@ def func(): wires=[0, 1], ) - self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=2)) + self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=2), experimental_capture=experimental_capture) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_qubitunitary_dynamic_variable_loop(self, backend): + def test_adjoint_qubitunitary_dynamic_variable_loop(self, backend, experimental_capture): """Ensures that catalyst.adjoint supports QubitUnitary oprtations.""" def func(gate): @@ -243,17 +230,16 @@ def loop_body(_i, s): ] ) - self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=2), _input) + self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=2), _input, experimental_capture=experimental_capture) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_multirz(self, backend): + def test_adjoint_multirz(self, backend, experimental_capture): """Ensures that catalyst.adjoint supports MultiRZ operations.""" def func(): qml.PauliX(0) qml.MultiRZ(theta=np.pi / 2, wires=[0, 1]) - self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=2)) + self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=2), experimental_capture=experimental_capture) def test_adjoint_pcphase(self, backend): """Ensures that catalyst.adjoint supports PCPhase operations.""" @@ -293,8 +279,7 @@ def C_workflow(): C_workflow() - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_classical_loop(self, backend): + def test_adjoint_classical_loop(self, backend, experimental_capture): """Checks that catalyst.adjoint supports purely-classical Control-flows.""" def func(w=0): @@ -305,11 +290,10 @@ def loop(_i, s): qml.PauliX(wires=loop(w)) # pylint: disable=no-value-for-parameter qml.RX(np.pi / 2, wires=w) - self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=3), 0) + self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=3), 0, experimental_capture=experimental_capture) - @pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("pred", [True, False]) - def test_adjoint_cond(self, backend, pred): + def test_adjoint_cond(self, backend, pred, experimental_capture): """Tests that the correct gates are applied in reverse in a conditional branch""" def func(pred, theta): @@ -320,10 +304,9 @@ def cond_fn(): cond_fn() dev = qml.device(backend, wires=1) - self.verify_catalyst_adjoint_against_pennylane(func, dev, pred, jnp.pi) + self.verify_catalyst_adjoint_against_pennylane(func, dev, pred, jnp.pi, experimental_capture=experimental_capture) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_while_loop(self, backend): + def test_adjoint_while_loop(self, backend, experimental_capture): """ Tests that the correct gates are applied in reverse in a while loop with a statically unknown number of iterations. @@ -341,10 +324,9 @@ def loop_body(carried): qml.RZ(final, wires=0) dev = qml.device(backend, wires=1) - self.verify_catalyst_adjoint_against_pennylane(func, dev, 10) + self.verify_catalyst_adjoint_against_pennylane(func, dev, 10, experimental_capture=experimental_capture) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_for_loop(self, backend): + def test_adjoint_for_loop(self, backend, experimental_capture): """Tests the correct application of gates (with dynamic wires)""" def func(ub): @@ -355,10 +337,9 @@ def loop_body(i): loop_body() # pylint: disable=no-value-for-parameter dev = qml.device(backend, wires=5) - self.verify_catalyst_adjoint_against_pennylane(func, dev, 4) + self.verify_catalyst_adjoint_against_pennylane(func, dev, 4,experimental_capture=experimental_capture) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_while_nested(self, backend): + def test_adjoint_while_nested(self, backend, experimental_capture): """Tests the correct handling of nested while loops.""" def func(limit, inner_iters): @@ -387,11 +368,10 @@ def cond_otherwise(): dev = qml.device(backend, wires=2) self.verify_catalyst_adjoint_against_pennylane( - func, dev, 10, jnp.array([2, 4, 3, 5, 1, 7, 4, 6, 9, 10]) + func, dev, 10, jnp.array([2, 4, 3, 5, 1, 7, 4, 6, 9, 10]), experimental_capture=experimental_capture ) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_nested_with_control_flow(self, backend): + def test_adjoint_nested_with_control_flow(self, backend, experimental_capture): """ Tests that nested adjoint ops produce correct results in the presence of nested control flow. @@ -431,7 +411,7 @@ def loop_inner(_): dev = qml.device(backend, wires=1) - @qjit + @qjit(experimental_capture=experimental_capture) @qml.qnode(dev) def catalyst_workflow(*args): adjoint(c_quantum_func)(*args) @@ -444,8 +424,7 @@ def pennylane_workflow(*args): assert_allclose(catalyst_workflow(jnp.pi), pennylane_workflow(jnp.pi)) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_for_nested(self, backend): + def test_adjoint_for_nested(self, backend, experimental_capture): """ Tests the adjoint op with nested and interspersed for/while loops that produce classical values in addition to quantum ones @@ -477,8 +456,9 @@ def while_loop_inner(counter): loop_outer() # pylint: disable=no-value-for-parameter dev = qml.device(backend, wires=1) - self.verify_catalyst_adjoint_against_pennylane(func, dev, jnp.pi) + self.verify_catalyst_adjoint_against_pennylane(func, dev, jnp.pi, experimental_capture=experimental_capture) + @pytest.mark.old_frontend def test_adjoint_wires(self, backend): """Test the wires property of Adjoint""" @@ -496,6 +476,7 @@ def func(theta): # Without the `wires` property, returns `[-1]` assert circuit(0.3) == qml.wires.Wires([0, 2]) + @pytest.mark.old_frontend def test_adjoint_wires_qubitunitary(self, backend): """Test the wires property of nested Adjoint with QubitUnitary""" @@ -521,6 +502,8 @@ def func(): # Without the `wires` property, returns `[-1]` assert circuit() == qml.wires.Wires([0, 1]) + + @pytest.mark.old_frontend @pytest.mark.xfail(reason="adjoint.wires is not supported with variable wires") def test_adjoint_var_wires(self, backend): """Test catalyst.adjoint.wires with variable wires.""" @@ -567,8 +550,7 @@ def cond_fn(): # It returns `-1` instead of `0` assert circuit() == qml.wires.Wires([0]) - @pytest.mark.usefixtures("use_both_frontend") - def test_adjoint_ctrl_ctrl_subroutine(self, backend): + def test_adjoint_ctrl_ctrl_subroutine(self, backend, experimental_capture): """https://github.com/PennyLaneAI/catalyst/issues/589""" def subsubroutine(): @@ -587,7 +569,7 @@ def circuit(): return qml.probs(wires=dev.wires) expected = circuit() - observed = qjit(circuit)() + observed = qjit(circuit, experimental_capture=experimental_capture)() assert_allclose(expected, observed) def test_adjoint_outside_qjit(self, backend): diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index edd3394bd6..9737a1ee44 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -705,8 +705,7 @@ def circuit(x): assert circuit(3) == False assert circuit(6) == True - @pytest.mark.usefixtures("use_both_frontend") - def test_branch_return_mismatch(self, backend): + def test_branch_return_mismatch(self, backend, experimental_capture): """Test that an exception is raised when the true branch returns a value without an else branch. """ @@ -725,7 +724,7 @@ def circuit(pred: bool): with pytest.raises( err_type, match="Some branches did not define a value for variable 'res'" ): - qjit(autograph=True)(qml.qnode(qml.device(backend, wires=1))(circuit)) + qjit(autograph=True, experimental_capture=experimental_capture)(qml.qnode(qml.device(backend, wires=1))(circuit)) def test_branch_no_multi_return_mismatch(self, backend): """Test that case when the return types of all branches do not match.""" @@ -816,11 +815,10 @@ def test_python_range_fallback(self): assert isinstance(c_range._py_range, range) assert c_range[2] == 2 - @pytest.mark.usefixtures("use_both_frontend") - def test_for_in_array(self): + def test_for_in_array(self, experimental_capture): """Test for loop over JAX array.""" - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=experimental_capture) @qml.qnode(qml.device("lightning.qubit", wires=1)) def f(params): for x in params: @@ -830,11 +828,10 @@ def f(params): result = f(jnp.array([0.0, 1 / 4 * jnp.pi, 2 / 4 * jnp.pi])) assert np.allclose(result, -jnp.sqrt(2) / 2) - @pytest.mark.usefixtures("use_both_frontend") - def test_for_in_array_unpack(self): + def test_for_in_array_unpack(self, experimental_capture): """Test for loop over a 2D JAX array unpacking the inner dimension.""" - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=experimental_capture) @qml.qnode(qml.device("lightning.qubit", wires=1)) def f(params): for x1, x2 in params: @@ -845,11 +842,10 @@ def f(params): result = f(jnp.array([[0.0, 1 / 4 * jnp.pi], [2 / 4 * jnp.pi, jnp.pi]])) assert np.allclose(result, jnp.sqrt(2) / 2) - @pytest.mark.usefixtures("use_both_frontend") - def test_for_in_numeric_list(self): + def test_for_in_numeric_list(self, experimental_capture): """Test for loop over a Python list that is convertible to an array.""" - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=experimental_capture) @qml.qnode(qml.device("lightning.qubit", wires=1)) def f(): params = [0.0, 1 / 4 * jnp.pi, 2 / 4 * jnp.pi] @@ -860,11 +856,10 @@ def f(): result = f() assert np.allclose(result, -jnp.sqrt(2) / 2) - @pytest.mark.usefixtures("use_both_frontend") - def test_for_in_numeric_list_of_list(self): + def test_for_in_numeric_list_of_list(self, experimental_capture): """Test for loop over a nested Python list that is convertible to an array.""" - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=experimental_capture) @qml.qnode(qml.device("lightning.qubit", wires=1)) def f(): params = [[0.0, 1 / 4 * jnp.pi], [2 / 4 * jnp.pi, jnp.pi]] diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index 4674c04a09..20926d8869 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -30,13 +30,13 @@ from catalyst.utils.exceptions import CompileError -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_basic_dynamic_wire_alloc_plain_API(backend): """ Test basic qml.allocate and qml.deallocate. """ - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=3)) def circuit(): qml.X(1) # |010> @@ -54,13 +54,13 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_basic_dynamic_wire_alloc_ctx_API(backend): """ Test basic qml.allocate with context manager API. """ - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=3)) def circuit(): qml.X(1) @@ -77,13 +77,13 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_measure(backend): """ Test qml.allocate with qml.Measure ops. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q: @@ -101,13 +101,13 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_measure_with_reset(backend): """ Test qml.allocate with qml.Measure ops with resetting. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q: @@ -131,14 +131,14 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only @pytest.mark.parametrize("ctrl_val, expected", [(False, [0, 1]), (True, [1, 0])]) def test_qml_ctrl(ctrl_val, expected, backend): """ Test qml.allocate with qml.ctrl ops. """ - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q: @@ -150,13 +150,13 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_QubitUnitary(backend): """ Test qml.allocate with qml.QubitUnitary ops. """ - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(2) as qs: @@ -169,13 +169,13 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_StatePrep(backend): """ Test qml.allocate with qml.StatePrep ops. """ - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q: @@ -188,13 +188,13 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_BasisState(backend): """ Test qml.allocate with qml.BasisState ops. """ - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q: @@ -207,14 +207,14 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only @pytest.mark.parametrize("cond, expected", [(True, [0, 0, 1, 0]), (False, [0, 1, 0, 0])]) def test_dynamic_wire_alloc_cond(cond, expected, backend): """ Test qml.allocate and qml.deallocate inside cond. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=2)) def circuit(c): if c: @@ -235,14 +235,14 @@ def circuit(c): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only @pytest.mark.parametrize("cond, expected", [(True, [0, 1, 0, 0]), (False, [1, 0, 0, 0])]) def test_dynamic_wire_alloc_cond_outside(cond, expected, backend): """ Test passing dynamically allocated wires into a cond. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=2)) def circuit(c): with qml.allocate(1) as q1: @@ -260,7 +260,7 @@ def circuit(c): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only @pytest.mark.parametrize( "num_iter, expected", [(3, [0, 0, 1, 0, 0, 0, 0, 0]), (4, [1, 0, 0, 0, 0, 0, 0, 0])] ) @@ -269,7 +269,7 @@ def test_dynamic_wire_alloc_forloop(num_iter, expected, backend): Test qml.allocate and qml.deallocate inside for loop. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=3)) def circuit(N): for _ in range(N): @@ -285,13 +285,13 @@ def circuit(N): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_dynamic_wire_alloc_forloop_outside(backend): """ Test passing dynamically allocated wires into a for loop. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q: @@ -307,13 +307,13 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_dynamic_wire_alloc_forloop_outside_multiple_regs(backend): """ Test using multiple dynamically allocated registers from inside for loop. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q1: @@ -330,7 +330,7 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only @pytest.mark.parametrize( "num_iter, expected", [(3, [0, 0, 1, 0, 0, 0, 0, 0]), (4, [1, 0, 0, 0, 0, 0, 0, 0])] ) @@ -339,7 +339,7 @@ def test_dynamic_wire_alloc_whileloop(num_iter, expected, backend): Test qml.allocate and qml.deallocate inside while loop. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=3)) def circuit(N): i = 0 @@ -357,14 +357,14 @@ def circuit(N): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only @pytest.mark.parametrize("num_iter, expected", [(3, [0, 1, 0, 0]), (4, [1, 0, 0, 0])]) def test_dynamic_wire_alloc_whileloop_outside(num_iter, expected, backend): """ Test passing dynamically allocated wires into a while loop. """ - @qjit(autograph=True) + @qjit(autograph=True, experimental_capture=True) @qml.qnode(qml.device(backend, wires=2)) def circuit(N): i = 0 @@ -383,7 +383,7 @@ def circuit(N): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only @pytest.mark.parametrize("flip_again, expected", [(True, [1, 0]), (False, [0, 1])]) def test_subroutine(flip_again, expected, backend): """ @@ -395,7 +395,7 @@ def flip(w): qml.X(w) qml.CNOT(wires=[w, 0]) - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q1: @@ -409,7 +409,7 @@ def circuit(): assert np.allclose(expected, observed) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_subroutine_multiple_args(backend): """ Test passing dynamically allocated wires into a subroutine with multiple arguments. @@ -421,7 +421,7 @@ def flip(w1, w2, theta): qml.X(w2) qml.ctrl(qml.RX, (w1, w2))(theta, wires=0) - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q1: @@ -434,6 +434,7 @@ def circuit(): assert np.allclose(expected, observed) +@pytest.mark.old_frontend def test_no_capture(backend): """ Test error message when used without capture. @@ -443,7 +444,7 @@ def test_no_capture(backend): match=re.escape("qml.allocate() is only supported with program capture enabled."), ): - @qjit + @qjit(experimental_capture=False) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as _: @@ -451,7 +452,7 @@ def circuit(): return qml.probs(wires=[0]) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_use_after_free(backend): """ Test error message when used after free. @@ -462,7 +463,7 @@ def test_use_after_free(backend): match="Deallocated qubits cannot be used, but used in Hadamard.", ): - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as q: @@ -471,7 +472,7 @@ def circuit(): return qml.probs(wires=[0]) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_terminal_MP_all_wires(backend): """ Test error message when used with terminal measurements on all wires. @@ -487,7 +488,7 @@ def test_terminal_MP_all_wires(backend): ), ): - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): with qml.allocate(1) as _: @@ -495,7 +496,7 @@ def circuit(): return qml.probs() -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_terminal_MP_dynamic_wires(backend): """ Test error message when used with terminal measurements on dynamic wires. @@ -511,14 +512,14 @@ def test_terminal_MP_dynamic_wires(backend): ), ): - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(): q = qml.allocate(1) return qml.probs(q) -@pytest.mark.usefixtures("use_capture") +@pytest.mark.capture_only def test_unsupported_adjoint(backend): """ Test that an error is raised when a dynamically allocated wire is passed into a adjoint. @@ -529,7 +530,7 @@ def test_unsupported_adjoint(backend): match="Dynamically allocated wires cannot be used in quantum adjoints yet.", ): - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=2)) def circuit(): with qml.allocate(1) as q: diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py index 6d33ec22eb..76df5e52da 100644 --- a/frontend/test/pytest/test_jit_behaviour.py +++ b/frontend/test/pytest/test_jit_behaviour.py @@ -901,13 +901,14 @@ def g(x: float): assert g.mlir_opt assert "__catalyst__qis" in g.mlir_opt - @pytest.mark.usefixtures("use_capture", "requires_xdsl") + @pytest.mark.capture_only + @pytest.mark.usefixtures("requires_xdsl") def test_mlir_opt_using_xdsl_passes(self, backend): """Test mlir opt using xDSL passes.""" # pylint: disable-next=import-outside-toplevel from pennylane.compiler.python_compiler.transforms import iterative_cancel_inverses_pass - @qjit + @qjit(experimental_capture=True) @iterative_cancel_inverses_pass @qml.qnode(qml.device(backend, wires=1)) def f(): diff --git a/frontend/test/pytest/test_switch.py b/frontend/test/pytest/test_switch.py index a4927bb0f0..a1a2394b95 100644 --- a/frontend/test/pytest/test_switch.py +++ b/frontend/test/pytest/test_switch.py @@ -164,29 +164,33 @@ def my_branch(): with pytest.raises(TypeError, match=MISSING_ARGUMENT_MESSAGE): circuit_3(0) - @pytest.mark.usefixtures("use_capture") + @pytest.mark.capture_only def test_fails_capture(self): """Test that a switch raises an exception with program capture enabled.""" if not qml.capture.enabled(): pytest.skip("capture only test") - with pytest.raises(PlxprCaptureCFCompatibilityError) as exc_info: + qml.capture.enable() + try: + with pytest.raises(PlxprCaptureCFCompatibilityError) as exc_info: - def circuit(i): - @switch(i) - def my_switch(): - return 0 + def circuit(i): + @switch(i) + def my_switch(): + return 0 - @my_switch.branch(0) - def my_branch(): - return 1 + @my_switch.branch(0) + def my_branch(): + return 1 - return my_switch() + return my_switch() - circuit(0) + circuit(0) - error_msg = str(exc_info.value) - assert "not supported" in error_msg + error_msg = str(exc_info.value) + assert "not supported" in error_msg + finally: + qml.capture.disable() def test_missing_operation(self): """Test that operation access in an interpreted context raises an exception.""" @@ -387,15 +391,13 @@ def my_branch(): with pytest.raises(TypeError, match=MISSING_ARGUMENT_MESSAGE): circuit_3(0) - @pytest.mark.usefixtures("use_capture") + @pytest.mark.capture_only def test_fails_capture(self, backend): """Test that an exception is raised when program capture is enabled.""" - if not qml.capture.enabled(): - pytest.skip("capture only test") with pytest.raises(PlxprCaptureCFCompatibilityError) as exc_info: - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(i): @switch(i) @@ -638,7 +640,7 @@ def my_branch(): with pytest.raises(TypeError, match=MISSING_ARGUMENT_MESSAGE): circuit_2(0) - @pytest.mark.usefixtures("use_capture") + @pytest.mark.capture_only def test_fails_capture(self, backend): """Test that an exception is raised when program capture is enabled.""" if not qml.capture.enabled(): @@ -646,7 +648,7 @@ def test_fails_capture(self, backend): with pytest.raises(PlxprCaptureCFCompatibilityError) as exc_info: - @qjit + @qjit(experimental_capture=True) @qml.qnode(qml.device(backend, wires=1)) def circuit(i): @switch(i)