Skip to content

Commit 8664570

Browse files
authored
fix adjoint and ctrl of qubit unitary and gphase (#2158)
**Context:** The special handling for qubit unitary and gphase bypassed the handling for operator modifiers. **Description of the Change:** Handle qubit unitary and gphase inside `interpret_operation` instead of registering custom handling for the primitive. **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** Fixes #2151 [sc-102523]
1 parent 1dab5d8 commit 8664570

File tree

6 files changed

+162
-102
lines changed

6 files changed

+162
-102
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
[(#2128)](https://github.com/PennyLaneAI/catalyst/pull/2128)
3636
[(#2133)](https://github.com/PennyLaneAI/catalyst/pull/2133)
3737

38+
* Fixes the translation of `QubitUnitary` and `GlobalPhase` ops
39+
when they are modified by adjoint or control.
40+
[(##2158)](https://github.com/PennyLaneAI/catalyst/pull/2158)
41+
3842
* Fixes the translation of a workflow with different transforms applied to different qnodes.
3943
[(#2167)](https://github.com/PennyLaneAI/catalyst/pull/2167)
4044

frontend/catalyst/from_plxpr/control_flow.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ def handle_while_loop(
296296
jaxpr = ClosedJaxpr(jaxpr_body_fn, consts_body)
297297

298298
f = partial(_calling_convention, self, jaxpr)
299-
converted_body_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr
300-
299+
converted_body_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg)
300+
new_consts_body = converted_body_jaxpr_branch.consts
301301
converted_body_closed_jaxpr_branch = ClosedJaxpr(
302-
convert_constvars_jaxpr(converted_body_jaxpr_branch), ()
302+
convert_constvars_jaxpr(converted_body_jaxpr_branch.jaxpr), ()
303303
)
304304

305305
# Convert for condition from plxpr to Catalyst jaxpr
@@ -324,21 +324,22 @@ def remove_qreg(*args_plus_qreg):
324324

325325
return converter(jaxpr, *args)
326326

327-
converted_cond_jaxpr_branch = jax.make_jaxpr(remove_qreg)(*args_plus_qreg).jaxpr
327+
converted_cond_jaxpr_branch = jax.make_jaxpr(remove_qreg)(*args_plus_qreg)
328328
converted_cond_closed_jaxpr_branch = ClosedJaxpr(
329-
convert_constvars_jaxpr(converted_cond_jaxpr_branch), ()
329+
convert_constvars_jaxpr(converted_cond_jaxpr_branch.jaxpr), ()
330330
)
331331

332+
new_consts_cond = converted_cond_jaxpr_branch.consts
332333
# Build Catalyst compatible input values
333-
while_loop_invals = [*consts_cond, *consts_body, *args_plus_qreg]
334+
while_loop_invals = [*new_consts_cond, *new_consts_body, *args_plus_qreg]
334335

335336
# Perform the binding
336337
outvals = while_p.bind(
337338
*while_loop_invals,
338339
cond_jaxpr=converted_cond_closed_jaxpr_branch,
339340
body_jaxpr=converted_body_closed_jaxpr_branch,
340-
cond_nconsts=len(consts_cond),
341-
body_nconsts=len(consts_body),
341+
cond_nconsts=len(new_consts_cond),
342+
body_nconsts=len(new_consts_body),
342343
nimplicit=0,
343344
preserve_dimensions=True,
344345
)

frontend/catalyst/from_plxpr/qfunc_interpreter.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,16 @@ def interpret_operation(self, op, is_adjoint=False, control_values=(), control_w
156156
if any(not qreg.is_qubit_mode() and qreg.expired for qreg in in_qregs + in_ctrl_qregs):
157157
raise CompileError(f"Deallocated qubits cannot be used, but used in {op.name}.")
158158

159-
out_qubits = qinst_p.bind(
159+
bind_fn = _special_op_bind_call.get(type(op), qinst_p.bind)
160+
161+
out_qubits = bind_fn(
160162
*[*in_qubits, *op.data, *in_ctrl_qubits, *control_values],
161163
op=op.name,
162164
qubits_len=len(op.wires),
163165
params_len=len(op.data),
164166
ctrl_len=len(control_wires),
165167
adjoint=is_adjoint,
166168
)
167-
168169
out_non_ctrl_qubits = out_qubits[: len(out_qubits) - len(control_wires)]
169170
out_ctrl_qubits = out_qubits[-len(control_wires) :]
170171

@@ -275,6 +276,27 @@ def __call__(self, jaxpr, *args):
275276
return self.eval(jaxpr.jaxpr, jaxpr.consts, *args)
276277

277278

279+
# pylint: disable=unused-argument
280+
def _qubit_unitary_bind_call(*invals, op, qubits_len, params_len, ctrl_len, adjoint):
281+
wires = invals[:qubits_len]
282+
mat = invals[qubits_len]
283+
ctrl_inputs = invals[qubits_len + 1 :]
284+
return unitary_p.bind(
285+
mat, *wires, *ctrl_inputs, qubits_len=qubits_len, ctrl_len=ctrl_len, adjoint=adjoint
286+
)
287+
288+
289+
# pylint: disable=unused-argument
290+
def _gphase_bind_call(*invals, op, qubits_len, params_len, ctrl_len, adjoint):
291+
return gphase_p.bind(*invals[qubits_len:], ctrl_len=ctrl_len, adjoint=adjoint)
292+
293+
294+
_special_op_bind_call = {
295+
qml.QubitUnitary: _qubit_unitary_bind_call,
296+
qml.GlobalPhase: _gphase_bind_call,
297+
}
298+
299+
278300
# pylint: disable=unused-argument
279301
@PLxPRToQuantumJaxprInterpreter.register_primitive(qml.allocation.allocate_prim)
280302
def handle_qml_alloc(self, *, num_wires, state=None, restored=False):
@@ -441,22 +463,6 @@ def wrapper(*args):
441463
return ()
442464

443465

444-
@PLxPRToQuantumJaxprInterpreter.register_primitive(qml.QubitUnitary._primitive)
445-
def handle_qubit_unitary(self, *invals, n_wires):
446-
"""Handle the conversion from plxpr to Catalyst jaxpr for the QubitUnitary primitive"""
447-
in_qregs, in_qubits = get_in_qubit_values(invals[1:], self.qubit_index_recorder, self.init_qreg)
448-
outvals = unitary_p.bind(invals[0], *in_qubits, qubits_len=n_wires, ctrl_len=0, adjoint=False)
449-
for in_qreg, w, new_wire in zip(in_qregs, invals[1:], outvals):
450-
in_qreg[in_qreg.global_index_to_local_index(w)] = new_wire
451-
452-
453-
# pylint: disable=unused-argument
454-
@PLxPRToQuantumJaxprInterpreter.register_primitive(qml.GlobalPhase._primitive)
455-
def handle_global_phase(self, phase, *wires, n_wires):
456-
"""Handle the conversion from plxpr to Catalyst jaxpr for the GlobalPhase primitive"""
457-
gphase_p.bind(phase, ctrl_len=0, adjoint=False)
458-
459-
460466
@PLxPRToQuantumJaxprInterpreter.register_primitive(qml.BasisState._primitive)
461467
def handle_basis_state(self, *invals, n_wires):
462468
"""Handle the conversion from plxpr to Catalyst jaxpr for the BasisState primitive"""
@@ -590,13 +596,16 @@ def calling_convention(*args_plus_qreg):
590596
init_qreg.insert_all_dangling_qubits()
591597
return *retvals, converter.init_qreg.get()
592598

593-
_, args_tree = tree_flatten((consts, args, [qreg]))
594-
converted_jaxpr_branch = jax.make_jaxpr(calling_convention)(*consts, *args, qreg).jaxpr
599+
converted_jaxpr_branch = jax.make_jaxpr(calling_convention)(*args, qreg)
595600

596-
converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ())
601+
converted_closed_jaxpr_branch = ClosedJaxpr(
602+
convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ()
603+
)
604+
new_consts = converted_jaxpr_branch.consts
605+
_, args_tree = tree_flatten((new_consts, args, [qreg]))
597606
# Perform the binding
598607
outvals = adjoint_p.bind(
599-
*consts,
608+
*new_consts,
600609
*args,
601610
qreg,
602611
jaxpr=converted_closed_jaxpr_branch,

frontend/test/pytest/test_adjoint.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
import pennylane.numpy as pnp
2323
import pytest
2424
from numpy.testing import assert_allclose
25+
from pennylane import adjoint, cond, for_loop, qjit, while_loop
2526
from pennylane.ops.op_math.adjoint import Adjoint, AdjointOperation
2627

27-
from catalyst import adjoint, cond, debug, for_loop, measure, qjit, while_loop
28+
import catalyst
29+
from catalyst import debug, measure, qjit
2830

2931
# pylint: disable=too-many-lines,missing-class-docstring,missing-function-docstring,too-many-public-methods
3032

@@ -50,8 +52,17 @@ def pennylane_workflow(*args):
5052
qml.adjoint(quantum_func)(*args)
5153
return qml.state()
5254

53-
assert_allclose(catalyst_workflow(*args), pennylane_workflow(*args))
55+
capture_enabled = qml.capture.enabled()
56+
qml.capture.disable()
57+
try:
58+
pl_res = pennylane_workflow(*args)
59+
finally:
60+
if capture_enabled:
61+
qml.capture.enable()
5462

63+
assert_allclose(catalyst_workflow(*args), pl_res)
64+
65+
@pytest.mark.usefixtures("use_both_frontend")
5566
def test_adjoint_func(self, backend):
5667
"""Ensures that catalyst.adjoint accepts simple Python functions as argument. Makes sure
5768
that simple quantum gates are adjointed correctly."""
@@ -82,6 +93,7 @@ def PL_workflow():
8293
desired = PL_workflow()
8394
assert_allclose(actual, desired)
8495

96+
@pytest.mark.usefixtures("use_both_frontend")
8597
@pytest.mark.parametrize("theta, val", [(jnp.pi, 0), (-100.0, 1)])
8698
def test_adjoint_op(self, theta, val, backend):
8799
"""Ensures that catalyst.adjoint accepts single PennyLane operators classes as argument."""
@@ -91,19 +103,20 @@ def test_adjoint_op(self, theta, val, backend):
91103
@qml.qnode(device)
92104
def C_workflow(theta, val):
93105
adjoint(qml.RY)(jnp.pi, val)
94-
adjoint(qml.RZ)(theta, wires=val)
106+
adjoint(qml.RZ)(theta, val)
95107
return qml.state()
96108

97109
@qml.qnode(device)
98110
def PL_workflow(theta, val):
99111
qml.adjoint(qml.RY)(jnp.pi, val)
100-
qml.adjoint(qml.RZ)(theta, wires=val)
112+
qml.adjoint(qml.RZ)(theta, val)
101113
return qml.state()
102114

103115
actual = C_workflow(theta, val)
104116
desired = PL_workflow(theta, val)
105117
assert_allclose(actual, desired)
106118

119+
@pytest.mark.usefixtures("use_both_frontend")
107120
@pytest.mark.parametrize("theta, val", [(np.pi, 0), (-100.0, 2)])
108121
def test_adjoint_bound_op(self, theta, val, backend):
109122
"""Ensures that catalyst.adjoint accepts single PennyLane operators objects as argument."""
@@ -129,6 +142,7 @@ def PL_workflow(theta, val):
129142
desired = PL_workflow(theta, val)
130143
assert_allclose(actual, desired, atol=1e-6, rtol=1e-6)
131144

145+
@pytest.mark.usefixtures("use_both_frontend")
132146
@pytest.mark.parametrize("w, p", [(0, 0.5), (0, -100.0), (1, 123.22)])
133147
def test_adjoint_param_fun(self, w, p, backend):
134148
"""Ensures that catalyst.adjoint accepts parameterized Python functions as arguments."""
@@ -144,21 +158,22 @@ def func(w, theta1, theta2, theta3=1):
144158
@qml.qnode(device)
145159
def C_workflow(w, theta):
146160
qml.PauliX(wires=0)
147-
adjoint(func)(w, theta, theta2=theta)
161+
adjoint(func)(w, theta, theta)
148162
qml.PauliY(wires=0)
149163
return qml.state()
150164

151165
@qml.qnode(device)
152166
def PL_workflow(w, theta):
153167
qml.PauliX(wires=0)
154-
qml.adjoint(func)(w, theta, theta2=theta)
168+
qml.adjoint(func)(w, theta, theta)
155169
qml.PauliY(wires=0)
156170
return qml.state()
157171

158172
actual = C_workflow(w, p)
159173
desired = PL_workflow(w, p)
160174
assert_allclose(actual, desired)
161175

176+
@pytest.mark.usefixtures("use_both_frontend")
162177
def test_adjoint_nested_fun(self, backend):
163178
"""Ensures that catalyst.adjoint allows arbitrary nesting."""
164179

@@ -186,6 +201,7 @@ def PL_workflow():
186201

187202
assert_allclose(C_workflow(), PL_workflow())
188203

204+
@pytest.mark.usefixtures("use_both_frontend")
189205
def test_adjoint_qubitunitary(self, backend):
190206
"""Ensures that catalyst.adjoint supports QubitUnitary oprtations."""
191207

@@ -204,6 +220,7 @@ def func():
204220

205221
self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=2))
206222

223+
@pytest.mark.usefixtures("use_both_frontend")
207224
def test_adjoint_qubitunitary_dynamic_variable_loop(self, backend):
208225
"""Ensures that catalyst.adjoint supports QubitUnitary oprtations."""
209226

@@ -228,6 +245,7 @@ def loop_body(_i, s):
228245

229246
self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=2), _input)
230247

248+
@pytest.mark.usefixtures("use_both_frontend")
231249
def test_adjoint_multirz(self, backend):
232250
"""Ensures that catalyst.adjoint supports MultiRZ operations."""
233251

@@ -275,6 +293,7 @@ def C_workflow():
275293

276294
C_workflow()
277295

296+
@pytest.mark.usefixtures("use_both_frontend")
278297
def test_adjoint_classical_loop(self, backend):
279298
"""Checks that catalyst.adjoint supports purely-classical Control-flows."""
280299

@@ -288,6 +307,7 @@ def loop(_i, s):
288307

289308
self.verify_catalyst_adjoint_against_pennylane(func, qml.device(backend, wires=3), 0)
290309

310+
@pytest.mark.usefixtures("use_both_frontend")
291311
@pytest.mark.parametrize("pred", [True, False])
292312
def test_adjoint_cond(self, backend, pred):
293313
"""Tests that the correct gates are applied in reverse in a conditional branch"""
@@ -302,6 +322,7 @@ def cond_fn():
302322
dev = qml.device(backend, wires=1)
303323
self.verify_catalyst_adjoint_against_pennylane(func, dev, pred, jnp.pi)
304324

325+
@pytest.mark.usefixtures("use_both_frontend")
305326
def test_adjoint_while_loop(self, backend):
306327
"""
307328
Tests that the correct gates are applied in reverse in a while loop with a statically
@@ -322,6 +343,7 @@ def loop_body(carried):
322343
dev = qml.device(backend, wires=1)
323344
self.verify_catalyst_adjoint_against_pennylane(func, dev, 10)
324345

346+
@pytest.mark.usefixtures("use_both_frontend")
325347
def test_adjoint_for_loop(self, backend):
326348
"""Tests the correct application of gates (with dynamic wires)"""
327349

@@ -335,6 +357,7 @@ def loop_body(i):
335357
dev = qml.device(backend, wires=5)
336358
self.verify_catalyst_adjoint_against_pennylane(func, dev, 4)
337359

360+
@pytest.mark.usefixtures("use_both_frontend")
338361
def test_adjoint_while_nested(self, backend):
339362
"""Tests the correct handling of nested while loops."""
340363

@@ -367,6 +390,7 @@ def cond_otherwise():
367390
func, dev, 10, jnp.array([2, 4, 3, 5, 1, 7, 4, 6, 9, 10])
368391
)
369392

393+
@pytest.mark.usefixtures("use_both_frontend")
370394
def test_adjoint_nested_with_control_flow(self, backend):
371395
"""
372396
Tests that nested adjoint ops produce correct results in the presence of nested control
@@ -420,6 +444,7 @@ def pennylane_workflow(*args):
420444

421445
assert_allclose(catalyst_workflow(jnp.pi), pennylane_workflow(jnp.pi))
422446

447+
@pytest.mark.usefixtures("use_both_frontend")
423448
def test_adjoint_for_nested(self, backend):
424449
"""
425450
Tests the adjoint op with nested and interspersed for/while loops that produce classical
@@ -542,6 +567,7 @@ def cond_fn():
542567
# It returns `-1` instead of `0`
543568
assert circuit() == qml.wires.Wires([0])
544569

570+
@pytest.mark.usefixtures("use_both_frontend")
545571
def test_adjoint_ctrl_ctrl_subroutine(self, backend):
546572
"""https://github.com/PennyLaneAI/catalyst/issues/589"""
547573

@@ -585,7 +611,7 @@ def qfunc(x):
585611
qml.RY(x, wires=0)
586612
qml.Hadamard(0)
587613

588-
adj_op = adjoint(qfunc)(0.7)
614+
adj_op = catalyst.adjoint(qfunc)(0.7)
589615
decomp = adj_op.decomposition()
590616

591617
assert len(decomp) == 2
@@ -602,7 +628,7 @@ def qfunc(x, w):
602628
qml.CNOT(wires=[1, w])
603629

604630
with pytest.raises(ValueError, match="Eagerly computing the adjoint"):
605-
adjoint(qfunc, lazy=False)(0.1, 0)
631+
catalyst.adjoint(qfunc, lazy=False)(0.1, 0)
606632

607633

608634
#####################################################################################

frontend/test/pytest/test_global_phase.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import numpy as np
1818
import pennylane as qml
1919
import pytest
20-
21-
from catalyst import cond, qjit
20+
from pennylane import cond, qjit
2221

2322

23+
@pytest.mark.usefixtures("use_both_frontend")
2424
def test_global_phase(backend):
2525
"""Test vanilla global phase"""
2626
dev = qml.device(backend, wires=1)
@@ -36,6 +36,7 @@ def qnn():
3636
assert np.allclose(expected, observed)
3737

3838

39+
@pytest.mark.usefixtures("use_both_frontend")
3940
@pytest.mark.parametrize("inp", [True, False])
4041
def test_global_phase_in_region(backend, inp):
4142
"""Test global phase in region"""
@@ -52,11 +53,13 @@ def cir():
5253
cir()
5354
return qml.state()
5455

55-
expected = qnn(inp)
5656
observed = qjit(qnn)(inp)
57+
qml.capture.disable()
58+
expected = qnn(inp)
5759
assert np.allclose(expected, observed)
5860

5961

62+
@pytest.mark.usefixtures("use_both_frontend")
6063
def test_global_phase_control(backend):
6164
"""Test global phase controlled"""
6265

0 commit comments

Comments
 (0)