Skip to content

Commit 1e193e3

Browse files
authored
[Capture] handle grad and jacobian (#2078)
**Context:** Handles the `grad` primitive from pennylane and adds a lowering for it. This requires PennyLane PR: PennyLaneAI/pennylane#8357 Once that merges in, we can retarget this branch against the newer pennylane. **Description of the Change:** Instead of translating the pennylane grad prim to the catalyst one, we just register a lowering for the grad primitive. We also have to register a translation rule where we are translating the target jaxpr. **Benefits:** Can handle grad and jacobian calls. **Possible Drawbacks:** **Related GitHub Issues:** [sc-100560]
1 parent 2df7c95 commit 1e193e3

File tree

7 files changed

+365
-198
lines changed

7 files changed

+365
-198
lines changed

.dep-versions

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ enzyme=v0.0.186
1010

1111
# For a custom PL version, update the package version here and at
1212
# 'doc/requirements.txt'
13-
pennylane=0.44.0.dev17
13+
pennylane=0.44.0.dev20
1414

1515
# For a custom LQ/LK version, update the package version here and at
1616
# 'doc/requirements.txt'

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
* Pytree inputs can now be used when program capture is enabled.
2121
[(#2165)](https://github.com/PennyLaneAI/catalyst/pull/2165)
2222

23+
* `qml.grad` and `qml.jacobian` can now be used with `qjit` when program capture is enabled.
24+
[(#2078)](https://github.com/PennyLaneAI/catalyst/pull/2078)
25+
2326
<h3>Breaking changes 💔</h3>
2427

2528
<h3>Deprecations 👋</h3>

doc/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ lxml_html_clean
3333
--extra-index-url https://test.pypi.org/simple/
3434
pennylane-lightning-kokkos==0.43.0
3535
pennylane-lightning==0.43.0
36-
pennylane==0.44.0.dev17
36+
pennylane==0.44.0.dev20

frontend/catalyst/from_plxpr/from_plxpr.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""
1717
# pylint: disable=protected-access
1818

19-
2019
import warnings
2120
from copy import copy
2221
from functools import partial
@@ -28,6 +27,7 @@
2827
from jax.extend.linear_util import wrap_init
2928
from pennylane.capture import PlxprInterpreter, qnode_prim
3029
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
30+
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
3131
from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires
3232
from pennylane.transforms import cancel_inverses as pl_cancel_inverses
3333
from pennylane.transforms import commute_controlled as pl_commute_controlled
@@ -157,6 +157,18 @@ def __init__(self):
157157
super().__init__()
158158

159159

160+
@WorkflowInterpreter.register_primitive(pl_jac_prim)
161+
def handle_grad(self, *args, jaxpr, n_consts, **kwargs):
162+
"""Translate a grad equation."""
163+
f = partial(copy(self).eval, jaxpr, args[:n_consts])
164+
new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:])
165+
166+
new_args = (*new_jaxpr.consts, *args[n_consts:])
167+
return pl_jac_prim.bind(
168+
*new_args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **kwargs
169+
)
170+
171+
160172
# pylint: disable=unused-argument, too-many-arguments
161173
@WorkflowInterpreter.register_primitive(qnode_prim)
162174
def handle_qnode(

frontend/catalyst/jax_primitives.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
VarianceOp,
101101
)
102102
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp
103+
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
103104

104105
from catalyst.compiler import get_lib_path
105106
from catalyst.jax_extras import (
@@ -644,9 +645,12 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
644645
consts = []
645646
offset = len(args) - len(jaxpr.consts)
646647
for i, jax_array_or_tracer in enumerate(jaxpr.consts):
647-
if not isinstance(
648-
jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer
649-
):
648+
if isinstance(jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer):
649+
# There are some cases where this value cannot be converted into
650+
# a jax.numpy.array.
651+
# in that case we get it from the arguments.
652+
consts.append(args[offset + i])
653+
else:
650654
# ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of
651655
# element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to
652656
# cast such constants to numpy array types.
@@ -656,11 +660,6 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
656660
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
657661
constval = StableHLOConstantOp(attr).results
658662
consts.append(constval)
659-
else:
660-
# There are some cases where this value cannot be converted into
661-
# a jax.numpy.array.
662-
# in that case we get it from the arguments.
663-
consts.append(args[offset + i])
664663

665664
method, h, argnums = grad_params.method, grad_params.h, grad_params.expanded_argnums
666665
mlir_ctx = ctx.module_context.context
@@ -673,7 +672,6 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
673672
argnum_numpy = np.array(new_argnums)
674673
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
675674
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums))
676-
677675
symbol_ref = get_symbolref(ctx, func_op)
678676
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
679677
flat_output_types = util.flatten(output_types)
@@ -692,6 +690,30 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
692690
).results
693691

694692

693+
# pylint: disable=too-many-arguments
694+
def _capture_grad_lowering(ctx, *args, argnums, jaxpr, n_consts, method, h, fn, scalar_out):
695+
mlir_ctx = ctx.module_context.context
696+
f64 = ir.F64Type.get(mlir_ctx)
697+
finiteDiffParam = ir.FloatAttr.get(f64, h)
698+
699+
new_argnums = [num + n_consts for num in argnums]
700+
argnum_numpy = np.array(new_argnums)
701+
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
702+
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *new_argnums), fn=fn)
703+
symbol_ref = get_symbolref(ctx, func_op)
704+
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
705+
flat_output_types = util.flatten(output_types)
706+
707+
return GradOp(
708+
flat_output_types,
709+
ir.StringAttr.get(method),
710+
symbol_ref,
711+
mlir.flatten_lowering_ir_args(args),
712+
diffArgIndices=diffArgIndices,
713+
finiteDiffParam=finiteDiffParam,
714+
).results
715+
716+
695717
# value_and_grad
696718
#
697719
@value_and_grad_p.def_impl
@@ -2542,6 +2564,7 @@ def subroutine_lowering(*args, **kwargs):
25422564
(while_p, _while_loop_lowering),
25432565
(for_p, _for_loop_lowering),
25442566
(grad_p, _grad_lowering),
2567+
(pl_jac_prim, _capture_grad_lowering),
25452568
(func_p, _func_lowering),
25462569
(jvp_p, _jvp_lowering),
25472570
(vjp_p, _vjp_lowering),

frontend/catalyst/jax_primitives_utils.py

Lines changed: 64 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,31 @@
2929
from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
3030

3131

32+
def _only_single_expval(call_jaxpr: core.ClosedJaxpr) -> bool:
33+
found_expval = False
34+
for eqn in call_jaxpr.eqns:
35+
name = eqn.primitive.name
36+
if name in {"probs", "counts", "sample"}:
37+
return False
38+
elif name == "expval":
39+
if found_expval:
40+
return False
41+
found_expval = True
42+
return True
43+
44+
45+
def _calculate_diff_method(qn: qml.QNode, call_jaxpr: core.ClosedJaxpr):
46+
diff_method = str(qn.diff_method)
47+
if diff_method != "best":
48+
return diff_method
49+
50+
device_name = getattr(getattr(qn, "device", None), "name", None)
51+
52+
if device_name and "lightning" in device_name and _only_single_expval(call_jaxpr):
53+
return "adjoint"
54+
return "parameter-shift"
55+
56+
3257
def get_call_jaxpr(jaxpr):
3358
"""Extracts the `call_jaxpr` from a JAXPR if it exists.""" ""
3459
for eqn in jaxpr.eqns:
@@ -45,28 +70,36 @@ def get_call_equation(jaxpr):
4570
raise AssertionError("No call_jaxpr found in the JAXPR.")
4671

4772

48-
def lower_jaxpr(ctx, jaxpr, context=None):
73+
def lower_jaxpr(ctx, jaxpr, metadata=None, fn=None):
4974
"""Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p
5075
5176
Args:
5277
ctx: LoweringRuleContext
5378
jaxpr: JAXPR to be lowered
54-
context: additional context to distinguish different FuncOps
79+
metadata: additional metadata to distinguish different FuncOps
80+
fn (Callable | None): the function the jaxpr corresponds to. Used for naming and caching.
5581
5682
Returns:
5783
FuncOp
5884
"""
59-
equation = get_call_equation(jaxpr)
60-
call_jaxpr = equation.params["call_jaxpr"]
61-
callable_ = equation.params.get("fn")
62-
if callable_ is None:
63-
callable_ = equation.params.get("qnode")
64-
pipeline = equation.params.get("pipeline")
65-
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context)
85+
86+
if fn is None or isinstance(fn, qml.QNode):
87+
equation = get_call_equation(jaxpr)
88+
call_jaxpr = equation.params["call_jaxpr"]
89+
pipeline = equation.params.get("pipeline")
90+
callable_ = equation.params.get("fn")
91+
if callable_ is None:
92+
callable_ = equation.params.get("qnode", None)
93+
else:
94+
call_jaxpr = jaxpr
95+
pipeline = ()
96+
callable_ = fn
97+
98+
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, metadata=metadata)
6699

67100

68101
# pylint: disable=too-many-arguments, too-many-positional-arguments
69-
def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False):
102+
def lower_callable(ctx, callable_, call_jaxpr, pipeline=(), metadata=None, public=False):
70103
"""Lowers _callable to MLIR.
71104
72105
If callable_ is a qnode, then we will first create a module, then
@@ -86,33 +119,33 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, publ
86119
if pipeline is None:
87120
pipeline = tuple()
88121

89-
if not isinstance(callable_, qml.QNode):
90-
return get_or_create_funcop(
91-
ctx, callable_, call_jaxpr, pipeline, context=context, public=public
92-
)
93-
94-
return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context)
122+
if isinstance(callable_, qml.QNode):
123+
return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=metadata)
124+
return get_or_create_funcop(
125+
ctx, callable_, call_jaxpr, pipeline, metadata=metadata, public=public
126+
)
95127

96128

97129
# pylint: disable=too-many-arguments, too-many-positional-arguments
98-
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False):
130+
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=None, public=False):
99131
"""Get funcOp from cache, or create it from scratch
100132
101133
Args:
102134
ctx: LoweringRuleContext
103135
callable_: python function
104136
call_jaxpr: jaxpr representing callable_
105-
context: additional context to distinguish different FuncOps
137+
metadata: additional metadata to distinguish different FuncOps
106138
public: whether the visibility should be marked public
107139
108140
Returns:
109141
FuncOp
110142
"""
111-
if context is None:
112-
context = tuple()
113-
key = (callable_, *context, *pipeline)
114-
if func_op := get_cached(ctx, key):
115-
return func_op
143+
if metadata is None:
144+
metadata = tuple()
145+
key = (callable_, *metadata, *pipeline)
146+
if callable_ is not None:
147+
if func_op := get_cached(ctx, key):
148+
return func_op
116149
func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public)
117150
cache(ctx, key, func_op)
118151
return func_op
@@ -135,10 +168,10 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
135168

136169
kwargs = {}
137170
kwargs["ctx"] = ctx.module_context
138-
if not isinstance(callable_, functools.partial):
139-
name = callable_.__name__
140-
else:
171+
if isinstance(callable_, functools.partial):
141172
name = callable_.func.__name__ + ".partial"
173+
else:
174+
name = callable_.__name__
142175

143176
kwargs["name"] = name
144177
kwargs["jaxpr"] = call_jaxpr
@@ -154,28 +187,7 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
154187
if isinstance(callable_, qml.QNode):
155188
func_op.attributes["qnode"] = ir.UnitAttr.get()
156189

157-
diff_method = str(callable_.diff_method)
158-
159-
if diff_method == "best":
160-
161-
def only_single_expval():
162-
found_expval = False
163-
for eqn in call_jaxpr.eqns:
164-
name = eqn.primitive.name
165-
if name in {"probs", "counts", "sample"}:
166-
return False
167-
elif name == "expval":
168-
if found_expval:
169-
return False
170-
found_expval = True
171-
return True
172-
173-
device_name = getattr(getattr(callable_, "device", None), "name", None)
174-
175-
if device_name and "lightning" in device_name and only_single_expval():
176-
diff_method = "adjoint"
177-
else:
178-
diff_method = "parameter-shift"
190+
diff_method = _calculate_diff_method(callable_, call_jaxpr)
179191

180192
func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method)
181193

@@ -195,7 +207,7 @@ def only_single_expval():
195207
return func_op
196208

197209

198-
def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
210+
def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata):
199211
"""A wrapper around lower_qnode_to_funcop that will cache the FuncOp.
200212
201213
Args:
@@ -205,11 +217,11 @@ def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
205217
Returns:
206218
FuncOp
207219
"""
208-
if context is None:
209-
context = tuple()
220+
if metadata is None:
221+
metadata = tuple()
210222
if callable_.static_argnums:
211223
return lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)
212-
key = (callable_, *context, *pipeline)
224+
key = (callable_, *metadata, *pipeline)
213225
if func_op := get_cached(ctx, key):
214226
return func_op
215227
func_op = lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)

0 commit comments

Comments
 (0)