Skip to content

Commit 85eb58a

Browse files
yashk2810charleshofer
authored andcommitted
Remove jax_varying_axes_in_types config and rewrite from shard_map_p
PiperOrigin-RevId: 748545142
1 parent 46fb400 commit 85eb58a

File tree

7 files changed

+52
-115
lines changed

7 files changed

+52
-115
lines changed

jax/_src/config.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def trace_context():
235235
threefry_partitionable.value,
236236
threefry_gpu_kernel_lowering.value,
237237
use_direct_linearize.value,
238-
varying_axes_in_types.value,
239238
softmax_custom_jvp.value,
240239
disable_jit.value,
241240
debug_key_reuse.value,
@@ -1092,14 +1091,6 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
10921091
help=('Use direct linearization instead JVP followed by partial eval'),
10931092
include_in_jit_key=True)
10941093

1095-
varying_axes_in_types = bool_state(
1096-
name='jax_varying_axes_in_types',
1097-
default=True,
1098-
help=('Adds varying manual axes to ShapedArray to track which mesh axes the'
1099-
' array is varying over. This will help to remove the efficient'
1100-
' transpose rewrite machinery in shard_map'),
1101-
include_in_jit_key=True)
1102-
11031094
# TODO make it so people don't use this, this is internal...
11041095
_check_rep = bool_state(
11051096
name='check_rep',

jax/_src/core.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,8 +2006,6 @@ def pvary(x, axis_name):
20062006
pvary_p.def_impl(lambda *args, axes, axis_index_groups: args)
20072007

20082008
def _pvary_abstract_eval(*args, axes, axis_index_groups):
2009-
if not config.varying_axes_in_types.value:
2010-
return args
20112009
if not config._check_rep.value:
20122010
return args
20132011
assert isinstance(axes, tuple)
@@ -2027,8 +2025,6 @@ def _pvary_abstract_eval(*args, axes, axis_index_groups):
20272025

20282026

20292027
def standard_insert_pvary(*args):
2030-
if not config.varying_axes_in_types.value:
2031-
return args
20322028
if not config._check_rep.value:
20332029
return args
20342030
if not args:
@@ -2040,8 +2036,6 @@ def standard_insert_pvary(*args):
20402036
if out_vma - src else arg for arg, src in zip(args, in_vma)]
20412037

20422038
def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]:
2043-
if not config.varying_axes_in_types.value:
2044-
return frozenset()
20452039
if not config._check_rep.value:
20462040
return frozenset()
20472041
avals = tuple(a for a in avals if a is not abstract_token)

jax/_src/interpreters/batching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,
407407
def aval(self):
408408
aval = core.get_aval(self.val)
409409
if self._trace.axis_data.spmd_name is not None:
410-
if config._check_rep.value and config.varying_axes_in_types.value:
410+
if config._check_rep.value:
411411
aval = aval.update(
412412
vma=aval.vma - frozenset(self._trace.axis_data.spmd_name))
413413
if self.batch_dim is not_mapped:
@@ -776,7 +776,7 @@ def _batch_jaxpr2(
776776
aval = core.unmapped_aval(
777777
axis_data.size, b, aval, axis_data.explicit_mesh_axis)
778778
if axis_data.spmd_name is not None:
779-
if config._check_rep.value and config.varying_axes_in_types.value:
779+
if config._check_rep.value:
780780
aval = aval.update(vma=aval.vma | frozenset(axis_data.spmd_name)) # type: ignore
781781
avals_in2.append(aval)
782782
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
@@ -1111,7 +1111,7 @@ def broadcast(x, sz, axis, mesh_axis=None):
11111111
# out how to ensure jaxpr arguments always have the context mesh.
11121112
with mesh_lib.use_abstract_mesh(sharding.mesh):
11131113
x = jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
1114-
if config._check_rep.value and config.varying_axes_in_types.value:
1114+
if config._check_rep.value:
11151115
# TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026
11161116
spmd_names = core.get_axis_env().spmd_axis_names
11171117
if len(spmd_names) > 1:

jax/_src/lax/parallel.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def pos_reduce(x):
144144
size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes])
145145
out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves)
146146
else:
147-
if config.varying_axes_in_types.value and config._check_rep.value:
147+
if config._check_rep.value:
148148
out_flat = bind_psum_invariant(
149149
leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups)
150150
else:
@@ -827,9 +827,6 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
827827
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
828828

829829
def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups):
830-
if not config.varying_axes_in_types.value:
831-
return psum_p.abstract_eval(
832-
*args, axes=axes, axis_index_groups=axis_index_groups)
833830
if not config._check_rep.value:
834831
return psum_p.abstract_eval(
835832
*args, axes=axes, axis_index_groups=axis_index_groups)
@@ -865,9 +862,6 @@ def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups):
865862

866863
# TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval
867864
def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
868-
if not config.varying_axes_in_types.value:
869-
return _allreduce_effectful_abstract_eval(
870-
*args, axes=axes, axis_index_groups=axis_index_groups)
871865
if not config._check_rep.value:
872866
return _allreduce_effectful_abstract_eval(
873867
*args, axes=axes, axis_index_groups=axis_index_groups)
@@ -1417,8 +1411,6 @@ def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in,
14171411
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')
14181412

14191413
def insert_collective_pvary(axis_name, x):
1420-
if not config.varying_axes_in_types.value:
1421-
return x
14221414
if not config._check_rep.value:
14231415
return x
14241416

@@ -1551,8 +1543,6 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
15511543

15521544

15531545
def collective_vma_rule(prim_name, axis_name, x_aval):
1554-
if not config.varying_axes_in_types.value:
1555-
return frozenset()
15561546
if not config._check_rep.value:
15571547
return frozenset()
15581548
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
@@ -1921,8 +1911,7 @@ def _axis_index_effectful_abstract_eval(*, axis_name):
19211911
mesh = get_abstract_mesh()
19221912
sharding = NamedSharding(mesh, P())
19231913
vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset())
1924-
if config.varying_axes_in_types.value and config._check_rep.value
1925-
else frozenset())
1914+
if config._check_rep.value else frozenset())
19261915
return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect
19271916

19281917
def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name):

0 commit comments

Comments
 (0)