Skip to content

Commit 8301c30

Browse files
yashk2810mattjj
authored andcommitted
Make changes to shard_map to prepare for setting varying_axes_in_types to True.
The main changes here are: * Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead. * Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`. * Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`. * Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on. Co-authored-by: Matthew Johnson <[email protected]> PiperOrigin-RevId: 745276474
1 parent b4629c2 commit 8301c30

File tree

12 files changed

+355
-195
lines changed

12 files changed

+355
-195
lines changed

jax/_src/checkify.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -966,13 +966,15 @@ def shard_map_error_check(
966966
new_vals_in = [*err_vals, *vals_in]
967967
in_avals = list(map(core.get_aval, new_vals_in))
968968
auto = kwargs.get('auto')
969+
check_rep = kwargs.get('check_rep')
969970
for i, v in enumerate(in_avals):
970971
if not (sharder := core.shard_aval_handlers.get(type(v))):
971972
raise ValueError(f'Unsupported aval type: {type(v)}')
972-
in_avals[i] = sharder(mesh, auto, new_in_names[i], v)
973+
in_avals[i] = sharder(mesh, auto, check_rep, new_in_names[i], v)
973974

974975
with (shard_map._extend_axis_env(mesh, auto),
975-
mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))):
976+
mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto)),
977+
config._check_rep(check_rep)):
976978
# jaxpr to checked_jaxpr
977979
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
978980
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
@@ -985,7 +987,7 @@ def expand_errors_leading_dim(*xs):
985987
errs = [lax.expand_dims(e, [0]) for e in errs]
986988
return *errs, *outs
987989

988-
with core.extend_axis_env_nd(mesh.shape.items()):
990+
with core.extend_axis_env_nd(mesh.shape.items()), config._check_rep(check_rep):
989991
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
990992
lu.wrap_init(expand_errors_leading_dim,
991993
debug_info=checked_jaxpr.jaxpr.debug_info),

jax/_src/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def trace_context():
240240
disable_jit.value,
241241
debug_key_reuse.value,
242242
jax_xla_profile_version.value,
243+
_check_rep.value,
243244
# Technically this affects jaxpr->stablehlo lowering, not tracing.
244245
hlo_source_file_canonicalization_regex.value,
245246
pgle_profiling_runs.value,
@@ -1099,6 +1100,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
10991100
' transpose rewrite machinery in shard_map'),
11001101
include_in_jit_key=True)
11011102

1103+
# TODO make it so people don't use this, this is internal...
1104+
_check_rep = bool_state(
1105+
name='check_rep',
1106+
default=False,
1107+
help='internal implementation detail of shard_map, DO NOT USE',
1108+
include_in_jit_key=True)
1109+
11021110
softmax_custom_jvp = bool_state(
11031111
name='jax_softmax_custom_jvp',
11041112
default=False,

jax/_src/core.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,12 +1895,15 @@ def str_short_aval(shape, dtype, mesh, spec, vma,
18951895
return f'{dt_str}[{shapestr}]{vma}{mesh_axes}'
18961896

18971897
def get_vma(vma, mesh):
1898+
assert isinstance(vma, frozenset)
1899+
return vma
1900+
if mesh.empty:
1901+
return vma
18981902
for i in vma:
18991903
if mesh._name_to_type[i] != AxisType.Manual:
19001904
raise ValueError(
19011905
"Axes mentioned in `vma` field of ShapedArray should"
19021906
f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}")
1903-
assert isinstance(vma, frozenset)
19041907
return vma
19051908

19061909
class ShapedArray(UnshapedArray):
@@ -1994,6 +1997,8 @@ def primal_dtype_to_tangent_dtype(primal_dtype):
19941997
def standard_insert_pbroadcast(*args):
19951998
if not config.varying_axes_in_types.value:
19961999
return args
2000+
if not config._check_rep.value:
2001+
return args
19972002
if not args:
19982003
return args
19992004
# TODO(yashkatariya): Move pbroadcast out of shard_map
@@ -2005,6 +2010,10 @@ def standard_insert_pbroadcast(*args):
20052010
if out_vma - src else arg for arg, src in zip(args, in_vma)]
20062011

20072012
def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]:
2013+
if not config.varying_axes_in_types.value:
2014+
return frozenset()
2015+
if not config._check_rep.value:
2016+
return frozenset()
20082017
avals = tuple(a for a in avals if a is not abstract_token)
20092018
if not avals:
20102019
return frozenset()
@@ -2567,9 +2576,9 @@ def unmapped_aval(size: AxisSize, axis: int | None,
25672576

25682577
def _map_shaped_array(
25692578
size: int, axis: int | None, aval: ShapedArray) -> ShapedArray:
2570-
assert axis is None or aval.shape[axis] == size
2571-
# TODO: Extend the named shape
2572-
if axis is None: return aval
2579+
# assert axis is None or aval.shape[axis] == size
2580+
if axis is None:
2581+
return aval
25732582
sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis))
25742583
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
25752584
weak_type=aval.weak_type, sharding=sharding, vma=aval.vma)

jax/_src/interpreters/batching.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,
408408
@property
409409
def aval(self):
410410
aval = core.get_aval(self.val)
411+
if self._trace.axis_data.spmd_name is not None:
412+
if config._check_rep.value and config.varying_axes_in_types.value:
413+
aval = aval.update(
414+
vma=aval.vma - frozenset(self._trace.axis_data.spmd_name))
411415
if self.batch_dim is not_mapped:
412416
return aval
413417
elif type(self.batch_dim) is int:
@@ -771,10 +775,17 @@ def _batch_jaxpr2(
771775
handle_ragged(closed_jaxpr.in_avals, dim, aval)
772776
if isinstance(dim, RaggedAxis) else (dim, aval)
773777
for dim, aval in zip(in_axes, closed_jaxpr.in_avals)])
774-
avals_in2 = [core.unmapped_aval(axis_data.size, b, aval,
775-
axis_data.explicit_mesh_axis)
776-
if b is not not_mapped else aval
777-
for aval, b in unsafe_zip(avals_in, in_axes2)]
778+
avals_in2 = []
779+
for aval, b in unsafe_zip(avals_in, in_axes2):
780+
if b is not_mapped:
781+
avals_in2.append(aval)
782+
else:
783+
aval = core.unmapped_aval(
784+
axis_data.size, b, aval, axis_data.explicit_mesh_axis)
785+
if axis_data.spmd_name is not None:
786+
if config._check_rep.value and config.varying_axes_in_types.value:
787+
aval = aval.update(vma=aval.vma | frozenset(axis_data.spmd_name)) # type: ignore
788+
avals_in2.append(aval)
778789
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
779790
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
780791

@@ -1111,8 +1122,16 @@ def broadcast(x, sz, axis, mesh_axis=None):
11111122
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
11121123
# out how to ensure jaxpr arguments always have the context mesh.
11131124
with mesh_lib.use_abstract_mesh(sharding.mesh):
1114-
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims,
1115-
out_sharding=sharding)
1125+
x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
1126+
if config._check_rep.value and config.varying_axes_in_types.value:
1127+
# TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026
1128+
spmd_names = core.get_axis_env().spmd_axis_names
1129+
if len(spmd_names) > 1:
1130+
raise NotImplementedError
1131+
if spmd_names:
1132+
from jax.experimental.shard_map import pbroadcast
1133+
x = pbroadcast(x, tuple(spmd_names))
1134+
return x
11161135

11171136
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
11181137
if dst == jumble_axis:

jax/_src/interpreters/partial_eval.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,24 @@ def partial_eval_wrapper_nounits(
501501
store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env))
502502
return (*out_consts, *res)
503503

504+
@lu.transformation_with_aux2
505+
def partial_eval_wrapper_nounits2(
506+
f: Callable,
507+
store: lu.Store,
508+
in_knowns: Sequence[bool],
509+
in_avals: Sequence[AbstractValue],
510+
*in_consts: Any):
511+
in_avals_, in_consts_ = iter(in_avals), iter(in_consts)
512+
in_pvals = [PartialVal.known(next(in_consts_)) if known else
513+
PartialVal.unknown(next(in_avals_)) for known in in_knowns]
514+
sentinel = object()
515+
assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel
516+
jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals)
517+
out_knowns, _, out_consts = partition_pvals(out_pvals)
518+
res_avals = [core.typeof(r) for r in res]
519+
store.store((*maybe_fwds, out_knowns, res_avals, jaxpr, env))
520+
return (*out_consts, *res)
521+
504522
custom_partial_eval_rules: dict[Primitive, Callable] = {}
505523
call_partial_eval_rules: dict[Primitive, Callable] = {}
506524
call_param_updaters: dict[Primitive, Callable] = {}

jax/_src/lax/control_flow/loops.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,10 @@ def _split_leading(sz, x):
550550
def _concat(a, b): return lax.concatenate([a, b], 0)
551551

552552
def _empty_array(prefix, length_spec, aval):
553+
from jax.experimental.shard_map import pbroadcast
553554
sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec))
554-
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
555-
out_sharding=sharding)
555+
empty = pbroadcast(lax.empty(aval.dtype), tuple(aval.vma))
556+
return lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding)
556557

557558
eval_jaxpr_p = core.Primitive('eval_jaxpr')
558559
eval_jaxpr_p.multiple_results = True
@@ -2248,12 +2249,7 @@ def _batch_and_remainder(x, batch_size: int):
22482249
return scan_tree, remainder_tree
22492250

22502251
@api_boundary
2251-
def map(
2252-
f,
2253-
xs,
2254-
*,
2255-
batch_size: int | None = None,
2256-
):
2252+
def map(f, xs, *, batch_size: int | None = None):
22572253
"""Map a function over leading array axes.
22582254
22592255
Like Python's builtin map, except inputs and outputs are in the form of

jax/_src/lax/parallel.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def psum(x, axis_name, *, axis_index_groups=None):
117117
"""
118118
if not isinstance(axis_name, (tuple, list)):
119119
axis_name = (axis_name,)
120+
if not axis_name:
121+
return x
120122
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
121123
raise ValueError("axis_index_groups only supported for sums over just named axes")
122124
_validate_reduce_axis_index_groups(axis_index_groups)
@@ -141,7 +143,7 @@ def pos_reduce(x):
141143
size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes])
142144
out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves)
143145
else:
144-
if config.varying_axes_in_types.value:
146+
if config.varying_axes_in_types.value and config._check_rep.value:
145147
out_flat = bind_psum2_p(leaves, axes=tuple(axis_name),
146148
axis_index_groups=axis_index_groups)
147149
else:
@@ -828,6 +830,9 @@ def _psum2_abstract_eval(name, *args, axes, axis_index_groups):
828830
if not config.varying_axes_in_types.value:
829831
return psum_p.abstract_eval(
830832
*args, axes=axes, axis_index_groups=axis_index_groups)
833+
if not config._check_rep.value:
834+
return psum_p.abstract_eval(
835+
*args, axes=axes, axis_index_groups=axis_index_groups)
831836

832837
assert isinstance(axes, tuple)
833838
_check_axis_names(axes)
@@ -863,6 +868,9 @@ def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
863868
if not config.varying_axes_in_types.value:
864869
return _allreduce_effectful_abstract_eval(
865870
*args, axes=axes, axis_index_groups=axis_index_groups)
871+
if not config._check_rep.value:
872+
return _allreduce_effectful_abstract_eval(
873+
*args, axes=axes, axis_index_groups=axis_index_groups)
866874
return _psum2_abstract_eval(name, *args, axes=axes,
867875
axis_index_groups=axis_index_groups)
868876

@@ -1411,6 +1419,8 @@ def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in,
14111419
def insert_collective_pbroadcast(axis_name, x):
14121420
if not config.varying_axes_in_types.value:
14131421
return x
1422+
if not config._check_rep.value:
1423+
return x
14141424

14151425
from jax.experimental import shard_map
14161426
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
@@ -1546,6 +1556,8 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
15461556
def collective_vma_rule(prim_name, axis_name, x_aval):
15471557
if not config.varying_axes_in_types.value:
15481558
return frozenset()
1559+
if not config._check_rep.value:
1560+
return frozenset()
15491561
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
15501562
if any(a not in x_aval.vma for a in axis_name):
15511563
raise ValueError(
@@ -1912,7 +1924,8 @@ def _axis_index_effectful_abstract_eval(*, axis_name):
19121924
mesh = get_abstract_mesh()
19131925
sharding = NamedSharding(mesh, P())
19141926
vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset())
1915-
if config.varying_axes_in_types.value else frozenset())
1927+
if config.varying_axes_in_types.value and config._check_rep.value
1928+
else frozenset())
19161929
return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect
19171930

19181931
def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name):

jax/_src/lax/slicing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def dynamic_slice(
173173
else:
174174
dynamic_sizes = []
175175
static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore
176+
operand, *start_indices = core.standard_insert_pbroadcast(
177+
operand, *start_indices)
176178
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
177179
slice_sizes=tuple(static_sizes))
178180

@@ -234,7 +236,8 @@ def dynamic_update_slice(
234236
"""
235237
start_indices = _dynamic_slice_indices(
236238
operand, start_indices, allow_negative_indices)
237-
operand, update = core.standard_insert_pbroadcast(operand, update)
239+
operand, update, *start_indices = core.standard_insert_pbroadcast(
240+
operand, update, *start_indices)
238241
return dynamic_update_slice_p.bind(operand, update, *start_indices)
239242

240243

jax/_src/state/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,15 +456,15 @@ def shaped_array_ref(
456456
shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef:
457457
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type))
458458

459-
def _shard_ref(mesh, auto, names, ref_aval: AbstractRef):
459+
def _shard_ref(mesh, auto, check_rep, names, ref_aval: AbstractRef):
460460
del mesh
461461
if names:
462462
# Can't actually shard a ref, can only close over it.
463463
raise NotImplementedError("Can't shard a Ref.")
464464
return ref_aval
465465
core.shard_aval_handlers[AbstractRef] = _shard_ref
466466

467-
def _unshard_ref(mesh, names, ref_aval: AbstractRef):
467+
def _unshard_ref(mesh, check_rep, names, ref_aval: AbstractRef):
468468
del mesh
469469
if names:
470470
# Can't actually shard a ref, can only close over it.

0 commit comments

Comments
 (0)