2222import pennylane .numpy as pnp
2323import pytest
2424from numpy .testing import assert_allclose
25+ from pennylane import adjoint , cond , for_loop , qjit , while_loop
2526from pennylane .ops .op_math .adjoint import Adjoint , AdjointOperation
2627
27- from catalyst import adjoint , cond , debug , for_loop , measure , qjit , while_loop
28+ import catalyst
29+ from catalyst import debug , measure , qjit
2830
2931# pylint: disable=too-many-lines,missing-class-docstring,missing-function-docstring,too-many-public-methods
3032
@@ -50,8 +52,17 @@ def pennylane_workflow(*args):
5052 qml .adjoint (quantum_func )(* args )
5153 return qml .state ()
5254
53- assert_allclose (catalyst_workflow (* args ), pennylane_workflow (* args ))
55+ capture_enabled = qml .capture .enabled ()
56+ qml .capture .disable ()
57+ try :
58+ pl_res = pennylane_workflow (* args )
59+ finally :
60+ if capture_enabled :
61+ qml .capture .enable ()
5462
63+ assert_allclose (catalyst_workflow (* args ), pl_res )
64+
65+ @pytest .mark .usefixtures ("use_both_frontend" )
5566 def test_adjoint_func (self , backend ):
5667 """Ensures that catalyst.adjoint accepts simple Python functions as argument. Makes sure
5768 that simple quantum gates are adjointed correctly."""
@@ -82,6 +93,7 @@ def PL_workflow():
8293 desired = PL_workflow ()
8394 assert_allclose (actual , desired )
8495
96+ @pytest .mark .usefixtures ("use_both_frontend" )
8597 @pytest .mark .parametrize ("theta, val" , [(jnp .pi , 0 ), (- 100.0 , 1 )])
8698 def test_adjoint_op (self , theta , val , backend ):
8799 """Ensures that catalyst.adjoint accepts single PennyLane operators classes as argument."""
@@ -91,19 +103,20 @@ def test_adjoint_op(self, theta, val, backend):
91103 @qml .qnode (device )
92104 def C_workflow (theta , val ):
93105 adjoint (qml .RY )(jnp .pi , val )
94- adjoint (qml .RZ )(theta , wires = val )
106+ adjoint (qml .RZ )(theta , val )
95107 return qml .state ()
96108
97109 @qml .qnode (device )
98110 def PL_workflow (theta , val ):
99111 qml .adjoint (qml .RY )(jnp .pi , val )
100- qml .adjoint (qml .RZ )(theta , wires = val )
112+ qml .adjoint (qml .RZ )(theta , val )
101113 return qml .state ()
102114
103115 actual = C_workflow (theta , val )
104116 desired = PL_workflow (theta , val )
105117 assert_allclose (actual , desired )
106118
119+ @pytest .mark .usefixtures ("use_both_frontend" )
107120 @pytest .mark .parametrize ("theta, val" , [(np .pi , 0 ), (- 100.0 , 2 )])
108121 def test_adjoint_bound_op (self , theta , val , backend ):
109122 """Ensures that catalyst.adjoint accepts single PennyLane operators objects as argument."""
@@ -129,6 +142,7 @@ def PL_workflow(theta, val):
129142 desired = PL_workflow (theta , val )
130143 assert_allclose (actual , desired , atol = 1e-6 , rtol = 1e-6 )
131144
145+ @pytest .mark .usefixtures ("use_both_frontend" )
132146 @pytest .mark .parametrize ("w, p" , [(0 , 0.5 ), (0 , - 100.0 ), (1 , 123.22 )])
133147 def test_adjoint_param_fun (self , w , p , backend ):
134148 """Ensures that catalyst.adjoint accepts parameterized Python functions as arguments."""
@@ -144,21 +158,22 @@ def func(w, theta1, theta2, theta3=1):
144158 @qml .qnode (device )
145159 def C_workflow (w , theta ):
146160 qml .PauliX (wires = 0 )
147- adjoint (func )(w , theta , theta2 = theta )
161+ adjoint (func )(w , theta , theta )
148162 qml .PauliY (wires = 0 )
149163 return qml .state ()
150164
151165 @qml .qnode (device )
152166 def PL_workflow (w , theta ):
153167 qml .PauliX (wires = 0 )
154- qml .adjoint (func )(w , theta , theta2 = theta )
168+ qml .adjoint (func )(w , theta , theta )
155169 qml .PauliY (wires = 0 )
156170 return qml .state ()
157171
158172 actual = C_workflow (w , p )
159173 desired = PL_workflow (w , p )
160174 assert_allclose (actual , desired )
161175
176+ @pytest .mark .usefixtures ("use_both_frontend" )
162177 def test_adjoint_nested_fun (self , backend ):
163178 """Ensures that catalyst.adjoint allows arbitrary nesting."""
164179
@@ -186,6 +201,7 @@ def PL_workflow():
186201
187202 assert_allclose (C_workflow (), PL_workflow ())
188203
204+ @pytest .mark .usefixtures ("use_both_frontend" )
189205 def test_adjoint_qubitunitary (self , backend ):
190206 """Ensures that catalyst.adjoint supports QubitUnitary oprtations."""
191207
@@ -204,6 +220,7 @@ def func():
204220
205221 self .verify_catalyst_adjoint_against_pennylane (func , qml .device (backend , wires = 2 ))
206222
223+ @pytest .mark .usefixtures ("use_both_frontend" )
207224 def test_adjoint_qubitunitary_dynamic_variable_loop (self , backend ):
208225 """Ensures that catalyst.adjoint supports QubitUnitary oprtations."""
209226
@@ -228,6 +245,7 @@ def loop_body(_i, s):
228245
229246 self .verify_catalyst_adjoint_against_pennylane (func , qml .device (backend , wires = 2 ), _input )
230247
248+ @pytest .mark .usefixtures ("use_both_frontend" )
231249 def test_adjoint_multirz (self , backend ):
232250 """Ensures that catalyst.adjoint supports MultiRZ operations."""
233251
@@ -275,6 +293,7 @@ def C_workflow():
275293
276294 C_workflow ()
277295
296+ @pytest .mark .usefixtures ("use_both_frontend" )
278297 def test_adjoint_classical_loop (self , backend ):
279298 """Checks that catalyst.adjoint supports purely-classical Control-flows."""
280299
@@ -288,6 +307,7 @@ def loop(_i, s):
288307
289308 self .verify_catalyst_adjoint_against_pennylane (func , qml .device (backend , wires = 3 ), 0 )
290309
310+ @pytest .mark .usefixtures ("use_both_frontend" )
291311 @pytest .mark .parametrize ("pred" , [True , False ])
292312 def test_adjoint_cond (self , backend , pred ):
293313 """Tests that the correct gates are applied in reverse in a conditional branch"""
@@ -302,6 +322,7 @@ def cond_fn():
302322 dev = qml .device (backend , wires = 1 )
303323 self .verify_catalyst_adjoint_against_pennylane (func , dev , pred , jnp .pi )
304324
325+ @pytest .mark .usefixtures ("use_both_frontend" )
305326 def test_adjoint_while_loop (self , backend ):
306327 """
307328 Tests that the correct gates are applied in reverse in a while loop with a statically
@@ -322,6 +343,7 @@ def loop_body(carried):
322343 dev = qml .device (backend , wires = 1 )
323344 self .verify_catalyst_adjoint_against_pennylane (func , dev , 10 )
324345
346+ @pytest .mark .usefixtures ("use_both_frontend" )
325347 def test_adjoint_for_loop (self , backend ):
326348 """Tests the correct application of gates (with dynamic wires)"""
327349
@@ -335,6 +357,7 @@ def loop_body(i):
335357 dev = qml .device (backend , wires = 5 )
336358 self .verify_catalyst_adjoint_against_pennylane (func , dev , 4 )
337359
360+ @pytest .mark .usefixtures ("use_both_frontend" )
338361 def test_adjoint_while_nested (self , backend ):
339362 """Tests the correct handling of nested while loops."""
340363
@@ -367,6 +390,7 @@ def cond_otherwise():
367390 func , dev , 10 , jnp .array ([2 , 4 , 3 , 5 , 1 , 7 , 4 , 6 , 9 , 10 ])
368391 )
369392
393+ @pytest .mark .usefixtures ("use_both_frontend" )
370394 def test_adjoint_nested_with_control_flow (self , backend ):
371395 """
372396 Tests that nested adjoint ops produce correct results in the presence of nested control
@@ -420,6 +444,7 @@ def pennylane_workflow(*args):
420444
421445 assert_allclose (catalyst_workflow (jnp .pi ), pennylane_workflow (jnp .pi ))
422446
447+ @pytest .mark .usefixtures ("use_both_frontend" )
423448 def test_adjoint_for_nested (self , backend ):
424449 """
425450 Tests the adjoint op with nested and interspersed for/while loops that produce classical
@@ -542,6 +567,7 @@ def cond_fn():
542567 # It returns `-1` instead of `0`
543568 assert circuit () == qml .wires .Wires ([0 ])
544569
570+ @pytest .mark .usefixtures ("use_both_frontend" )
545571 def test_adjoint_ctrl_ctrl_subroutine (self , backend ):
546572 """https://github.com/PennyLaneAI/catalyst/issues/589"""
547573
@@ -585,7 +611,7 @@ def qfunc(x):
585611 qml .RY (x , wires = 0 )
586612 qml .Hadamard (0 )
587613
588- adj_op = adjoint (qfunc )(0.7 )
614+ adj_op = catalyst . adjoint (qfunc )(0.7 )
589615 decomp = adj_op .decomposition ()
590616
591617 assert len (decomp ) == 2
@@ -602,7 +628,7 @@ def qfunc(x, w):
602628 qml .CNOT (wires = [1 , w ])
603629
604630 with pytest .raises (ValueError , match = "Eagerly computing the adjoint" ):
605- adjoint (qfunc , lazy = False )(0.1 , 0 )
631+ catalyst . adjoint (qfunc , lazy = False )(0.1 , 0 )
606632
607633
608634#####################################################################################
0 commit comments