Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=8bc0ae6da4f6d334d03be7b76be841d72a56d330
jax=5712de44e97c455faed1fd45532e821ca66d025a
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203
Expand Down
10 changes: 9 additions & 1 deletion frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import textwrap

import jax
from jax._src.dispatch import jaxpr_replicas
from jax._src import core
from jax._src.interpreters.pxla import _jaxpr_replicas as jaxpr_replicas
from jax._src.effects import ordered_effects as jax_ordered_effects
from jax._src.interpreters.mlir import _module_name_regex
from jax._src.sharding_impls import AxisEnv, ReplicaAxisContext
Expand Down Expand Up @@ -138,13 +139,20 @@ def custom_lower_jaxpr_to_module(
module_name = _module_name_regex.sub("_", module_name)
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)

const_args = core.jaxpr_const_args(jaxpr.jaxpr)
const_arg_avals = [core.shaped_abstractify(c) for c in const_args]
num_const_args = len(const_arg_avals)
in_avals = const_arg_avals + jaxpr.in_avals

# Use main_function=False to preserve the function name (e.g., "jit_func")
# instead of renaming it to "main"
lower_jaxpr_to_fun(
ctx,
func_name,
jaxpr,
effects,
num_const_args=num_const_args,
in_avals=in_avals,
main_function=False,
replicated_args=replicated_args,
arg_shardings=arg_shardings,
Expand Down
38 changes: 38 additions & 0 deletions frontend/catalyst/jax_extras/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@

# TODO: Remove this patch when JAX/PennyLane are updated to use the new JAX 0.7+ API.
# pylint: disable=protected-access
def patch_primitives():

Check notice on line 238 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L238

Too many statements (91/50) (too-many-statements)
"""Patch PennyLane/JAX primitives to make them compatible with JAX 0.7+.

JAX 0.7+ requires all primitive parameters to be hashable, but PennyLane
Expand All @@ -250,10 +250,10 @@
return value

try:
from pennylane.capture.primitives import ctrl_transform_prim

Check notice on line 253 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L253

Import outside toplevel (pennylane.capture.primitives.ctrl_transform_prim) (import-outside-toplevel)
from pennylane.capture.primitives import cond_prim

Check notice on line 254 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L254

Import outside toplevel (pennylane.capture.primitives.cond_prim) (import-outside-toplevel)
from pennylane.ops.op_math.controlled import Controlled

Check notice on line 255 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L255

Import outside toplevel (pennylane.ops.op_math.controlled.Controlled) (import-outside-toplevel)
from jax._src.interpreters import partial_eval as pe

Check notice on line 256 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L256

Import outside toplevel (jax._src.interpreters.partial_eval) (import-outside-toplevel)

original_ctrl_bind = ctrl_transform_prim.bind
original_cond_bind = cond_prim.bind
Expand Down Expand Up @@ -297,14 +297,14 @@

# patch DynamicJaxprTrace members: makevar and getvar
try:
from jax._src.interpreters import partial_eval as pe

Check notice on line 300 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L300

Import outside toplevel (jax._src.interpreters.partial_eval) (import-outside-toplevel)

def patched_makevar(self, tracer):
assert tracer.val is None, "a jaxpr variable must be created only once per tracer"
tracer.val = self.frame.newvar(tracer.aval)
return tracer.val

def patched_getvar(self, tracer): # pylint: disable=unused-argument

Check notice on line 307 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L307

Either all return statements in a function should return an expression, or none of them should. (inconsistent-return-statements)
if var := tracer.val:
return var
raise jax.core.escaped_tracer_error(tracer)
Expand All @@ -315,15 +315,15 @@
# Patch make_eqn to handle both single aval and list of avals
# original_make_eqn = pe.DynamicJaxprTrace.make_eqn

import jax._src.source_info_util as source_info_util

Check notice on line 318 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L318

Import outside toplevel (jax._src.source_info_util) (import-outside-toplevel)

Check notice on line 318 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L318

Use 'from jax._src import source_info_util' instead (consider-using-from-import)
from jax._src.core import JaxprEqnContext, Var

Check notice on line 319 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L319

Import outside toplevel (jax._src.core.JaxprEqnContext, jax._src.core.Var) (import-outside-toplevel)
from jax._src.interpreters.partial_eval import DynamicJaxprTracer

Check notice on line 320 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L320

Import outside toplevel (jax._src.interpreters.partial_eval.DynamicJaxprTracer) (import-outside-toplevel)
from jax._src.interpreters.partial_eval import TracingEqn

Check notice on line 321 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L321

Import outside toplevel (jax._src.interpreters.partial_eval.TracingEqn) (import-outside-toplevel)
from jax._src.interpreters.partial_eval import compute_on

Check notice on line 322 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L322

Import outside toplevel (jax._src.interpreters.partial_eval.compute_on) (import-outside-toplevel)
from jax._src import config

Check notice on line 323 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L323

Import outside toplevel (jax._src.config) (import-outside-toplevel)
from jax._src.interpreters.partial_eval import xla_metadata_lib

Check notice on line 324 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L324

Import outside toplevel (jax._src.interpreters.partial_eval.xla_metadata_lib) (import-outside-toplevel)

def internal_make_eqn(

Check notice on line 326 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L326

Too many positional arguments (9/5) (too-many-positional-arguments)
self,
in_tracers,
out_avals,
Expand All @@ -349,7 +349,7 @@
eqn = TracingEqn(in_tracers, outvars, primitive, params, effects, source_info, ctx)
return eqn, out_tracers
else:
outvars = list(map(lambda aval: self.frame.newvar(aval), out_avals))

Check notice on line 352 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L352

Lambda may not be necessary (unnecessary-lambda)
if config.enable_checks.value:
assert all(isinstance(x, DynamicJaxprTracer) for x in in_tracers)
assert all(isinstance(v, Var) for v in outvars)
Expand All @@ -362,7 +362,7 @@

pe.DynamicJaxprTrace.make_eqn_internal = internal_make_eqn

def patched_make_eqn(

Check notice on line 365 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L365

Too many positional arguments (9/5) (too-many-positional-arguments)
self,
in_tracers,
out_avals,
Expand Down Expand Up @@ -412,8 +412,8 @@

pe.JaxprStackFrame.eqns = property(patched_eqns_getter, patched_eqns_setter)

import jax._src.lax.lax as lax

Check notice on line 415 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L415

Use 'from jax._src.lax import lax' instead (consider-using-from-import)

Check notice on line 415 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L415

Import outside toplevel (jax._src.lax.lax) (import-outside-toplevel)
import jax._src.core as core

Check notice on line 416 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L416

Import outside toplevel (jax._src.core) (import-outside-toplevel)

Check notice on line 416 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L416

Use 'from jax._src import core' instead (consider-using-from-import)

def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params):
eqn, out_tracer = trace.make_eqn(
Expand All @@ -424,5 +424,43 @@

lax._dyn_shape_staging_rule = patched_dyn_shape_staging_rule

# Patch multi_broadcast_in_dim to handle dynamic shapes in JAX 0.7+
import jax._src.interpreters.mlir as mlir

Check notice on line 428 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L428

Use 'from jax._src.interpreters import mlir' instead (consider-using-from-import)

Check notice on line 428 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L428

Import outside toplevel (jax._src.interpreters.mlir) (import-outside-toplevel)

def patched_multi_broadcast_in_dim(ctx, ops, ops_avals, out_shape, out_sharding=None):
"""Patched version that uses DShapedArray for dynamic shapes."""
out = []
for op, op_aval in zip(ops, ops_avals):
op_aval_shape = op_aval.shape
op_aval_sharding = getattr(op_aval, "sharding", None)

# Use DShapedArray if shape contains dynamic dimensions
if core.is_constant_shape(out_shape):
out_aval = core.ShapedArray(out_shape, op_aval.dtype, sharding=out_sharding)
else:
# DShapedArray doesn't support sharding parameter
out_aval = core.DShapedArray(
out_shape, op_aval.dtype, weak_type=getattr(op_aval, "weak_type", False)
)

if core.definitely_equal_shape(op_aval_shape, out_shape):
if out_sharding is None or op_aval_sharding == out_sharding:
out.append(op)
else:
out.append(mlir.lower_with_sharding_in_types(ctx, op, out_aval))
else:
assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape)
broadcast_dimensions = list(
range(len(out_shape) - len(op_aval_shape), len(out_shape))
)
b_out = mlir.broadcast_in_dim(
ctx, op, out_aval, broadcast_dimensions=broadcast_dimensions
)
b_out = mlir.lower_with_sharding_in_types(ctx, b_out, out_aval)
out.append(b_out)
return out

mlir.multi_broadcast_in_dim = patched_multi_broadcast_in_dim

except ImportError:
pass

Check notice on line 466 in frontend/catalyst/jax_extras/patches.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_extras/patches.py#L238-L466

Complex Method
6 changes: 6 additions & 0 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,7 @@
Note: This only contains the boolean measurement result,
not the qubit output
"""
from catalyst.jax_extras.tracing import new_inner_tracer

Check notice on line 1410 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L1410

Import outside toplevel (catalyst.jax_extras.tracing.new_inner_tracer) (import-outside-toplevel)

# Define output shapes
out_shapes = (core.ShapedArray((), bool), qubit.aval)
Expand Down Expand Up @@ -2074,6 +2074,7 @@
[mlir.ir_constant(c) for c in true_jaxpr.consts], # is never hit in our tests
*flat_args_plus_consts,
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

YieldOp(out)
Expand All @@ -2094,6 +2095,7 @@
[mlir.ir_constants(c) for c in otherwise_jaxpr.consts],
*flat_args_plus_consts,
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

YieldOp(out)
Expand Down Expand Up @@ -2186,6 +2188,7 @@
[mlir.ir_constants(c) for c in cond_jaxpr.consts],
*params,
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

pred_extracted = TensorExtractOp(ir.IntegerType.get_signless(1), pred, []).result
Expand All @@ -2207,6 +2210,7 @@
[mlir.ir_constants(c) for c in cond_jaxpr.consts],
*params,
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

YieldOp(out)
Expand Down Expand Up @@ -2344,6 +2348,7 @@
[mlir.ir_constants(c) for c in body_jaxpr.consts],
*loop_params,
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

YieldOp(out)
Expand Down Expand Up @@ -2474,6 +2479,7 @@
[mlir.ir_constants(c) for c in jaxpr.consts],
*list(chain(consts, cargs, adjoint_block.arguments)),
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

QYieldOp([out[-1]])
Expand Down
7 changes: 7 additions & 0 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
# to avoid elimination by the compiler
kwargs["main_function"] = public

const_args = core.jaxpr_const_args(call_jaxpr.jaxpr)
const_arg_avals = [core.shaped_abstractify(c) for c in const_args]
num_const_args = len(const_arg_avals)

kwargs["in_avals"] = const_arg_avals + call_jaxpr.in_avals
kwargs["num_const_args"] = num_const_args

func_op = mlir.lower_jaxpr_to_fun(**kwargs)

if isinstance(callable_, qml.QNode):
Expand Down
Loading