diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 4505c143f6bd..653d789550ec 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -19,6 +19,7 @@ from collections.abc import Sequence from functools import partial import operator +import itertools import numpy as np from typing import Any, Literal, overload import warnings @@ -903,7 +904,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, def _dot_product_attention_core(query, key, value, bias, mask, is_causal, scale, q_seqlen, kv_seqlen, local_window_size, - return_residual): + return_residual, out_sharding = None, s_sharding = None): logits_dtype = jnp.promote_types(query.dtype, np.float32) # If the query and logits dtypes are different, then the default precision @@ -920,6 +921,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, # Explicit precision will fail on platforms that don't support it. For example, # some GPUs do not support BF16_BF16_F32, and TPU does not support F16_F16_F32. # Use the default precision as a fallback in these cases. + out_sharding1 = permute_shard(out_sharding, "BTNH", "BNTS", S=s_sharding) try: logits = jnp_einsum.einsum( "BTNH,BSNH->BNTS", @@ -927,6 +929,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, key, precision=precision, preferred_element_type=logits_dtype, + out_sharding=out_sharding1 ) except: # pylint: disable=bare-except logits = jnp_einsum.einsum( @@ -935,6 +938,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, key, precision=None, preferred_element_type=logits_dtype, + out_sharding=out_sharding1 ) logits *= jnp.array(scale, dtype=logits.dtype) @@ -949,7 +953,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, padded_logits = padded_logits.astype(np.float32) probs = softmax(padded_logits, axis=-1).astype(key.dtype) - encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value) + encoded = jnp_einsum.einsum('BNTS,BSNH->BTNH', probs, value, out_sharding=out_sharding) if q_seqlen is not None: mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen) encoded *= mask.astype(encoded.dtype) @@ -972,7 +976,9 @@ def _dot_product_attention_xla( q_seqlen: Array | None, kv_seqlen: Array | None, local_window_size: tuple[int, int] | None, - return_residual: bool = False): + return_residual: bool = False, + out_sharding: Any = None, + s_sharding : Any = None): B, T, N, H = query.shape _, S, K, _ = key.shape @@ -991,7 +997,7 @@ def _reshape_to_grouped(t): bias = _reshape_to_grouped(bias) mask = _reshape_to_grouped(mask) vmapped_fn = api.vmap( - _dot_product_attention_core, + partial(_dot_product_attention_core, out_sharding=out_sharding, s_sharding=s_sharding), in_axes=(3, None, None, 2, 2, None, None, None, None, None, None), out_axes=3, ) @@ -1056,6 +1062,8 @@ def dot_product_attention( local_window_size: int | tuple[int, int] | None = None, implementation: Literal['xla', 'cudnn'] | None = None, return_residual: bool = False, + out_sharding=None, + s_sharding=None, ): r"""Scaled dot product attention function. @@ -1185,6 +1193,8 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], kv_seqlen=key_value_seq_lengths, local_window_size=local_window_size, return_residual=return_residual, + out_sharding=out_sharding, + s_sharding=s_sharding, ) case 'cudnn': use_padding = ( @@ -1490,3 +1500,45 @@ def log1mexp(x: ArrayLike) -> Array: ) log1mexp.defjvps(lambda g, ans, x: g / jnp.expm1(x)) + +def permute_shard(outer_t, in_spec: str, out_spec: str, **bindings): + """ + Permute a tuple according to einops-like dimension specifications. + + Args: + t: Input sharding or partition spec or tuple to permute + in_spec: Input specification string (e.g., "BTNH") + out_spec: Output specification string (e.g., "BNTH") + **bindings: Values for dimensions in out_spec not present in in_spec + + Returns: + Permuted tuple according to out_spec + + Examples: + >>> permute_shard(P('A', 'B', 'C'), "ABC", "ACB") + PartitionSpec('A', 'C', 'B') + + >>> permute_shard(P('A', 'B', 'C'), "XC", "CX") + PartitionSpec('C', 'A', 'B') + + >>> permute_shard(P('A', 'B', 'C'), "ABC", "ANZ", N="D") + PartitionSpec('A', 'D', None) + """ + if outer_t is None: + return None + t = outer_t.spec if isinstance(outer_t, NamedSharding) else outer_t + if t is None: + return outer_t + d = len(t) - len(in_spec) + assert d >= 0, f"Tuple length {len(t)} is less than named dimensions {len(in_spec)}" + np_t = np.array(t) + dim_map : dict = {dim: [i+d] for (i, dim) in enumerate(in_spec)} + dim_map[in_spec[0]] = range(0, d+1) + new_spec = P(*[str(a) if a else None for a in itertools.chain.from_iterable( + np_t[dim_map[dim_name]] if dim_name in dim_map else [bindings.get(dim_name, None)] + for dim_name in out_spec + )]) + if isinstance(outer_t, NamedSharding): + return NamedSharding(outer_t.mesh, new_spec) + else: + return new_spec