Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim
from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim
from pennylane.capture.primitives import measure_prim as plxpr_measure_prim
from pennylane.capture.primitives import grad_prim
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.measurements import CountsMP
from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires
Expand Down Expand Up @@ -185,12 +186,19 @@
self._pass_pipeline = []
self.init_qreg = None

# Compiler options for the new decomposition system

Check notice on line 189 in frontend/catalyst/from_plxpr/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/from_plxpr.py#L189

Missing function or method docstring (missing-function-docstring)
self.requires_decompose_lowering = False
self.decompose_tkwargs = {} # target gateset

super().__init__()

@WorkflowInterpreter.register_primitive(grad_prim)
def handle_grad(self, *args, jaxpr, n_consts, **kwargs):

f = partial(copy(self).eval, jaxpr, args[:n_consts])
new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:]).jaxpr

return grad_prim.bind(*args, jaxpr=new_jaxpr, n_consts=n_consts, **kwargs)

# pylint: disable=unused-argument, too-many-arguments
@WorkflowInterpreter.register_primitive(qnode_prim)
Expand Down
42 changes: 35 additions & 7 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@
)
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp

from pennylane.capture.primitives import grad_prim as pl_grad_prim

from catalyst.compiler import get_lib_path
from catalyst.jax_extras import (
ClosedJaxpr,
Expand Down Expand Up @@ -644,9 +646,14 @@
consts = []
offset = len(args) - len(jaxpr.consts)
for i, jax_array_or_tracer in enumerate(jaxpr.consts):
if not isinstance(
if isinstance(
jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer
):
# There are some cases where this value cannot be converted into
# a jax.numpy.array.
# in that case we get it from the arguments.
consts.append(args[offset + i])
else:
# ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of
# element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to
# cast such constants to numpy array types.
Expand All @@ -656,11 +663,6 @@
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
constval = StableHLOConstantOp(attr).results
consts.append(constval)
else:
# There are some cases where this value cannot be converted into
# a jax.numpy.array.
# in that case we get it from the arguments.
consts.append(args[offset + i])

method, h, argnums = grad_params.method, grad_params.h, grad_params.expanded_argnums
mlir_ctx = ctx.module_context.context
Expand All @@ -673,7 +675,6 @@
argnum_numpy = np.array(new_argnums)
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums))

symbol_ref = get_symbolref(ctx, func_op)
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
Expand All @@ -692,6 +693,32 @@
).results


def _capture_grad_lowering(ctx, *args, argnum, jaxpr, n_consts, method, h, fn, scalar_out):

Check notice on line 696 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L696

Too many arguments (8/5) (too-many-arguments)
mlir_ctx = ctx.module_context.context
if h:
f64 = ir.F64Type.get(mlir_ctx)
finiteDiffParam = ir.FloatAttr.get(f64, h)
else:
finiteDiffParam = None

argnum_numpy = np.array(argnum)
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnum), fn=fn)
symbol_ref = get_symbolref(ctx, func_op)
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)

Check notice on line 710 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L710

Trailing whitespace (trailing-whitespace)
return GradOp(
flat_output_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(args),
diffArgIndices=diffArgIndices,
finiteDiffParam=finiteDiffParam,
).results



# value_and_grad
#
@value_and_grad_p.def_impl
Expand Down Expand Up @@ -2542,6 +2569,7 @@
(while_p, _while_loop_lowering),
(for_p, _for_loop_lowering),
(grad_p, _grad_lowering),
(pl_grad_prim, _capture_grad_lowering),
(func_p, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
Expand Down
115 changes: 62 additions & 53 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,29 @@
from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval


def _only_single_expval(call_jaxpr : core.ClosedJaxpr) -> bool:
found_expval = False
for eqn in call_jaxpr.eqns:
name = eqn.primitive.name
if name in {"probs", "counts", "sample"}:
return False
elif name == "expval":
if found_expval:
return False
found_expval = True
return True

def _calculate_diff_method(qn: qml.QNode, call_jaxpr: core.ClosedJaxpr):
diff_method = str(qn.diff_method)
if diff_method != "best":
return diff_method

Check notice on line 48 in frontend/catalyst/jax_primitives_utils.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives_utils.py#L48

Trailing whitespace (trailing-whitespace)
device_name = getattr(getattr(qn, "device", None), "name", None)

if device_name and "lightning" in device_name and _only_single_expval(call_jaxpr):
return "adjoint"
return "parameter-shift"

def get_call_jaxpr(jaxpr):
"""Extracts the `call_jaxpr` from a JAXPR if it exists.""" ""
for eqn in jaxpr.eqns:
Expand All @@ -45,28 +68,35 @@
raise AssertionError("No call_jaxpr found in the JAXPR.")


def lower_jaxpr(ctx, jaxpr, context=None):
def lower_jaxpr(ctx, jaxpr, metadata=None, fn=None):
"""Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p
Args:
ctx: LoweringRuleContext
jaxpr: JAXPR to be lowered
context: additional context to distinguish different FuncOps
metadata: additional metadata to distinguish different FuncOps
Returns:
FuncOp
"""
equation = get_call_equation(jaxpr)
call_jaxpr = equation.params["call_jaxpr"]
callable_ = equation.params.get("fn")
if callable_ is None:
callable_ = equation.params.get("qnode")
pipeline = equation.params.get("pipeline")
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context)

if fn is None or isinstance(fn, qml.QNode):
equation = get_call_equation(jaxpr)
call_jaxpr = equation.params["call_jaxpr"]
pipeline = equation.params.get("pipeline")
callable_ = equation.params.get("fn")
if callable_ is None:
callable_ = equation.params.get("qnode", None)
else:
call_jaxpr = jaxpr
pipeline = ()
callable_ = fn

Check notice on line 94 in frontend/catalyst/jax_primitives_utils.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives_utils.py#L94

Trailing whitespace (trailing-whitespace)
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, metadata=metadata)


# pylint: disable=too-many-arguments, too-many-positional-arguments
def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False):
def lower_callable(ctx, callable_, call_jaxpr, pipeline=(), metadata=None, public=False):
"""Lowers _callable to MLIR.
If callable_ is a qnode, then we will first create a module, then
Expand All @@ -75,7 +105,7 @@
add more than one FuncOps. This depends on the contents of call_jaxpr.
Args:
ctx: LoweringRuleContext
ctx: LoweringRulemetadata
callable_: python function
call_jaxpr: jaxpr representing callable_
public: whether the visibility should be marked public
Expand All @@ -86,33 +116,33 @@
if pipeline is None:
pipeline = tuple()

if not isinstance(callable_, qml.QNode):
return get_or_create_funcop(
ctx, callable_, call_jaxpr, pipeline, context=context, public=public
)

return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context)
if isinstance(callable_, qml.QNode):
return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=metadata)
return get_or_create_funcop(
ctx, callable_, call_jaxpr, pipeline, metadata=metadata, public=public
)


# pylint: disable=too-many-arguments, too-many-positional-arguments
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False):
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=None, public=False):
"""Get funcOp from cache, or create it from scratch
Args:
ctx: LoweringRuleContext
callable_: python function
call_jaxpr: jaxpr representing callable_
context: additional context to distinguish different FuncOps
metadata: additional metadata to distinguish different FuncOps
public: whether the visibility should be marked public
Returns:
FuncOp
"""
if context is None:
context = tuple()
key = (callable_, *context, *pipeline)
if func_op := get_cached(ctx, key):
return func_op
if metadata is None:
metadata = tuple()
key = (callable_, *metadata, *pipeline)
if callable_ is not None:
if func_op := get_cached(ctx, key):
return func_op
func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public)
cache(ctx, key, func_op)
return func_op
Expand All @@ -135,10 +165,10 @@

kwargs = {}
kwargs["ctx"] = ctx.module_context
if not isinstance(callable_, functools.partial):
name = callable_.__name__
else:
if isinstance(callable_, functools.partial):
name = callable_.func.__name__ + ".partial"
else:
name = callable_.__name__

kwargs["name"] = name
kwargs["jaxpr"] = call_jaxpr
Expand All @@ -154,28 +184,7 @@
if isinstance(callable_, qml.QNode):
func_op.attributes["qnode"] = ir.UnitAttr.get()

diff_method = str(callable_.diff_method)

if diff_method == "best":

def only_single_expval():
found_expval = False
for eqn in call_jaxpr.eqns:
name = eqn.primitive.name
if name in {"probs", "counts", "sample"}:
return False
elif name == "expval":
if found_expval:
return False
found_expval = True
return True

device_name = getattr(getattr(callable_, "device", None), "name", None)

if device_name and "lightning" in device_name and only_single_expval():
diff_method = "adjoint"
else:
diff_method = "parameter-shift"
diff_method = _calculate_diff_method(callable_, call_jaxpr)

func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method)

Expand All @@ -195,7 +204,7 @@
return func_op


def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata):
"""A wrapper around lower_qnode_to_funcop that will cache the FuncOp.
Args:
Expand All @@ -205,11 +214,11 @@
Returns:
FuncOp
"""
if context is None:
context = tuple()
if metadata is None:
metadata = tuple()
if callable_.static_argnums:
return lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)
key = (callable_, *context, *pipeline)
key = (callable_, *metadata, *pipeline)
if func_op := get_cached(ctx, key):
return func_op
func_op = lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)
Expand Down
Loading