Skip to content

Commit 1771936

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter primitives
PiperOrigin-RevId: 741661360
1 parent b719ac0 commit 1771936

File tree

2 files changed

+64
-13
lines changed

2 files changed

+64
-13
lines changed

jax/_src/lax/parallel.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import jax
2626
from jax import tree_util
2727
from jax._src import core
28+
from jax._src import config
2829
from jax._src import dispatch
2930
from jax._src import dtypes
3031
from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext,
@@ -325,9 +326,10 @@ def ppermute(x, axis_name, perm):
325326
"""
326327
if not isinstance(axis_name, (list, tuple)):
327328
axis_name = (axis_name,)
328-
return tree_util.tree_map(
329-
partial(ppermute_p.bind, axis_name=axis_name,
330-
perm=tuple(map(tuple, perm))), x)
329+
def bind(leaf):
330+
leaf = insert_collective_pbroadcast(axis_name, leaf)
331+
return ppermute_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm)))
332+
return tree_util.tree_map(bind, x)
331333

332334
def pshuffle(x, axis_name, perm):
333335
"""Convenience wrapper of jax.lax.ppermute with alternate permutation encoding
@@ -447,6 +449,7 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis):
447449
else: # concat_axis < split_axis
448450
x = lax.expand_dims(x, (concat_axis,)) # insert the new axis
449451
split_axis += 1 # we have a new axis before split_axis now
452+
x = insert_collective_pbroadcast(axis_name, x)
450453
result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis,
451454
axis_name=axis_name,
452455
axis_index_groups=axis_index_groups,
@@ -975,6 +978,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
975978

976979
def _raise_to_shaped_abstract_eval(x, *, axis_name, **params):
977980
_check_axis_names(axis_name)
981+
collective_vma_rule('ppermute', axis_name, x)
978982
return x
979983

980984
ppermute_p = core.Primitive('ppermute')
@@ -1189,7 +1193,8 @@ def _all_to_all_effectful_abstract_eval(
11891193
assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size)
11901194
shape[split_axis] //= axis_size
11911195
shape[concat_axis] *= axis_size
1192-
out_aval = input_aval.update(shape=tuple(shape), weak_type=False)
1196+
vma = collective_vma_rule('all_to_all', axis_name, input_aval)
1197+
out_aval = input_aval.update(shape=tuple(shape), weak_type=False, vma=vma)
11931198
effects = {*map(core.NamedAxisEffect, axis_name)}
11941199
return out_aval, effects
11951200

@@ -1313,6 +1318,19 @@ def _ragged_all_to_all_transpose(
13131318
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
13141319
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')
13151320

1321+
def insert_collective_pbroadcast(axis_name, x):
1322+
if not config.varying_axes_in_types.value:
1323+
return x
1324+
1325+
from jax.experimental import shard_map
1326+
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
1327+
aval = core.get_aval(x)
1328+
names_union = set(axis_name) | aval.vma
1329+
pbroadcast_axis_name = tuple(n for n in names_union if n not in aval.vma)
1330+
if pbroadcast_axis_name:
1331+
x = shard_map.pbroadcast(x, pbroadcast_axis_name)
1332+
return x
1333+
13161334
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
13171335
"""Gather values of x across all replicas.
13181336
@@ -1382,6 +1400,7 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
13821400
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
13831401
axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
13841402
def bind(leaf):
1403+
leaf = insert_collective_pbroadcast(axis_name, leaf)
13851404
return all_gather_p.bind(
13861405
leaf,
13871406
all_gather_dimension=canonicalize_axis(
@@ -1434,6 +1453,19 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
14341453
**other_args).results
14351454

14361455

1456+
def collective_vma_rule(prim_name, axis_name, x_aval):
1457+
if not config.varying_axes_in_types.value:
1458+
return frozenset()
1459+
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
1460+
if any(a not in x_aval.vma for a in axis_name):
1461+
raise ValueError(
1462+
f"Collective {prim_name} must be applied to a device-varying "
1463+
f" type, but got {x_aval.vma} for collective acting "
1464+
f"over axis name {axis_name}. Please open an issue at "
1465+
"https://github.com/jax-ml/jax/issues and as a temporary "
1466+
"workaround pass the check_rep=False argument to shard_map")
1467+
return x_aval.vma
1468+
14371469
def _all_gather_effectful_abstract_eval(
14381470
x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled
14391471
):
@@ -1445,7 +1477,9 @@ def _all_gather_effectful_abstract_eval(
14451477
new_shape[all_gather_dimension] *= axis_size
14461478
else:
14471479
new_shape.insert(all_gather_dimension, axis_size)
1448-
return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)}
1480+
out_vma = collective_vma_rule('all_gather', axis_name, x_aval)
1481+
return (x_aval.update(shape=new_shape, vma=out_vma),
1482+
{*map(core.NamedAxisEffect, axis_name)})
14491483

14501484
def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
14511485
return (psum_scatter(cts, axis_name=axis_name,
@@ -1582,7 +1616,9 @@ def _reduce_scatter_effectful_abstract_eval(
15821616
f"{scatter_dim_input_size} must match shard count "
15831617
f"{axis_size}")
15841618
del new_shape[scatter_dimension]
1585-
return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)}
1619+
vma = collective_vma_rule('reduce_scatter', axis_name, x_aval)
1620+
return (x_aval.update(shape=new_shape, vma=vma),
1621+
{*map(core.NamedAxisEffect, axis_name)})
15861622

15871623

15881624
def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension,
@@ -1726,13 +1762,11 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
17261762
axis_name = axis_name,
17271763
axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
17281764
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
1729-
bind = partial(
1730-
reduce_scatter_p.bind,
1731-
axis_name=axis_name,
1732-
scatter_dimension=scatter_dimension,
1733-
axis_index_groups=axis_index_groups,
1734-
axis_size=axis_size,
1735-
tiled=tiled)
1765+
def bind(leaf):
1766+
leaf = insert_collective_pbroadcast(axis_name, leaf)
1767+
return reduce_scatter_p.bind(
1768+
leaf, axis_name=axis_name, scatter_dimension=scatter_dimension,
1769+
axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled)
17361770
return tree_util.tree_map(bind, x)
17371771

17381772

tests/shard_map_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2707,6 +2707,23 @@ def f(x):
27072707
# return jnp.sum(f(x, y))
27082708
# print(jax.jit(jax.grad(g)).trace(x, y).jaxpr)
27092709

2710+
@config.varying_axes_in_types(True)
2711+
def test_all_gather_with_vma_in_types(self):
2712+
mesh = jtu.create_mesh((2,), ('x',))
2713+
x = np.arange(8.)
2714+
2715+
def f(x):
2716+
self.assertEqual(x.aval.vma, frozenset())
2717+
out = jax.lax.all_gather(x, 'x')
2718+
self.assertEqual(out.aval.vma, frozenset({'x'}))
2719+
return out
2720+
2721+
f = jax.jit(shard_map(f, mesh=mesh, in_specs=P(), out_specs=P('x')))
2722+
jaxpr = f.trace(x).jaxpr
2723+
self.assertIn("pbroadcast[axes=('x',)", str(jaxpr))
2724+
2725+
f(x) # doesn't crash
2726+
27102727

27112728
class FunSpec(NamedTuple):
27122729
name: str

0 commit comments

Comments
 (0)