Skip to content

Commit 12c3057

Browse files
ghpvnistGoogle-ML-Automation
authored andcommitted
Introduce lax.ragged_all_to_all primitive
This version emits a StableHLO custom call. The test outputs the following MLIR module: ``` module @jit_ragged_all_to_all { func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) { %0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32> return %0 : tensor<6xf32> } } ``` For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above). The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all. PiperOrigin-RevId: 704550890
1 parent 944d822 commit 12c3057

File tree

5 files changed

+162
-0
lines changed

5 files changed

+162
-0
lines changed

jax/_src/lax/parallel.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,55 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis):
457457

458458
return tree_util.tree_map(bind, x)
459459

460+
def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
461+
"""Ragged version of :func:`all_to_all`.
462+
463+
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
464+
and the outermost (ragged) dimension. ``axis_index_groups`` is default to all
465+
replicas (e.g. there is only one group and covers all axis indices).
466+
467+
Ragged arrays are defined by a set of three arrays:
468+
* ``data``: the ``data`` array is "ragged" along its outermost dimension,
469+
along which each indexed element has variable size.
470+
* ``offsets``: the ``offsets`` array indexes the outermost dimension of the
471+
``data`` array, and represents the starting offset of each ragged element of
472+
the ``data`` array.
473+
* ``sizes``: the ``sizes`` array represents the size of each ragged element of
474+
the ``data`` array, where the size is specified in units of sub-elements. A
475+
sub-element is defined as the suffix of the ``data`` array shape obtained by
476+
removing the outermost "ragged" dimension.
477+
The ``offsets`` and ``sizes`` arrays must have the same size.
478+
479+
# Example ragged tensor
480+
data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
481+
offsets: [3] = {0, 1, 4}
482+
sizes: [3] = {1, 3, 4}
483+
484+
# Index 'data' at 'offsets'[0], 'sizes'[0]'
485+
{a,b,c}
486+
487+
# Index 'data' at 'offsets'[1], 'sizes'[1]'
488+
{d,e,f},{g,h,i},{j,k,l}
489+
490+
# Index 'data' at 'offsets'[2], 'sizes'[2]'
491+
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
492+
493+
Args:
494+
operand: array with ragged dimension along its outermost dimension.
495+
output: array of ragged input offsets.
496+
input_offsets: array of ragged input send sizes.
497+
send_sizes: array of ragged output data.
498+
output_offsets: array of ragged output offsets.
499+
recv_sizes: array of ragged output receive sizes.
500+
Returns:
501+
array with shape equal to ``output``.
502+
"""
503+
return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes,
504+
output_offsets, recv_sizes)
505+
506+
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
507+
508+
460509
def axis_index(axis_name):
461510
"""Return the index along the mapped axis ``axis_name``.
462511
@@ -1052,6 +1101,64 @@ def _all_to_all_effectful_abstract_eval(
10521101
batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name')
10531102

10541103

1104+
def _ragged_all_to_all_lowering(ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
1105+
N = input_offsets.type.shape[0]
1106+
backend_config = ir.DictAttr.get({
1107+
'replica_groups': ir.DenseIntElementsAttr.get(
1108+
np.arange(0, N, 1, dtype=np.int64), shape=[1, N]
1109+
)
1110+
})
1111+
return hlo.CustomCallOp(
1112+
result=[output.type],
1113+
inputs=[operand, output, input_offsets, send_sizes, output_offsets,
1114+
recv_sizes],
1115+
call_target_name=ir.StringAttr.get('ragged_all_to_all'),
1116+
backend_config=backend_config,
1117+
api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4),
1118+
).results
1119+
1120+
@ragged_all_to_all_p.def_abstract_eval
1121+
def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
1122+
if operand.shape != output.shape:
1123+
raise ValueError('ragged_all_to_all input and output shapes must be equal.')
1124+
if not dtypes.issubdtype(input_offsets.dtype, np.integer):
1125+
raise ValueError("ragged_all_to_all input_offsets must be integer type.")
1126+
if not dtypes.issubdtype(send_sizes.dtype, np.integer):
1127+
raise ValueError("ragged_all_to_all send_sizes must be integer type.")
1128+
if not dtypes.issubdtype(output_offsets.dtype, np.integer):
1129+
raise ValueError("ragged_all_to_all output_offsets must be integer type.")
1130+
if not dtypes.issubdtype(recv_sizes.dtype, np.integer):
1131+
raise ValueError("ragged_all_to_all recv_sizes must be integer type.")
1132+
if len(input_offsets.shape) != 1 or input_offsets.shape[0] < 1:
1133+
raise ValueError(
1134+
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
1135+
" size, but got shape {}".format(input_offsets.shape)
1136+
)
1137+
if len(send_sizes.shape) != 1 or send_sizes.shape[0] < 1:
1138+
raise ValueError(
1139+
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
1140+
" size, but got shape {}".format(send_sizes.shape)
1141+
)
1142+
if len(output_offsets.shape) != 1 or output_offsets.shape[0] < 1:
1143+
raise ValueError(
1144+
"ragged_all_to_all output_offsets must be rank 1 with positive"
1145+
" dimension size, but got shape {}".format(output_offsets.shape)
1146+
)
1147+
if len(recv_sizes.shape) != 1 or recv_sizes.shape[0] < 1:
1148+
raise ValueError(
1149+
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
1150+
" size, but got shape {}".format(recv_sizes.shape)
1151+
)
1152+
return output.update(
1153+
shape=list(output.shape),
1154+
dtype=output.dtype,
1155+
weak_type=output.weak_type,
1156+
)
1157+
1158+
ragged_all_to_all_p.def_impl(partial(dispatch.apply_primitive, ragged_all_to_all_p))
1159+
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
1160+
1161+
10551162
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
10561163
"""Gather values of x across all replicas.
10571164

jax/experimental/jax2tf/tests/primitives_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def test_primitive_coverage(self):
183183
continue
184184
if p.name == "pallas_call":
185185
continue
186+
if p.name == "ragged_all_to_all":
187+
continue
186188
if p.name == "ffi_call":
187189
continue
188190
if p.name == "tpu_custom_call":

jax/extend/core/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@
204204
pmin_p as pmin_p,
205205
ppermute_p as ppermute_p,
206206
psum_p as psum_p,
207+
ragged_all_to_all_p as ragged_all_to_all_p,
207208
)
208209

209210
from jax._src.lax.ann import (

jax/lax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@
362362
psum_p as psum_p,
363363
psum_scatter as psum_scatter,
364364
pswapaxes as pswapaxes,
365+
ragged_all_to_all as ragged_all_to_all,
366+
ragged_all_to_all_p as ragged_all_to_all_p,
365367
)
366368
from jax._src.lax.other import (
367369
conv_general_dilated_local as conv_general_dilated_local,

tests/lax_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,56 @@ def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype,
13461346
numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers)
13471347
self._CheckAgainstNumpy(numpy_op, op, args_maker)
13481348

1349+
def testRaggedAllToAllErrors(self):
1350+
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
1351+
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
1352+
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
1353+
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
1354+
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
1355+
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
1356+
1357+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input and output shapes must be equal."):
1358+
jax.jit(lax.ragged_all_to_all).lower(operand, jnp.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32), input_offsets, send_sizes, output_offsets, recv_sizes)
1359+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be integer type."):
1360+
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), send_sizes, output_offsets, recv_sizes)
1361+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be integer type."):
1362+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets, recv_sizes)
1363+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be integer type."):
1364+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes)
1365+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be integer type."):
1366+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32))
1367+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
1368+
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
1369+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
1370+
jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes)
1371+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
1372+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes)
1373+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
1374+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([], dtype=jnp.int32), output_offsets, recv_sizes)
1375+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"):
1376+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes)
1377+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (0,)"):
1378+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([], dtype=jnp.int32), recv_sizes)
1379+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"):
1380+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32))
1381+
with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (0,)"):
1382+
jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([], dtype=jnp.int32))
1383+
1384+
def testRaggedAllToAll(self):
1385+
operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32)
1386+
output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32)
1387+
input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
1388+
send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
1389+
output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32)
1390+
recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32)
1391+
mlir_module = jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes).as_text()
1392+
self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module)
1393+
self.assertIn(
1394+
"backend_config = {replica_groups = dense<[[0, 1, 2]]> :"
1395+
" tensor<1x3xi64>}}",
1396+
mlir_module,
1397+
)
1398+
13491399
@jtu.sample_product(
13501400
[
13511401
{'m': 5, 'k': 4, 'n': 3, 'num_groups': 1},

0 commit comments

Comments
 (0)