Skip to content

Commit 133b1ca

Browse files
committed
Reshard dot_product_attention
1 parent 2912a1c commit 133b1ca

File tree

1 file changed

+52
-4
lines changed

1 file changed

+52
-4
lines changed

jax/_src/nn/functions.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from __future__ import annotations
1818

1919
from collections.abc import Sequence
20+
from collections import defaultdict
2021
from functools import partial
2122
import operator
23+
import itertools
2224
import numpy as np
2325
from typing import Any, Literal, overload
2426
import warnings
@@ -903,7 +905,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
903905

904906
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
905907
scale, q_seqlen, kv_seqlen, local_window_size,
906-
return_residual):
908+
return_residual, out_sharding = None, s_sharding = None):
907909
logits_dtype = jnp.promote_types(query.dtype, np.float32)
908910

909911
# If the query and logits dtypes are different, then the default precision
@@ -920,13 +922,15 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
920922
# Explicit precision will fail on platforms that don't support it. For example,
921923
# some GPUs do not support BF16_BF16_F32, and TPU does not support F16_F16_F32.
922924
# Use the default precision as a fallback in these cases.
925+
out_sharding1 = permute_shard(out_sharding, "*BTNH", "BNTS", S=s_sharding)
923926
try:
924927
logits = jnp_einsum.einsum(
925928
"BTNH,BSNH->BNTS",
926929
query,
927930
key,
928931
precision=precision,
929932
preferred_element_type=logits_dtype,
933+
out_sharding=out_sharding1
930934
)
931935
except: # pylint: disable=bare-except
932936
logits = jnp_einsum.einsum(
@@ -949,7 +953,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
949953
padded_logits = padded_logits.astype(np.float32)
950954
probs = softmax(padded_logits, axis=-1).astype(key.dtype)
951955

952-
encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value)
956+
encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value, out_sharding=out_sharding)
953957
if q_seqlen is not None:
954958
mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen)
955959
encoded *= mask.astype(encoded.dtype)
@@ -972,7 +976,9 @@ def _dot_product_attention_xla(
972976
q_seqlen: Array | None,
973977
kv_seqlen: Array | None,
974978
local_window_size: tuple[int, int] | None,
975-
return_residual: bool = False):
979+
return_residual: bool = False,
980+
out_sharding: Any = None,
981+
s_sharding : Any = None):
976982

977983
B, T, N, H = query.shape
978984
_, S, K, _ = key.shape
@@ -991,7 +997,7 @@ def _reshape_to_grouped(t):
991997
bias = _reshape_to_grouped(bias)
992998
mask = _reshape_to_grouped(mask)
993999
vmapped_fn = api.vmap(
994-
_dot_product_attention_core,
1000+
partial(_dot_product_attention_core, out_sharding=out_sharding, s_sharding=s_sharding),
9951001
in_axes=(3, None, None, 2, 2, None, None, None, None, None, None),
9961002
out_axes=3,
9971003
)
@@ -1056,6 +1062,8 @@ def dot_product_attention(
10561062
local_window_size: int | tuple[int, int] | None = None,
10571063
implementation: Literal['xla', 'cudnn'] | None = None,
10581064
return_residual: bool = False,
1065+
out_sharding=None,
1066+
s_sharding=None,
10591067
):
10601068
r"""Scaled dot product attention function.
10611069
@@ -1185,6 +1193,8 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
11851193
kv_seqlen=key_value_seq_lengths,
11861194
local_window_size=local_window_size,
11871195
return_residual=return_residual,
1196+
out_sharding=out_sharding,
1197+
s_sharding=s_sharding,
11881198
)
11891199
case 'cudnn':
11901200
use_padding = (
@@ -1490,3 +1500,41 @@ def log1mexp(x: ArrayLike) -> Array:
14901500
)
14911501

14921502
log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x))
1503+
1504+
def permute_shard(outer_t, in_spec: str, out_spec: str, **bindings):
1505+
"""
1506+
Permute a tuple according to einops-like dimension specifications.
1507+
1508+
Args:
1509+
t: Input sharding or partition spec or tuple to permute
1510+
in_spec: Input specification string (e.g., "BTNH")
1511+
out_spec: Output specification string (e.g., "BNTH")
1512+
**bindings: Values for dimensions in out_spec not present in in_spec
1513+
1514+
Returns:
1515+
Permuted tuple according to out_spec
1516+
1517+
Examples:
1518+
>>> permute_shard(P('A', 'B', 'C'), "ABC", "ACB")
1519+
PartitionSpec('A', 'C', 'B')
1520+
1521+
>>> permute_shard(P('A', 'B', 'C'), "XC", "CX")
1522+
PartitionSpec('C', 'A', 'B')
1523+
1524+
>>> permute_shard(P('A', 'B', 'C'), "ABC", "ANZ", N="D")
1525+
PartitionSpec('A', 'D', None)
1526+
"""
1527+
t = outer_t.spec if isinstance(outer_t, NamedSharding) else outer_t
1528+
d = len(t) - len(in_spec)
1529+
assert d >= 0, f"Tuple length {len(t)} is less than named dimensions {len(in_spec)}"
1530+
np_t = np.array(t)
1531+
dim_map = {dim: [i+d] for (i, dim) in enumerate(in_spec)}
1532+
dim_map[in_spec[0]] = range(0, d+1)
1533+
new_spec = P(*list(itertools.chain.from_iterable(
1534+
map(str, np_t[dim_map[dim_name]]) if dim_name in dim_map else [bindings.get(dim_name, None)]
1535+
for dim_name in out_spec
1536+
)))
1537+
if isinstance(outer_t, NamedSharding):
1538+
return NamedSharding(outer_t.mesh, new_spec)
1539+
else:
1540+
return new_spec

0 commit comments

Comments
 (0)