2525import jax
2626from jax import tree_util
2727from jax ._src import core
28+ from jax ._src import config
2829from jax ._src import dispatch
2930from jax ._src import dtypes
3031from jax ._src .sharding_impls import (SPMDAxisContext , ShardingContext ,
@@ -325,9 +326,10 @@ def ppermute(x, axis_name, perm):
325326 """
326327 if not isinstance (axis_name , (list , tuple )):
327328 axis_name = (axis_name ,)
328- return tree_util .tree_map (
329- partial (ppermute_p .bind , axis_name = axis_name ,
330- perm = tuple (map (tuple , perm ))), x )
329+ def bind (leaf ):
330+ leaf = insert_collective_pbroadcast (axis_name , leaf )
331+ return ppermute_p .bind (leaf , axis_name = axis_name , perm = tuple (map (tuple , perm )))
332+ return tree_util .tree_map (bind , x )
331333
332334def pshuffle (x , axis_name , perm ):
333335 """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding
@@ -447,6 +449,7 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis):
447449 else : # concat_axis < split_axis
448450 x = lax .expand_dims (x , (concat_axis ,)) # insert the new axis
449451 split_axis += 1 # we have a new axis before split_axis now
452+ x = insert_collective_pbroadcast (axis_name , x )
450453 result = all_to_all_p .bind (x , split_axis = split_axis , concat_axis = concat_axis ,
451454 axis_name = axis_name ,
452455 axis_index_groups = axis_index_groups ,
@@ -975,6 +978,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
975978
976979def _raise_to_shaped_abstract_eval (x , * , axis_name , ** params ):
977980 _check_axis_names (axis_name )
981+ collective_vma_rule ('ppermute' , axis_name , x )
978982 return x
979983
980984ppermute_p = core .Primitive ('ppermute' )
@@ -1189,7 +1193,8 @@ def _all_to_all_effectful_abstract_eval(
11891193 assert shape [split_axis ] % axis_size == 0 , (shape [split_axis ], axis_size )
11901194 shape [split_axis ] //= axis_size
11911195 shape [concat_axis ] *= axis_size
1192- out_aval = input_aval .update (shape = tuple (shape ), weak_type = False )
1196+ vma = collective_vma_rule ('all_to_all' , axis_name , input_aval )
1197+ out_aval = input_aval .update (shape = tuple (shape ), weak_type = False , vma = vma )
11931198 effects = {* map (core .NamedAxisEffect , axis_name )}
11941199 return out_aval , effects
11951200
@@ -1313,6 +1318,19 @@ def _ragged_all_to_all_transpose(
13131318mlir .register_lowering (ragged_all_to_all_p , _ragged_all_to_all_lowering )
13141319batching .skippable_batchers [ragged_all_to_all_p ] = partial (_names_in_param , 'axis_name' )
13151320
1321+ def insert_collective_pbroadcast (axis_name , x ):
1322+ if not config .varying_axes_in_types .value :
1323+ return x
1324+
1325+ from jax .experimental import shard_map
1326+ axis_name = (axis_name ,) if not isinstance (axis_name , tuple ) else axis_name
1327+ aval = core .get_aval (x )
1328+ names_union = set (axis_name ) | aval .vma
1329+ pbroadcast_axis_name = tuple (n for n in names_union if n not in aval .vma )
1330+ if pbroadcast_axis_name :
1331+ x = shard_map .pbroadcast (x , pbroadcast_axis_name )
1332+ return x
1333+
13161334def all_gather (x , axis_name , * , axis_index_groups = None , axis = 0 , tiled = False ):
13171335 """Gather values of x across all replicas.
13181336
@@ -1382,6 +1400,7 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
13821400 axis_index_groups = _canonicalize_axis_index_groups (axis_index_groups )
13831401 axis_size = psum (1 , axis_name , axis_index_groups = axis_index_groups )
13841402 def bind (leaf ):
1403+ leaf = insert_collective_pbroadcast (axis_name , leaf )
13851404 return all_gather_p .bind (
13861405 leaf ,
13871406 all_gather_dimension = canonicalize_axis (
@@ -1434,6 +1453,19 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
14341453 ** other_args ).results
14351454
14361455
1456+ def collective_vma_rule (prim_name , axis_name , x_aval ):
1457+ if not config .varying_axes_in_types .value :
1458+ return frozenset ()
1459+ axis_name = (axis_name ,) if not isinstance (axis_name , tuple ) else axis_name
1460+ if any (a not in x_aval .vma for a in axis_name ):
1461+ raise ValueError (
1462+ f"Collective { prim_name } must be applied to a device-varying "
1463+ f" type, but got { x_aval .vma } for collective acting "
1464+ f"over axis name { axis_name } . Please open an issue at "
1465+ "https://github.com/jax-ml/jax/issues and as a temporary "
1466+ "workaround pass the check_rep=False argument to shard_map" )
1467+ return x_aval .vma
1468+
14371469def _all_gather_effectful_abstract_eval (
14381470 x_aval , * , all_gather_dimension , axis_name , axis_index_groups , axis_size , tiled
14391471):
@@ -1445,7 +1477,9 @@ def _all_gather_effectful_abstract_eval(
14451477 new_shape [all_gather_dimension ] *= axis_size
14461478 else :
14471479 new_shape .insert (all_gather_dimension , axis_size )
1448- return x_aval .update (shape = new_shape ), {* map (core .NamedAxisEffect , axis_name )}
1480+ out_vma = collective_vma_rule ('all_gather' , axis_name , x_aval )
1481+ return (x_aval .update (shape = new_shape , vma = out_vma ),
1482+ {* map (core .NamedAxisEffect , axis_name )})
14491483
14501484def _all_gather_transpose_rule (cts , x , * , all_gather_dimension , axis_name , axis_index_groups , axis_size , tiled ):
14511485 return (psum_scatter (cts , axis_name = axis_name ,
@@ -1582,7 +1616,9 @@ def _reduce_scatter_effectful_abstract_eval(
15821616 f"{ scatter_dim_input_size } must match shard count "
15831617 f"{ axis_size } " )
15841618 del new_shape [scatter_dimension ]
1585- return x_aval .update (shape = new_shape ), {* map (core .NamedAxisEffect , axis_name )}
1619+ vma = collective_vma_rule ('reduce_scatter' , axis_name , x_aval )
1620+ return (x_aval .update (shape = new_shape , vma = vma ),
1621+ {* map (core .NamedAxisEffect , axis_name )})
15861622
15871623
15881624def _reduce_scatter_transpose_rule (cts , x , * , axis_name , scatter_dimension ,
@@ -1726,13 +1762,11 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
17261762 axis_name = axis_name ,
17271763 axis_size = psum (1 , axis_name , axis_index_groups = axis_index_groups )
17281764 axis_index_groups = _canonicalize_axis_index_groups (axis_index_groups )
1729- bind = partial (
1730- reduce_scatter_p .bind ,
1731- axis_name = axis_name ,
1732- scatter_dimension = scatter_dimension ,
1733- axis_index_groups = axis_index_groups ,
1734- axis_size = axis_size ,
1735- tiled = tiled )
1765+ def bind (leaf ):
1766+ leaf = insert_collective_pbroadcast (axis_name , leaf )
1767+ return reduce_scatter_p .bind (
1768+ leaf , axis_name = axis_name , scatter_dimension = scatter_dimension ,
1769+ axis_index_groups = axis_index_groups , axis_size = axis_size , tiled = tiled )
17361770 return tree_util .tree_map (bind , x )
17371771
17381772
0 commit comments