Skip to content

Commit d83fb10

Browse files
rniczhJerryChen97paul0403dime10maliasadi
authored
Bump jax to 0.7.1 (#2134)
**Context:** **Description of the Change:** 1. API change of `lower_jaxpr_to_fun`, update it in catalyst accordingly 2. API change of `jaxpr_subcomp`, update it in catalyst accordingly 3. `_jaxpr_replicas` has been moved to `jax._src.interpreters.pxla` in jax 4. After jax 0.7.1+ `mlir.multi_broadcast_in_dim` has changed, and break the dynamic shape, patch it to fix the dynamic shape issue **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** --------- Co-authored-by: JerryChen97 <[email protected]> Co-authored-by: Paul <[email protected]> Co-authored-by: David Ittah <[email protected]> Co-authored-by: Ali Asadi <[email protected]>
1 parent 1242748 commit d83fb10

File tree

7 files changed

+65
-6
lines changed

7 files changed

+65
-6
lines changed

.dep-versions

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Always update the version check in catalyst.__init__ when changing the JAX version.
22
# To update JAX version alongside compatible dependency tags, run the following script:
33
# python3 .github/workflows/set_dep_versions.py {JAX_version}
4-
jax=0.7.0
4+
jax=0.7.1
55
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
66
llvm=113f01aa82d055410f22a9d03b3468fa68600589
77
enzyme=v0.0.203

frontend/catalyst/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import jaxlib as _jaxlib
2424

25-
_jaxlib_version = "0.7.0"
25+
_jaxlib_version = "0.7.1"
2626
if _jaxlib.__version__ != _jaxlib_version:
2727
import warnings
2828

frontend/catalyst/jax_extras/lowering.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
import textwrap
2020

2121
import jax
22-
from jax._src.dispatch import jaxpr_replicas
22+
from jax._src import core
2323
from jax._src.effects import ordered_effects as jax_ordered_effects
2424
from jax._src.interpreters.mlir import _module_name_regex
25+
from jax._src.interpreters.pxla import _jaxpr_replicas as jaxpr_replicas
2526
from jax._src.sharding_impls import AxisEnv, ReplicaAxisContext
2627
from jax.extend.core import ClosedJaxpr
2728
from jax.interpreters.mlir import (
@@ -44,7 +45,11 @@
4445

4546
__all__ = ("jaxpr_to_mlir", "custom_lower_jaxpr_to_module")
4647

47-
from catalyst.jax_extras.patches import _no_clean_up_dead_vars, get_aval2
48+
from catalyst.jax_extras.patches import (
49+
_no_clean_up_dead_vars,
50+
get_aval2,
51+
patched_multi_broadcast_in_dim,
52+
)
4853

4954
logger = logging.getLogger(__name__)
5055
logger.addHandler(logging.NullHandler())
@@ -67,6 +72,7 @@ def jaxpr_to_mlir(jaxpr, func_name, arg_names):
6772
with Patcher(
6873
(jax._src.interpreters.partial_eval, "get_aval", get_aval2),
6974
(jax._src.core, "clean_up_dead_vars", _no_clean_up_dead_vars),
75+
(jax._src.interpreters.mlir, "multi_broadcast_in_dim", patched_multi_broadcast_in_dim),
7076
):
7177
nrep = jaxpr_replicas(jaxpr)
7278
effects = jax_ordered_effects.filter_in(jaxpr.effects)
@@ -141,13 +147,20 @@ def custom_lower_jaxpr_to_module(
141147
module_name = _module_name_regex.sub("_", module_name)
142148
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)
143149

150+
const_args = core.jaxpr_const_args(jaxpr.jaxpr)
151+
const_arg_avals = [core.shaped_abstractify(c) for c in const_args]
152+
num_const_args = len(const_arg_avals)
153+
in_avals = const_arg_avals + jaxpr.in_avals
154+
144155
# Use main_function=False to preserve the function name (e.g., "jit_func")
145156
# instead of renaming it to "main"
146157
lower_jaxpr_to_fun(
147158
ctx,
148159
func_name,
149160
jaxpr,
150161
effects,
162+
num_const_args=num_const_args,
163+
in_avals=in_avals,
151164
main_function=False,
152165
replicated_args=replicated_args,
153166
arg_names=arg_names,

frontend/catalyst/jax_extras/patches.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import jax._src.interpreters.partial_eval as pe
2424
from jax._src import config, core, source_info_util
2525
from jax._src.core import JaxprEqnContext, abstractify, standard_vma_rule
26+
from jax._src.interpreters import mlir
2627
from jax._src.interpreters.partial_eval import (
2728
DynamicJaxprTracer,
2829
TracingEqn,
@@ -54,6 +55,7 @@
5455
"patched_make_eqn",
5556
"patched_dyn_shape_staging_rule",
5657
"patched_pjit_staging_rule",
58+
"patched_multi_broadcast_in_dim",
5759
)
5860

5961

@@ -322,6 +324,34 @@ def make_eqn_internal(out_avals_list, out_tracers):
322324
return make_eqn_internal(out_avals, out_tracers)
323325

324326

327+
def patched_multi_broadcast_in_dim(ctx, ops, ops_avals, out_shape, out_sharding=None):
328+
"""Patched version that uses DShapedArray for dynamic shapes."""
329+
out = []
330+
for op, op_aval in zip(ops, ops_avals):
331+
op_aval_shape = op_aval.shape
332+
333+
# Use DShapedArray if shape contains dynamic dimensions
334+
if core.is_constant_shape(out_shape):
335+
out_aval = core.ShapedArray(out_shape, op_aval.dtype, sharding=out_sharding)
336+
else:
337+
# DShapedArray doesn't support sharding parameter
338+
out_aval = core.DShapedArray(
339+
out_shape, op_aval.dtype, weak_type=getattr(op_aval, "weak_type", False)
340+
)
341+
342+
if core.definitely_equal_shape(op_aval_shape, out_shape):
343+
out.append(op)
344+
else:
345+
assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape)
346+
broadcast_dimensions = list(range(len(out_shape) - len(op_aval_shape), len(out_shape)))
347+
b_out = mlir.broadcast_in_dim(
348+
ctx, op, out_aval, broadcast_dimensions=broadcast_dimensions
349+
)
350+
b_out = mlir.lower_with_sharding_in_types(ctx, b_out, out_aval)
351+
out.append(b_out)
352+
return out
353+
354+
325355
def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params):
326356
"""Patched _dyn_shape_staging_rule for dynamic shape handling."""
327357
eqn, out_tracer = trace.make_eqn(args, out_aval, prim, params, core.no_effects, source_info)

frontend/catalyst/jax_primitives.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2243,6 +2243,7 @@ def emit_branches(preds, branch_jaxprs, ip):
22432243
[mlir.ir_constant(c) for c in true_jaxpr.consts], # is never hit in our tests
22442244
*flat_args_plus_consts,
22452245
dim_var_values=jax_ctx.dim_var_values,
2246+
const_lowering=jax_ctx.const_lowering,
22462247
)
22472248

22482249
YieldOp(out)
@@ -2263,6 +2264,7 @@ def emit_branches(preds, branch_jaxprs, ip):
22632264
[mlir.ir_constants(c) for c in otherwise_jaxpr.consts],
22642265
*flat_args_plus_consts,
22652266
dim_var_values=jax_ctx.dim_var_values,
2267+
const_lowering=jax_ctx.const_lowering,
22662268
)
22672269

22682270
YieldOp(out)
@@ -2333,6 +2335,7 @@ def _switch_lowering(
23332335
[mlir.ir_constant(const) for const in branch_jaxpr.consts],
23342336
*flat_args_plus_consts,
23352337
dim_var_values=jax_ctx.dim_var_values,
2338+
const_lowering=jax_ctx.const_lowering,
23362339
)
23372340

23382341
YieldOp(out)
@@ -2348,6 +2351,7 @@ def _switch_lowering(
23482351
[mlir.ir_constant(const) for const in branch_jaxpr.consts],
23492352
*flat_args_plus_consts,
23502353
dim_var_values=jax_ctx.dim_var_values,
2354+
const_lowering=jax_ctx.const_lowering,
23512355
)
23522356

23532357
YieldOp(out)
@@ -2440,6 +2444,7 @@ def _while_loop_lowering(
24402444
[mlir.ir_constants(c) for c in cond_jaxpr.consts],
24412445
*params,
24422446
dim_var_values=jax_ctx.dim_var_values,
2447+
const_lowering=jax_ctx.const_lowering,
24432448
)
24442449

24452450
pred_extracted = TensorExtractOp(ir.IntegerType.get_signless(1), pred, []).result
@@ -2461,6 +2466,7 @@ def _while_loop_lowering(
24612466
[mlir.ir_constants(c) for c in cond_jaxpr.consts],
24622467
*params,
24632468
dim_var_values=jax_ctx.dim_var_values,
2469+
const_lowering=jax_ctx.const_lowering,
24642470
)
24652471

24662472
YieldOp(out)
@@ -2591,6 +2597,7 @@ def _for_loop_lowering(
25912597
[mlir.ir_constants(c) for c in body_jaxpr.consts],
25922598
*loop_params,
25932599
dim_var_values=jax_ctx.dim_var_values,
2600+
const_lowering=jax_ctx.const_lowering,
25942601
)
25952602
YieldOp(out)
25962603

@@ -2720,6 +2727,7 @@ def _adjoint_lowering(
27202727
[mlir.ir_constants(c) for c in jaxpr.consts],
27212728
*list(chain(consts, cargs, adjoint_block.arguments)),
27222729
dim_var_values=jax_ctx.dim_var_values,
2730+
const_lowering=jax_ctx.const_lowering,
27232731
)
27242732

27252733
QYieldOp([out[-1]])
@@ -2734,7 +2742,7 @@ def _adjoint_lowering(
27342742
def adjoint_pass_injector(_op: ir.Operation) -> ir.WalkResult:
27352743
if _op.name == "transform.named_sequence":
27362744
with ir.InsertionPoint.at_block_begin(_op.regions[0].blocks[0]):
2737-
adjoint_lowering_pass_op = ApplyRegisteredPassOp(
2745+
ApplyRegisteredPassOp(
27382746
result=ir.OpaqueType.get("transform", 'op<"builtin.module">'),
27392747
target=_op.regions[0].blocks[0].arguments[0], # just insert at beginning
27402748
pass_name="adjoint-lowering",

frontend/catalyst/jax_primitives_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
178178
kwargs["effects"] = []
179179
kwargs["main_function"] = False
180180

181+
const_args = core.jaxpr_const_args(call_jaxpr.jaxpr)
182+
const_arg_avals = [core.shaped_abstractify(c) for c in const_args]
183+
num_const_args = len(const_arg_avals)
184+
185+
kwargs["in_avals"] = const_arg_avals + call_jaxpr.in_avals
186+
kwargs["num_const_args"] = num_const_args
187+
181188
func_op = mlir.lower_jaxpr_to_fun(**kwargs)
182189
if public:
183190
func_op.attributes["sym_visibility"] = ir.StringAttr.get("public")

frontend/catalyst/jax_tracer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,8 @@ def lower_jaxpr_to_mlir(jaxpr, func_name, arg_names):
713713
# JAX internally calls trace_to_jaxpr_dynamic2 during lowering of nested @jit primitives
714714
# (e.g., in jax.scipy.linalg.expm and jax.scipy.linalg.solve), which triggers two bugs:
715715
# 1. make_eqn signature changed to include out_tracers parameter
716-
# 2. pjit_staging_rule creates JaxprEqn instead of TracingEqn (AssertionError at partial_eval.py:1790)
716+
# 2. pjit_staging_rule creates JaxprEqn instead of TracingEqn
717+
# (AssertionError at partial_eval.py:1790)
717718
with transient_jax_config(
718719
{"jax_dynamic_shapes": True, "jax_use_shardy_partitioner": False}
719720
), Patcher(

0 commit comments

Comments
 (0)