diff --git a/.dep-versions b/.dep-versions index 50e555f78..943b6e869 100644 --- a/.dep-versions +++ b/.dep-versions @@ -1,7 +1,7 @@ # Always update the version check in catalyst.__init__ when changing the JAX version. # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} -jax=8bc0ae6da4f6d334d03be7b76be841d72a56d330 +jax=5712de44e97c455faed1fd45532e821ca66d025a stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d llvm=113f01aa82d055410f22a9d03b3468fa68600589 enzyme=v0.0.203 diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index 0b8e6c30b..14b685010 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -19,7 +19,8 @@ import textwrap import jax -from jax._src.dispatch import jaxpr_replicas +from jax._src import core +from jax._src.interpreters.pxla import _jaxpr_replicas as jaxpr_replicas from jax._src.effects import ordered_effects as jax_ordered_effects from jax._src.interpreters.mlir import _module_name_regex from jax._src.sharding_impls import AxisEnv, ReplicaAxisContext @@ -138,6 +139,11 @@ def custom_lower_jaxpr_to_module( module_name = _module_name_regex.sub("_", module_name) ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) + const_args = core.jaxpr_const_args(jaxpr.jaxpr) + const_arg_avals = [core.shaped_abstractify(c) for c in const_args] + num_const_args = len(const_arg_avals) + in_avals = const_arg_avals + jaxpr.in_avals + # Use main_function=False to preserve the function name (e.g., "jit_func") # instead of renaming it to "main" lower_jaxpr_to_fun( @@ -145,6 +151,8 @@ def custom_lower_jaxpr_to_module( func_name, jaxpr, effects, + num_const_args=num_const_args, + in_avals=in_avals, main_function=False, replicated_args=replicated_args, arg_shardings=arg_shardings, diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index c28b1120e..82f2150e1 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -424,5 +424,43 @@ def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, ** lax._dyn_shape_staging_rule = patched_dyn_shape_staging_rule + # Patch multi_broadcast_in_dim to handle dynamic shapes in JAX 0.7+ + import jax._src.interpreters.mlir as mlir + + def patched_multi_broadcast_in_dim(ctx, ops, ops_avals, out_shape, out_sharding=None): + """Patched version that uses DShapedArray for dynamic shapes.""" + out = [] + for op, op_aval in zip(ops, ops_avals): + op_aval_shape = op_aval.shape + op_aval_sharding = getattr(op_aval, "sharding", None) + + # Use DShapedArray if shape contains dynamic dimensions + if core.is_constant_shape(out_shape): + out_aval = core.ShapedArray(out_shape, op_aval.dtype, sharding=out_sharding) + else: + # DShapedArray doesn't support sharding parameter + out_aval = core.DShapedArray( + out_shape, op_aval.dtype, weak_type=getattr(op_aval, "weak_type", False) + ) + + if core.definitely_equal_shape(op_aval_shape, out_shape): + if out_sharding is None or op_aval_sharding == out_sharding: + out.append(op) + else: + out.append(mlir.lower_with_sharding_in_types(ctx, op, out_aval)) + else: + assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape) + broadcast_dimensions = list( + range(len(out_shape) - len(op_aval_shape), len(out_shape)) + ) + b_out = mlir.broadcast_in_dim( + ctx, op, out_aval, broadcast_dimensions=broadcast_dimensions + ) + b_out = mlir.lower_with_sharding_in_types(ctx, b_out, out_aval) + out.append(b_out) + return out + + mlir.multi_broadcast_in_dim = patched_multi_broadcast_in_dim + except ImportError: pass diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index f6e04dc44..b92bad38c 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -2074,6 +2074,7 @@ def emit_branches(preds, branch_jaxprs, ip): [mlir.ir_constant(c) for c in true_jaxpr.consts], # is never hit in our tests *flat_args_plus_consts, dim_var_values=jax_ctx.dim_var_values, + const_lowering=jax_ctx.const_lowering, ) YieldOp(out) @@ -2094,6 +2095,7 @@ def emit_branches(preds, branch_jaxprs, ip): [mlir.ir_constants(c) for c in otherwise_jaxpr.consts], *flat_args_plus_consts, dim_var_values=jax_ctx.dim_var_values, + const_lowering=jax_ctx.const_lowering, ) YieldOp(out) @@ -2186,6 +2188,7 @@ def _while_loop_lowering( [mlir.ir_constants(c) for c in cond_jaxpr.consts], *params, dim_var_values=jax_ctx.dim_var_values, + const_lowering=jax_ctx.const_lowering, ) pred_extracted = TensorExtractOp(ir.IntegerType.get_signless(1), pred, []).result @@ -2207,6 +2210,7 @@ def _while_loop_lowering( [mlir.ir_constants(c) for c in cond_jaxpr.consts], *params, dim_var_values=jax_ctx.dim_var_values, + const_lowering=jax_ctx.const_lowering, ) YieldOp(out) @@ -2344,6 +2348,7 @@ def _cast_to_index(p): [mlir.ir_constants(c) for c in body_jaxpr.consts], *loop_params, dim_var_values=jax_ctx.dim_var_values, + const_lowering=jax_ctx.const_lowering, ) YieldOp(out) @@ -2474,6 +2479,7 @@ def _adjoint_lowering( [mlir.ir_constants(c) for c in jaxpr.consts], *list(chain(consts, cargs, adjoint_block.arguments)), dim_var_values=jax_ctx.dim_var_values, + const_lowering=jax_ctx.const_lowering, ) QYieldOp([out[-1]]) diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 31f02fe30..1e8a955fd 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -148,6 +148,13 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): # to avoid elimination by the compiler kwargs["main_function"] = public + const_args = core.jaxpr_const_args(call_jaxpr.jaxpr) + const_arg_avals = [core.shaped_abstractify(c) for c in const_args] + num_const_args = len(const_arg_avals) + + kwargs["in_avals"] = const_arg_avals + call_jaxpr.in_avals + kwargs["num_const_args"] = num_const_args + func_op = mlir.lower_jaxpr_to_fun(**kwargs) if isinstance(callable_, qml.QNode):