Skip to content

Commit 66a6eb2

Browse files
mattjjGoogle-ML-Automation
authored andcommitted
add autodiff rules for jax.lax.ragged_all_to_all collective
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future. PiperOrigin-RevId: 735957604
1 parent 3a26804 commit 66a6eb2

File tree

3 files changed

+242
-72
lines changed

3 files changed

+242
-72
lines changed

jax/_src/lax/lax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8197,7 +8197,7 @@ def _const(example, val):
81978197
def _zero(x):
81988198
x_aval = core.get_aval(x)
81998199
return full_like(x, shape=(), fill_value=0,
8200-
sharding=x_aval.sharding.with_spec(P()))
8200+
sharding=x_aval.sharding.with_spec(P()))
82018201

82028202
_ones: Callable = partial(full_like, fill_value=1)
82038203

jax/_src/lax/parallel.py

Lines changed: 167 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import itertools
2323
import math
2424

25+
import jax
2526
from jax import tree_util
2627
from jax._src import core
2728
from jax._src import dispatch
@@ -459,86 +460,146 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis):
459460
def ragged_all_to_all(
460461
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *,
461462
axis_name, axis_index_groups = None):
462-
"""Ragged version of :func:`all_to_all`.
463-
464-
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
465-
and the outermost (ragged) dimension. ``axis_index_groups`` is default to all
466-
replicas (e.g. there is only one group and covers all axis indices).
467-
468-
Ragged arrays are defined by a set of three arrays:
469-
* ``data``: the ``data`` array is "ragged" along its outermost dimension,
470-
along which each indexed element has variable size.
471-
* ``offsets``: the ``offsets`` array indexes the outermost dimension of the
472-
``data`` array, and represents the starting offset of each ragged element of
473-
the ``data`` array.
474-
* ``sizes``: the ``sizes`` array represents the size of each ragged element of
475-
the ``data`` array, where the size is specified in units of sub-elements. A
476-
sub-element is defined as the suffix of the ``data`` array shape obtained by
477-
removing the outermost "ragged" dimension.
478-
The ``offsets`` and ``sizes`` arrays must have the same size.
479-
480-
# Example ragged tensor
481-
data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
482-
offsets: [3] = {0, 1, 4}
483-
sizes: [3] = {1, 3, 4}
484-
485-
# Index 'data' at 'offsets'[0], 'sizes'[0]'
486-
{a,b,c}
487-
488-
# Index 'data' at 'offsets'[1], 'sizes'[1]'
489-
{d,e,f},{g,h,i},{j,k,l}
490-
491-
# Index 'data' at 'offsets'[2], 'sizes'[2]'
492-
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
493-
494-
495-
``output_offsets`` must be sharded in a way that each replica has offsets in
496-
the target replica output perspective.
497-
498-
For i-th output offset, the current replica will send
499-
`operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th
500-
replica that will be written to
501-
`output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th
502-
replica ``output``.
503-
504-
For example, if we have 2 replicas:
505-
506-
replica 0:
507-
operand: [1, 2, 2]
508-
output: [0, 0, 0, 0]
509-
input_offsets: [0, 1]
510-
send_sizes: [1, 2]
511-
output_offsets: [0, 0]
512-
recv_sizes: [1, 1]
513-
514-
replica 1:
515-
operand: [3, 4, 0]
516-
output: [0, 0, 0, 0]
517-
input_offsets: [0, 1]
518-
send_sizes: [1, 1]
519-
output_offsets: [1, 2]
520-
recv_sizes: [2, 1]
521-
522-
replica 0's result will be: [1, 3, 0, 0]
523-
replica 1's result will be: [2, 2, 4, 0]
463+
"""Ragged version of :func:`all_to_all` collective.
464+
465+
We say data are "ragged" when they can be represented as a list of arrays
466+
whose shapes differ only in the size of the leading axis. For example, these
467+
data are ragged, comprising four component arrays::
468+
469+
ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)]
470+
471+
We often instead want a contiguous representation, e.g. for batching. But
472+
because the shapes of the components differ, we can't apply ``jnp.stack`` to
473+
represent these data by a single rectangular array with the leading axis
474+
indexing the component arrays. So instead of stacking, we concatenate along
475+
the leading axis and keep track of offsets and sizes.
476+
477+
That is, we can represent ragged data contiguously using a triple of dense
478+
arrays ``(data, offsets, sizes)``:
479+
* ``data``: the concatenated component arrays,
480+
* ``offsets``: 1D array of indices into the leading axis of ``data``
481+
indicating where the data for each component array begins,
482+
* ``sizes``: 1D array of sizes of the leading axis of each component array.
483+
We refer to this triple as a ragged array. (Offsets can't be computed from
484+
sizes in general to allow for internal padding.)
485+
486+
For example::
487+
data: f32[8,3] = jnp.array([
488+
[a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x],
489+
])
490+
offsets: i32[3] = jnp.array([0, 1, 4])
491+
sizes: i32[3] = jnp.array([1, 3, 4])
492+
493+
# To extract the first component array, of type f32[1,3]
494+
data[offsets[0]:offsets[0]+sizes[0]]
495+
496+
# To extract the second component array, of type f32[3,3]
497+
data[offsets[1]:offsets[1]+sizes[1]]
498+
499+
# To extract the third component array, of type f32[4,3]
500+
data[offsets[2]:offsets[2]+sizes[2]]
501+
502+
The ``ragged_all_to_all`` collective operation communicates slices of ragged
503+
arrays between devices. Each caller is both a sender and a receiver. The
504+
``input_offsets`` and ``send_sizes`` arguments indicate the slices of the
505+
caller's ``operand`` to be sent. Received results are returned in an array
506+
that has the same value of the argument ``output`` except with received values
507+
written at some slices. The ``output_offsets`` argument does *not* indicate
508+
the offsets at which all the received results are written; instead,
509+
``output_offsets`` indicates the offsets at which the *sent* slices are
510+
written on their corresponding receivers. The sizes of received slices are
511+
indicated by ``recv_sizes``. See below for details.
512+
513+
The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and
514+
``recv_sizes`` must all be the same length, and that length must be divisible
515+
by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and
516+
``recv_sizes`` must satisfy::
517+
518+
jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True))
519+
520+
Specifically, given a call::
521+
522+
result = ragged_all_to_all(operand, output, input_offsets, send_sizes,
523+
output_offsets, recv_sizes, axis_name)
524+
525+
the caller sends data like::
526+
527+
assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes)
528+
N = len(input_offsets)
529+
slices_per_device, leftover = divmod(N, lax.axis_size(axis_name))
530+
assert not leftover
531+
532+
for i in range(N):
533+
dst_idx = i // slices_per_device
534+
SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]],
535+
axis_name=axis_name, to_axis_index=dst_idx)
536+
537+
and receives data in ``result`` like::
538+
539+
result = output
540+
output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
541+
for i in range(N):
542+
src_idx = i // slices_per_device
543+
result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i]
544+
].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx))
545+
546+
where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local
547+
``output_offsets`` does not indicate the offsets at which its local ``result``
548+
is updated; instead, it indicates where the corresponding sent slices are
549+
written on their destination instances. To compute the local offsets at which
550+
received data are written, we apply an ``all_to_all`` on ``output_offsets``.
551+
552+
For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with
553+
these arguments in each mapped function instance::
554+
555+
axis index 0:
556+
operand = [1, 2, 2]
557+
output = [0, 0, 0, 0]
558+
input_offsets = [0, 1]
559+
send_sizes = [1, 2]
560+
output_offsets = [0, 0]
561+
recv_sizes = [1, 1]
562+
563+
axis index 1:
564+
operand = [3, 4, 0]
565+
output = [0, 0, 0, 0]
566+
input_offsets = [0, 1]
567+
send_sizes = [1, 1]
568+
output_offsets = [1, 2]
569+
recv_sizes = [2, 1]
570+
571+
then::
572+
573+
axis index 0:
574+
result = [1, 3, 0, 0]
575+
576+
axis index 1:
577+
result = [2, 2, 4, 0]
524578
525579
Args:
526-
operand: array with ragged dimension along its outermost dimension.
527-
output: array of ragged input offsets.
528-
input_offsets: array of ragged input send sizes.
529-
send_sizes: array of ragged output data.
530-
output_offsets: array of ragged offsets in the target replica output.
531-
recv_sizes: array of ragged output receive sizes.
532-
axis_name: hashable Python object used to name a pmapped axis (see the
533-
:func:`jax.pmap` documentation for more details).
580+
operand: data array of shape (N, A, B, ...) representing concatenated
581+
(possibly padded) ragged data to be sent.
582+
output: data array of shape (M, A, B, ...) to update with received data.
583+
input_offsets: 1D integer array of shape (K,) representing the offsets of
584+
leading-axis slices into ``operand`` to be sent.
585+
send_sizes: 1D integer array array of shape (K,) representing the sizes of
586+
leading-axis slices into ``operand`` to be sent.
587+
output_offsets: 1D integer array of shape (K,) representing where the
588+
corresponding sent data is written on each corresponding receiver.
589+
recv_sizes: 1D integer array of shape (K,) representing sizes of
590+
leading-axis slices into ``output`` to update with received data.
591+
axis_name: name of the mapped axis over which to perform the communication.
534592
axis_index_groups: optional list of lists containing axis indices (e.g. for
535593
an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the
536594
first two and last two replicas). Groups must cover all axis indices
537595
exactly once, and all groups must be the same size. Otherwise, the
538596
behavior is undefined.
539597
540598
Returns:
541-
array with shape equal to ``output``.
599+
Array of shape (M, A, B, ...) with the same value as the ``output`` except
600+
with received data written into slices starting at
601+
``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size
602+
``recv_sizes``.
542603
"""
543604

544605
if not isinstance(axis_name, (tuple, list)):
@@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval(
12101271
effects = {*map(core.NamedAxisEffect, axis_name)}
12111272
return out_aval, effects
12121273

1274+
def _ragged_all_to_all_jvp(primals, tangents, **params):
1275+
operand, output, *sizes_and_offsets = primals
1276+
operand_dot, output_dot, *_ = tangents
1277+
result = ragged_all_to_all_p.bind(
1278+
operand, output, *sizes_and_offsets, **params)
1279+
if type(operand_dot) is type(output_dot) is ad.Zero:
1280+
result_dot = ad.Zero.from_primal_value(result)
1281+
else:
1282+
operand_dot = ad.instantiate_zeros(operand_dot)
1283+
output_dot = ad.instantiate_zeros(output_dot)
1284+
result_dot = ragged_all_to_all_p.bind(
1285+
operand_dot, output_dot, *sizes_and_offsets, **params)
1286+
return result, result_dot
1287+
1288+
def _ragged_all_to_all_transpose(
1289+
t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
1290+
*, axis_name, axis_index_groups):
1291+
if type(t) is ad.Zero:
1292+
operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
1293+
output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None
1294+
else:
1295+
zero = ad.zeros_like_aval(operand.aval)
1296+
output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
1297+
input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True)
1298+
operand_t = ragged_all_to_all_p.bind(
1299+
t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes,
1300+
axis_name=axis_name, axis_index_groups=axis_index_groups)
1301+
mask = jax.numpy.cumsum(
1302+
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
1303+
.at[output_offsets_ + recv_sizes].add(-1))
1304+
output_t = jax.numpy.where(mask, 0, t)
1305+
return [operand_t, output_t] + [None] * 4
1306+
12131307
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
12141308
ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval)
1309+
ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp
1310+
ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose
12151311
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
12161312
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')
12171313

tests/ragged_collective_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,80 @@ def fwd(
125125
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32)
126126
)
127127

128+
@parameterized.named_parameters(
129+
dict(
130+
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2)
131+
),
132+
)
133+
def test_ragged_all_to_all_grad(self, axis_name, mesh_axes):
134+
device_type = jax.devices()[0].platform
135+
if device_type == 'tpu' and jtu.get_tpu_version() < 4:
136+
raise unittest.SkipTest(
137+
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU'
138+
f' v{jtu.get_tpu_version()}'
139+
)
140+
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
141+
operand = jax.device_put(
142+
jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.float32),
143+
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
144+
)
145+
output = jax.device_put(
146+
jnp.zeros((2, 4), dtype=jnp.float32),
147+
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
148+
)
149+
input_offsets = jax.device_put(
150+
jnp.array([[0, 1], [0, 1]], dtype=jnp.int32),
151+
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
152+
)
153+
send_sizes = jax.device_put(
154+
jnp.array([[1, 2], [1, 1]], dtype=jnp.int32),
155+
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
156+
)
157+
output_offsets = jax.device_put(
158+
jnp.array([[0, 0], [1, 2]], dtype=jnp.int32),
159+
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
160+
)
161+
recv_sizes = jax.device_put(
162+
jnp.array([[1, 1], [2, 1]], dtype=jnp.int32),
163+
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
164+
)
165+
166+
@partial(
167+
shard_map,
168+
mesh=mesh,
169+
in_specs=(
170+
P(axis_name, None),
171+
P(axis_name, None),
172+
P(axis_name, None),
173+
P(axis_name, None),
174+
P(axis_name, None),
175+
P(axis_name, None),
176+
),
177+
out_specs=P(axis_name),
178+
check_rep=False,
179+
)
180+
def fwd(
181+
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
182+
):
183+
operand = operand.reshape(operand.shape[1:])
184+
output = output.reshape(output.shape[1:])
185+
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
186+
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
187+
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
188+
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
189+
return lax.ragged_all_to_all(
190+
operand,
191+
output,
192+
input_offsets,
193+
send_sizes,
194+
output_offsets,
195+
recv_sizes,
196+
axis_name=axis_name,
197+
)
198+
199+
args = input_offsets, send_sizes, output_offsets, recv_sizes
200+
jtu.check_grads(lambda op, out: fwd(op, out, *args), (operand, output), order=1)
201+
128202
@parameterized.named_parameters(
129203
dict(
130204
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4)

0 commit comments

Comments
 (0)