Skip to content
Open
13 changes: 13 additions & 0 deletions .github/workflows/check-catalyst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ jobs:
python3 -m pip install oqc-qcaas-client
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim

- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down Expand Up @@ -558,6 +562,11 @@ jobs:
python3 -m pip install -r requirements.txt
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim


- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down Expand Up @@ -620,6 +629,10 @@ jobs:
python3 -m pip install -r requirements.txt
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim

- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down
10 changes: 10 additions & 0 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim
from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim
from pennylane.capture.primitives import grad_prim
from pennylane.capture.primitives import measure_prim as plxpr_measure_prim
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.measurements import CountsMP
Expand Down Expand Up @@ -200,6 +201,15 @@ def __init__(self):
super().__init__()


@WorkflowInterpreter.register_primitive(grad_prim)
def handle_grad(self, *args, jaxpr, n_consts, **kwargs):
"""Translate a grad equation."""
f = partial(copy(self).eval, jaxpr, args[:n_consts])
new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:]).jaxpr

return grad_prim.bind(*args, jaxpr=new_jaxpr, n_consts=n_consts, **kwargs)


# pylint: disable=unused-argument, too-many-arguments
@WorkflowInterpreter.register_primitive(qnode_prim)
def handle_qnode(
Expand Down
43 changes: 34 additions & 9 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
VarianceOp,
)
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp
from pennylane.capture.primitives import grad_prim as pl_grad_prim

from catalyst.compiler import get_lib_path
from catalyst.jax_extras import (
Expand Down Expand Up @@ -644,9 +645,12 @@
consts = []
offset = len(args) - len(jaxpr.consts)
for i, jax_array_or_tracer in enumerate(jaxpr.consts):
if not isinstance(
jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer
):
if isinstance(jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer):
# There are some cases where this value cannot be converted into
# a jax.numpy.array.
# in that case we get it from the arguments.
consts.append(args[offset + i])
else:
# ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of
# element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to
# cast such constants to numpy array types.
Expand All @@ -656,11 +660,6 @@
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
constval = StableHLOConstantOp(attr).results
consts.append(constval)
else:
# There are some cases where this value cannot be converted into
# a jax.numpy.array.
# in that case we get it from the arguments.
consts.append(args[offset + i])

method, h, argnums = grad_params.method, grad_params.h, grad_params.expanded_argnums
mlir_ctx = ctx.module_context.context
Expand All @@ -673,7 +672,6 @@
argnum_numpy = np.array(new_argnums)
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums))

symbol_ref = get_symbolref(ctx, func_op)
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
Expand All @@ -692,6 +690,32 @@
).results


# pylint: disable=too-many-arguments
def _capture_grad_lowering(ctx, *args, argnum, jaxpr, n_consts, method, h, fn, scalar_out):
mlir_ctx = ctx.module_context.context
if h:
f64 = ir.F64Type.get(mlir_ctx)
finiteDiffParam = ir.FloatAttr.get(f64, h)
else:
finiteDiffParam = None

new_argnums = [num+n_consts for num in argnum]
argnum_numpy = np.array(new_argnums)
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *new_argnums), fn=fn)
symbol_ref = get_symbolref(ctx, func_op)
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)

return GradOp(

Check notice on line 710 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L710

Trailing whitespace (trailing-whitespace)
flat_output_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(args),
diffArgIndices=diffArgIndices,
finiteDiffParam=finiteDiffParam,
).results

# value_and_grad
#
@value_and_grad_p.def_impl
Expand Down Expand Up @@ -2542,6 +2566,7 @@
(while_p, _while_loop_lowering),
(for_p, _for_loop_lowering),
(grad_p, _grad_lowering),
(pl_grad_prim, _capture_grad_lowering),
(func_p, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
Expand Down
115 changes: 63 additions & 52 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@
from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval


def _only_single_expval(call_jaxpr: core.ClosedJaxpr) -> bool:
found_expval = False
for eqn in call_jaxpr.eqns:
name = eqn.primitive.name
if name in {"probs", "counts", "sample"}:
return False
elif name == "expval":
if found_expval:
return False
found_expval = True
return True


def _calculate_diff_method(qn: qml.QNode, call_jaxpr: core.ClosedJaxpr):
diff_method = str(qn.diff_method)
if diff_method != "best":
return diff_method

device_name = getattr(getattr(qn, "device", None), "name", None)

if device_name and "lightning" in device_name and _only_single_expval(call_jaxpr):
return "adjoint"
return "parameter-shift"


def get_call_jaxpr(jaxpr):
"""Extracts the `call_jaxpr` from a JAXPR if it exists.""" ""
for eqn in jaxpr.eqns:
Expand All @@ -45,28 +70,35 @@ def get_call_equation(jaxpr):
raise AssertionError("No call_jaxpr found in the JAXPR.")


def lower_jaxpr(ctx, jaxpr, context=None):
def lower_jaxpr(ctx, jaxpr, metadata=None, fn=None):
"""Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p

Args:
ctx: LoweringRuleContext
jaxpr: JAXPR to be lowered
context: additional context to distinguish different FuncOps
metadata: additional metadata to distinguish different FuncOps

Returns:
FuncOp
"""
equation = get_call_equation(jaxpr)
call_jaxpr = equation.params["call_jaxpr"]
callable_ = equation.params.get("fn")
if callable_ is None:
callable_ = equation.params.get("qnode")
pipeline = equation.params.get("pipeline")
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context)

if fn is None or isinstance(fn, qml.QNode):
equation = get_call_equation(jaxpr)
call_jaxpr = equation.params["call_jaxpr"]
pipeline = equation.params.get("pipeline")
callable_ = equation.params.get("fn")
if callable_ is None:
callable_ = equation.params.get("qnode", None)
else:
call_jaxpr = jaxpr
pipeline = ()
callable_ = fn

return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, metadata=metadata)


# pylint: disable=too-many-arguments, too-many-positional-arguments
def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False):
def lower_callable(ctx, callable_, call_jaxpr, pipeline=(), metadata=None, public=False):
"""Lowers _callable to MLIR.

If callable_ is a qnode, then we will first create a module, then
Expand All @@ -86,33 +118,33 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, publ
if pipeline is None:
pipeline = tuple()

if not isinstance(callable_, qml.QNode):
return get_or_create_funcop(
ctx, callable_, call_jaxpr, pipeline, context=context, public=public
)

return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context)
if isinstance(callable_, qml.QNode):
return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=metadata)
return get_or_create_funcop(
ctx, callable_, call_jaxpr, pipeline, metadata=metadata, public=public
)


# pylint: disable=too-many-arguments, too-many-positional-arguments
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False):
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=None, public=False):
"""Get funcOp from cache, or create it from scratch

Args:
ctx: LoweringRuleContext
callable_: python function
call_jaxpr: jaxpr representing callable_
context: additional context to distinguish different FuncOps
metadata: additional metadata to distinguish different FuncOps
public: whether the visibility should be marked public

Returns:
FuncOp
"""
if context is None:
context = tuple()
key = (callable_, *context, *pipeline)
if func_op := get_cached(ctx, key):
return func_op
if metadata is None:
metadata = tuple()
key = (callable_, *metadata, *pipeline)
if callable_ is not None:
if func_op := get_cached(ctx, key):
return func_op
func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public)
cache(ctx, key, func_op)
return func_op
Expand All @@ -135,10 +167,10 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):

kwargs = {}
kwargs["ctx"] = ctx.module_context
if not isinstance(callable_, functools.partial):
name = callable_.__name__
else:
if isinstance(callable_, functools.partial):
name = callable_.func.__name__ + ".partial"
else:
name = callable_.__name__

kwargs["name"] = name
kwargs["jaxpr"] = call_jaxpr
Expand All @@ -154,28 +186,7 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
if isinstance(callable_, qml.QNode):
func_op.attributes["qnode"] = ir.UnitAttr.get()

diff_method = str(callable_.diff_method)

if diff_method == "best":

def only_single_expval():
found_expval = False
for eqn in call_jaxpr.eqns:
name = eqn.primitive.name
if name in {"probs", "counts", "sample"}:
return False
elif name == "expval":
if found_expval:
return False
found_expval = True
return True

device_name = getattr(getattr(callable_, "device", None), "name", None)

if device_name and "lightning" in device_name and only_single_expval():
diff_method = "adjoint"
else:
diff_method = "parameter-shift"
diff_method = _calculate_diff_method(callable_, call_jaxpr)

func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method)

Expand All @@ -195,7 +206,7 @@ def only_single_expval():
return func_op


def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata):
"""A wrapper around lower_qnode_to_funcop that will cache the FuncOp.

Args:
Expand All @@ -205,11 +216,11 @@ def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
Returns:
FuncOp
"""
if context is None:
context = tuple()
if metadata is None:
metadata = tuple()
if callable_.static_argnums:
return lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)
key = (callable_, *context, *pipeline)
key = (callable_, *metadata, *pipeline)
if func_op := get_cached(ctx, key):
return func_op
func_op = lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)
Expand Down
Loading
Loading