|
22 | 22 | import itertools |
23 | 23 | import math |
24 | 24 |
|
| 25 | +import jax |
25 | 26 | from jax import tree_util |
26 | 27 | from jax._src import core |
27 | 28 | from jax._src import dispatch |
@@ -459,86 +460,146 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): |
459 | 460 | def ragged_all_to_all( |
460 | 461 | operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *, |
461 | 462 | axis_name, axis_index_groups = None): |
462 | | - """Ragged version of :func:`all_to_all`. |
463 | | -
|
464 | | - For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent |
465 | | - and the outermost (ragged) dimension. ``axis_index_groups`` is default to all |
466 | | - replicas (e.g. there is only one group and covers all axis indices). |
467 | | -
|
468 | | - Ragged arrays are defined by a set of three arrays: |
469 | | - * ``data``: the ``data`` array is "ragged" along its outermost dimension, |
470 | | - along which each indexed element has variable size. |
471 | | - * ``offsets``: the ``offsets`` array indexes the outermost dimension of the |
472 | | - ``data`` array, and represents the starting offset of each ragged element of |
473 | | - the ``data`` array. |
474 | | - * ``sizes``: the ``sizes`` array represents the size of each ragged element of |
475 | | - the ``data`` array, where the size is specified in units of sub-elements. A |
476 | | - sub-element is defined as the suffix of the ``data`` array shape obtained by |
477 | | - removing the outermost "ragged" dimension. |
478 | | - The ``offsets`` and ``sizes`` arrays must have the same size. |
479 | | -
|
480 | | - # Example ragged tensor |
481 | | - data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}} |
482 | | - offsets: [3] = {0, 1, 4} |
483 | | - sizes: [3] = {1, 3, 4} |
484 | | -
|
485 | | - # Index 'data' at 'offsets'[0], 'sizes'[0]' |
486 | | - {a,b,c} |
487 | | -
|
488 | | - # Index 'data' at 'offsets'[1], 'sizes'[1]' |
489 | | - {d,e,f},{g,h,i},{j,k,l} |
490 | | -
|
491 | | - # Index 'data' at 'offsets'[2], 'sizes'[2]' |
492 | | - {m,n,o},{p,q,r},{s,t,u},{v,w,x} |
493 | | -
|
494 | | -
|
495 | | - ``output_offsets`` must be sharded in a way that each replica has offsets in |
496 | | - the target replica output perspective. |
497 | | -
|
498 | | - For i-th output offset, the current replica will send |
499 | | - `operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th |
500 | | - replica that will be written to |
501 | | - `output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th |
502 | | - replica ``output``. |
503 | | -
|
504 | | - For example, if we have 2 replicas: |
505 | | -
|
506 | | - replica 0: |
507 | | - operand: [1, 2, 2] |
508 | | - output: [0, 0, 0, 0] |
509 | | - input_offsets: [0, 1] |
510 | | - send_sizes: [1, 2] |
511 | | - output_offsets: [0, 0] |
512 | | - recv_sizes: [1, 1] |
513 | | -
|
514 | | - replica 1: |
515 | | - operand: [3, 4, 0] |
516 | | - output: [0, 0, 0, 0] |
517 | | - input_offsets: [0, 1] |
518 | | - send_sizes: [1, 1] |
519 | | - output_offsets: [1, 2] |
520 | | - recv_sizes: [2, 1] |
521 | | -
|
522 | | - replica 0's result will be: [1, 3, 0, 0] |
523 | | - replica 1's result will be: [2, 2, 4, 0] |
| 463 | + """Ragged version of :func:`all_to_all` collective. |
| 464 | +
|
| 465 | + We say data are "ragged" when they can be represented as a list of arrays |
| 466 | + whose shapes differ only in the size of the leading axis. For example, these |
| 467 | + data are ragged, comprising four component arrays:: |
| 468 | +
|
| 469 | + ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)] |
| 470 | +
|
| 471 | + We often instead want a contiguous representation, e.g. for batching. But |
| 472 | + because the shapes of the components differ, we can't apply ``jnp.stack`` to |
| 473 | + represent these data by a single rectangular array with the leading axis |
| 474 | + indexing the component arrays. So instead of stacking, we concatenate along |
| 475 | + the leading axis and keep track of offsets and sizes. |
| 476 | +
|
| 477 | + That is, we can represent ragged data contiguously using a triple of dense |
| 478 | + arrays ``(data, offsets, sizes)``: |
| 479 | + * ``data``: the concatenated component arrays, |
| 480 | + * ``offsets``: 1D array of indices into the leading axis of ``data`` |
| 481 | + indicating where the data for each component array begins, |
| 482 | + * ``sizes``: 1D array of sizes of the leading axis of each component array. |
| 483 | + We refer to this triple as a ragged array. (Offsets can't be computed from |
| 484 | + sizes in general to allow for internal padding.) |
| 485 | +
|
| 486 | + For example:: |
| 487 | + data: f32[8,3] = jnp.array([ |
| 488 | + [a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x], |
| 489 | + ]) |
| 490 | + offsets: i32[3] = jnp.array([0, 1, 4]) |
| 491 | + sizes: i32[3] = jnp.array([1, 3, 4]) |
| 492 | +
|
| 493 | + # To extract the first component array, of type f32[1,3] |
| 494 | + data[offsets[0]:offsets[0]+sizes[0]] |
| 495 | +
|
| 496 | + # To extract the second component array, of type f32[3,3] |
| 497 | + data[offsets[1]:offsets[1]+sizes[1]] |
| 498 | +
|
| 499 | + # To extract the third component array, of type f32[4,3] |
| 500 | + data[offsets[2]:offsets[2]+sizes[2]] |
| 501 | +
|
| 502 | + The ``ragged_all_to_all`` collective operation communicates slices of ragged |
| 503 | + arrays between devices. Each caller is both a sender and a receiver. The |
| 504 | + ``input_offsets`` and ``send_sizes`` arguments indicate the slices of the |
| 505 | + caller's ``operand`` to be sent. Received results are returned in an array |
| 506 | + that has the same value of the argument ``output`` except with received values |
| 507 | + written at some slices. The ``output_offsets`` argument does *not* indicate |
| 508 | + the offsets at which all the received results are written; instead, |
| 509 | + ``output_offsets`` indicates the offsets at which the *sent* slices are |
| 510 | + written on their corresponding receivers. The sizes of received slices are |
| 511 | + indicated by ``recv_sizes``. See below for details. |
| 512 | +
|
| 513 | + The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and |
| 514 | + ``recv_sizes`` must all be the same length, and that length must be divisible |
| 515 | + by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and |
| 516 | + ``recv_sizes`` must satisfy:: |
| 517 | +
|
| 518 | + jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True)) |
| 519 | +
|
| 520 | + Specifically, given a call:: |
| 521 | +
|
| 522 | + result = ragged_all_to_all(operand, output, input_offsets, send_sizes, |
| 523 | + output_offsets, recv_sizes, axis_name) |
| 524 | +
|
| 525 | + the caller sends data like:: |
| 526 | +
|
| 527 | + assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes) |
| 528 | + N = len(input_offsets) |
| 529 | + slices_per_device, leftover = divmod(N, lax.axis_size(axis_name)) |
| 530 | + assert not leftover |
| 531 | +
|
| 532 | + for i in range(N): |
| 533 | + dst_idx = i // slices_per_device |
| 534 | + SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]], |
| 535 | + axis_name=axis_name, to_axis_index=dst_idx) |
| 536 | +
|
| 537 | + and receives data in ``result`` like:: |
| 538 | +
|
| 539 | + result = output |
| 540 | + output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True) |
| 541 | + for i in range(N): |
| 542 | + src_idx = i // slices_per_device |
| 543 | + result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i] |
| 544 | + ].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx)) |
| 545 | +
|
| 546 | + where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local |
| 547 | + ``output_offsets`` does not indicate the offsets at which its local ``result`` |
| 548 | + is updated; instead, it indicates where the corresponding sent slices are |
| 549 | + written on their destination instances. To compute the local offsets at which |
| 550 | + received data are written, we apply an ``all_to_all`` on ``output_offsets``. |
| 551 | +
|
| 552 | + For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with |
| 553 | + these arguments in each mapped function instance:: |
| 554 | +
|
| 555 | + axis index 0: |
| 556 | + operand = [1, 2, 2] |
| 557 | + output = [0, 0, 0, 0] |
| 558 | + input_offsets = [0, 1] |
| 559 | + send_sizes = [1, 2] |
| 560 | + output_offsets = [0, 0] |
| 561 | + recv_sizes = [1, 1] |
| 562 | +
|
| 563 | + axis index 1: |
| 564 | + operand = [3, 4, 0] |
| 565 | + output = [0, 0, 0, 0] |
| 566 | + input_offsets = [0, 1] |
| 567 | + send_sizes = [1, 1] |
| 568 | + output_offsets = [1, 2] |
| 569 | + recv_sizes = [2, 1] |
| 570 | +
|
| 571 | + then:: |
| 572 | +
|
| 573 | + axis index 0: |
| 574 | + result = [1, 3, 0, 0] |
| 575 | +
|
| 576 | + axis index 1: |
| 577 | + result = [2, 2, 4, 0] |
524 | 578 |
|
525 | 579 | Args: |
526 | | - operand: array with ragged dimension along its outermost dimension. |
527 | | - output: array of ragged input offsets. |
528 | | - input_offsets: array of ragged input send sizes. |
529 | | - send_sizes: array of ragged output data. |
530 | | - output_offsets: array of ragged offsets in the target replica output. |
531 | | - recv_sizes: array of ragged output receive sizes. |
532 | | - axis_name: hashable Python object used to name a pmapped axis (see the |
533 | | - :func:`jax.pmap` documentation for more details). |
| 580 | + operand: data array of shape (N, A, B, ...) representing concatenated |
| 581 | + (possibly padded) ragged data to be sent. |
| 582 | + output: data array of shape (M, A, B, ...) to update with received data. |
| 583 | + input_offsets: 1D integer array of shape (K,) representing the offsets of |
| 584 | + leading-axis slices into ``operand`` to be sent. |
| 585 | + send_sizes: 1D integer array array of shape (K,) representing the sizes of |
| 586 | + leading-axis slices into ``operand`` to be sent. |
| 587 | + output_offsets: 1D integer array of shape (K,) representing where the |
| 588 | + corresponding sent data is written on each corresponding receiver. |
| 589 | + recv_sizes: 1D integer array of shape (K,) representing sizes of |
| 590 | + leading-axis slices into ``output`` to update with received data. |
| 591 | + axis_name: name of the mapped axis over which to perform the communication. |
534 | 592 | axis_index_groups: optional list of lists containing axis indices (e.g. for |
535 | 593 | an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the |
536 | 594 | first two and last two replicas). Groups must cover all axis indices |
537 | 595 | exactly once, and all groups must be the same size. Otherwise, the |
538 | 596 | behavior is undefined. |
539 | 597 |
|
540 | 598 | Returns: |
541 | | - array with shape equal to ``output``. |
| 599 | + Array of shape (M, A, B, ...) with the same value as the ``output`` except |
| 600 | + with received data written into slices starting at |
| 601 | + ``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size |
| 602 | + ``recv_sizes``. |
542 | 603 | """ |
543 | 604 |
|
544 | 605 | if not isinstance(axis_name, (tuple, list)): |
@@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval( |
1210 | 1271 | effects = {*map(core.NamedAxisEffect, axis_name)} |
1211 | 1272 | return out_aval, effects |
1212 | 1273 |
|
| 1274 | +def _ragged_all_to_all_jvp(primals, tangents, **params): |
| 1275 | + operand, output, *sizes_and_offsets = primals |
| 1276 | + operand_dot, output_dot, *_ = tangents |
| 1277 | + result = ragged_all_to_all_p.bind( |
| 1278 | + operand, output, *sizes_and_offsets, **params) |
| 1279 | + if type(operand_dot) is type(output_dot) is ad.Zero: |
| 1280 | + result_dot = ad.Zero.from_primal_value(result) |
| 1281 | + else: |
| 1282 | + operand_dot = ad.instantiate_zeros(operand_dot) |
| 1283 | + output_dot = ad.instantiate_zeros(output_dot) |
| 1284 | + result_dot = ragged_all_to_all_p.bind( |
| 1285 | + operand_dot, output_dot, *sizes_and_offsets, **params) |
| 1286 | + return result, result_dot |
| 1287 | + |
| 1288 | +def _ragged_all_to_all_transpose( |
| 1289 | + t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, |
| 1290 | + *, axis_name, axis_index_groups): |
| 1291 | + if type(t) is ad.Zero: |
| 1292 | + operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None |
| 1293 | + output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None |
| 1294 | + else: |
| 1295 | + zero = ad.zeros_like_aval(operand.aval) |
| 1296 | + output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True) |
| 1297 | + input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True) |
| 1298 | + operand_t = ragged_all_to_all_p.bind( |
| 1299 | + t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes, |
| 1300 | + axis_name=axis_name, axis_index_groups=axis_index_groups) |
| 1301 | + mask = jax.numpy.cumsum( |
| 1302 | + jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ |
| 1303 | + .at[output_offsets_ + recv_sizes].add(-1)) |
| 1304 | + output_t = jax.numpy.where(mask, 0, t) |
| 1305 | + return [operand_t, output_t] + [None] * 4 |
| 1306 | + |
1213 | 1307 | ragged_all_to_all_p = core.Primitive('ragged_all_to_all') |
1214 | 1308 | ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) |
| 1309 | +ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp |
| 1310 | +ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose |
1215 | 1311 | mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) |
1216 | 1312 | batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') |
1217 | 1313 |
|
|
0 commit comments