Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Sequence
from functools import partial
import operator
import itertools
import numpy as np
from typing import Any, Literal, overload
import warnings
Expand Down Expand Up @@ -903,7 +904,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,

def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
scale, q_seqlen, kv_seqlen, local_window_size,
return_residual):
return_residual, out_sharding = None, s_sharding = None):
logits_dtype = jnp.promote_types(query.dtype, np.float32)

# If the query and logits dtypes are different, then the default precision
Expand All @@ -920,13 +921,15 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
# Explicit precision will fail on platforms that don't support it. For example,
# some GPUs do not support BF16_BF16_F32, and TPU does not support F16_F16_F32.
# Use the default precision as a fallback in these cases.
out_sharding1 = permute_shard(out_sharding, "BTNH", "BNTS", S=s_sharding)
try:
logits = jnp_einsum.einsum(
"BTNH,BSNH->BNTS",
query,
key,
precision=precision,
preferred_element_type=logits_dtype,
out_sharding=out_sharding1
)
except: # pylint: disable=bare-except
logits = jnp_einsum.einsum(
Expand All @@ -935,6 +938,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
key,
precision=None,
preferred_element_type=logits_dtype,
out_sharding=out_sharding1
)

logits *= jnp.array(scale, dtype=logits.dtype)
Expand All @@ -949,7 +953,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
padded_logits = padded_logits.astype(np.float32)
probs = softmax(padded_logits, axis=-1).astype(key.dtype)

encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value)
encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value, out_sharding=out_sharding)
if q_seqlen is not None:
mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen)
encoded *= mask.astype(encoded.dtype)
Expand All @@ -972,7 +976,9 @@ def _dot_product_attention_xla(
q_seqlen: Array | None,
kv_seqlen: Array | None,
local_window_size: tuple[int, int] | None,
return_residual: bool = False):
return_residual: bool = False,
out_sharding: Any = None,
s_sharding : Any = None):

B, T, N, H = query.shape
_, S, K, _ = key.shape
Expand All @@ -991,7 +997,7 @@ def _reshape_to_grouped(t):
bias = _reshape_to_grouped(bias)
mask = _reshape_to_grouped(mask)
vmapped_fn = api.vmap(
_dot_product_attention_core,
partial(_dot_product_attention_core, out_sharding=out_sharding, s_sharding=s_sharding),
in_axes=(3, None, None, 2, 2, None, None, None, None, None, None),
out_axes=3,
)
Expand Down Expand Up @@ -1056,6 +1062,8 @@ def dot_product_attention(
local_window_size: int | tuple[int, int] | None = None,
implementation: Literal['xla', 'cudnn'] | None = None,
return_residual: bool = False,
out_sharding=None,
s_sharding=None,
):
r"""Scaled dot product attention function.

Expand Down Expand Up @@ -1185,6 +1193,8 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
kv_seqlen=key_value_seq_lengths,
local_window_size=local_window_size,
return_residual=return_residual,
out_sharding=out_sharding,
s_sharding=s_sharding,
)
case 'cudnn':
use_padding = (
Expand Down Expand Up @@ -1490,3 +1500,45 @@ def log1mexp(x: ArrayLike) -> Array:
)

log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x))

def permute_shard(outer_t, in_spec: str, out_spec: str, **bindings):
"""
Permute a tuple according to einops-like dimension specifications.

Args:
t: Input sharding or partition spec or tuple to permute
in_spec: Input specification string (e.g., "BTNH")
out_spec: Output specification string (e.g., "BNTH")
**bindings: Values for dimensions in out_spec not present in in_spec

Returns:
Permuted tuple according to out_spec

Examples:
>>> permute_shard(P('A', 'B', 'C'), "ABC", "ACB")
PartitionSpec('A', 'C', 'B')

>>> permute_shard(P('A', 'B', 'C'), "XC", "CX")
PartitionSpec('C', 'A', 'B')

>>> permute_shard(P('A', 'B', 'C'), "ABC", "ANZ", N="D")
PartitionSpec('A', 'D', None)
"""
if outer_t is None:
return None
t = outer_t.spec if isinstance(outer_t, NamedSharding) else outer_t
if t is None:
return outer_t
d = len(t) - len(in_spec)
assert d >= 0, f"Tuple length {len(t)} is less than named dimensions {len(in_spec)}"
np_t = np.array(t)
dim_map : dict = {dim: [i+d] for (i, dim) in enumerate(in_spec)}
dim_map[in_spec[0]] = range(0, d+1)
new_spec = P(*[str(a) if a else None for a in itertools.chain.from_iterable(
np_t[dim_map[dim_name]] if dim_name in dim_map else [bindings.get(dim_name, None)]
for dim_name in out_spec
)])
if isinstance(outer_t, NamedSharding):
return NamedSharding(outer_t.mesh, new_spec)
else:
return new_spec
Loading