Skip to content

Commit 2d27e5f

Browse files
committed
Add out_sharding to dot_product_attention
1 parent 2912a1c commit 2d27e5f

File tree

1 file changed

+57
-5
lines changed

1 file changed

+57
-5
lines changed

jax/_src/nn/functions.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections.abc import Sequence
2020
from functools import partial
2121
import operator
22+
import itertools
2223
import numpy as np
2324
from typing import Any, Literal, overload
2425
import warnings
@@ -903,7 +904,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
903904

904905
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
905906
scale, q_seqlen, kv_seqlen, local_window_size,
906-
return_residual):
907+
return_residual, out_sharding = None, s_sharding = None):
907908
logits_dtype = jnp.promote_types(query.dtype, np.float32)
908909

909910
# If the query and logits dtypes are different, then the default precision
@@ -920,13 +921,15 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
920921
# Explicit precision will fail on platforms that don't support it. For example,
921922
# some GPUs do not support BF16_BF16_F32, and TPU does not support F16_F16_F32.
922923
# Use the default precision as a fallback in these cases.
924+
out_sharding1 = permute_shard(out_sharding, "BTNH", "BNTS", S=s_sharding)
923925
try:
924926
logits = jnp_einsum.einsum(
925927
"BTNH,BSNH->BNTS",
926928
query,
927929
key,
928930
precision=precision,
929931
preferred_element_type=logits_dtype,
932+
out_sharding=out_sharding1
930933
)
931934
except: # pylint: disable=bare-except
932935
logits = jnp_einsum.einsum(
@@ -935,12 +938,13 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
935938
key,
936939
precision=None,
937940
preferred_element_type=logits_dtype,
941+
out_sharding=out_sharding1
938942
)
939943

940944
logits *= jnp.array(scale, dtype=logits.dtype)
941945

942946
if bias is not None:
943-
logits = (logits + bias).astype(logits.dtype)
947+
logits = lax.add(logits, bias).astype(logits.dtype)
944948

945949
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
946950
local_window_size)
@@ -949,7 +953,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
949953
padded_logits = padded_logits.astype(np.float32)
950954
probs = softmax(padded_logits, axis=-1).astype(key.dtype)
951955

952-
encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value)
956+
encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value, out_sharding=out_sharding)
953957
if q_seqlen is not None:
954958
mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen)
955959
encoded *= mask.astype(encoded.dtype)
@@ -972,7 +976,9 @@ def _dot_product_attention_xla(
972976
q_seqlen: Array | None,
973977
kv_seqlen: Array | None,
974978
local_window_size: tuple[int, int] | None,
975-
return_residual: bool = False):
979+
return_residual: bool = False,
980+
out_sharding: Any = None,
981+
s_sharding : Any = None):
976982

977983
B, T, N, H = query.shape
978984
_, S, K, _ = key.shape
@@ -991,7 +997,7 @@ def _reshape_to_grouped(t):
991997
bias = _reshape_to_grouped(bias)
992998
mask = _reshape_to_grouped(mask)
993999
vmapped_fn = api.vmap(
994-
_dot_product_attention_core,
1000+
partial(_dot_product_attention_core, out_sharding=out_sharding, s_sharding=s_sharding),
9951001
in_axes=(3, None, None, 2, 2, None, None, None, None, None, None),
9961002
out_axes=3,
9971003
)
@@ -1056,6 +1062,8 @@ def dot_product_attention(
10561062
local_window_size: int | tuple[int, int] | None = None,
10571063
implementation: Literal['xla', 'cudnn'] | None = None,
10581064
return_residual: bool = False,
1065+
out_sharding=None,
1066+
s_sharding=None,
10591067
):
10601068
r"""Scaled dot product attention function.
10611069
@@ -1185,6 +1193,8 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
11851193
kv_seqlen=key_value_seq_lengths,
11861194
local_window_size=local_window_size,
11871195
return_residual=return_residual,
1196+
out_sharding=out_sharding,
1197+
s_sharding=s_sharding,
11881198
)
11891199
case 'cudnn':
11901200
use_padding = (
@@ -1490,3 +1500,45 @@ def log1mexp(x: ArrayLike) -> Array:
14901500
)
14911501

14921502
log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x))
1503+
1504+
def permute_shard(outer_t, in_spec: str, out_spec: str, **bindings):
1505+
"""
1506+
Permute a tuple according to einops-like dimension specifications.
1507+
1508+
Args:
1509+
t: Input sharding or partition spec or tuple to permute
1510+
in_spec: Input specification string (e.g., "BTNH")
1511+
out_spec: Output specification string (e.g., "BNTH")
1512+
**bindings: Values for dimensions in out_spec not present in in_spec
1513+
1514+
Returns:
1515+
Permuted tuple according to out_spec
1516+
1517+
Examples:
1518+
>>> permute_shard(P('A', 'B', 'C'), "ABC", "ACB")
1519+
PartitionSpec('A', 'C', 'B')
1520+
1521+
>>> permute_shard(P('A', 'B', 'C'), "XC", "CX")
1522+
PartitionSpec('C', 'A', 'B')
1523+
1524+
>>> permute_shard(P('A', 'B', 'C'), "ABC", "ANZ", N="D")
1525+
PartitionSpec('A', 'D', None)
1526+
"""
1527+
if outer_t is None:
1528+
return None
1529+
t = outer_t.spec if isinstance(outer_t, NamedSharding) else outer_t
1530+
if t is None:
1531+
return outer_t
1532+
d = len(t) - len(in_spec)
1533+
assert d >= 0, f"Tuple length {len(t)} is less than named dimensions {len(in_spec)}"
1534+
np_t = np.array(t)
1535+
dim_map : dict = {dim: [i+d] for (i, dim) in enumerate(in_spec)}
1536+
dim_map[in_spec[0]] = range(0, d+1)
1537+
new_spec = P(*list(itertools.chain.from_iterable(
1538+
np_t[dim_map[dim_name]] if dim_name in dim_map else [bindings.get(dim_name, None)]
1539+
for dim_name in out_spec
1540+
)))
1541+
if isinstance(outer_t, NamedSharding):
1542+
return NamedSharding(outer_t.mesh, new_spec)
1543+
else:
1544+
return new_spec

0 commit comments

Comments
 (0)