1717from __future__ import annotations
1818
1919from collections .abc import Sequence
20+ from collections import defaultdict
2021from functools import partial
2122import operator
23+ import itertools
2224import numpy as np
2325from typing import Any , Literal , overload
2426import warnings
@@ -903,7 +905,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
903905
904906def _dot_product_attention_core (query , key , value , bias , mask , is_causal ,
905907 scale , q_seqlen , kv_seqlen , local_window_size ,
906- return_residual ):
908+ return_residual , out_sharding = None , s_sharding = None ):
907909 logits_dtype = jnp .promote_types (query .dtype , np .float32 )
908910
909911 # If the query and logits dtypes are different, then the default precision
@@ -920,13 +922,15 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
920922 # Explicit precision will fail on platforms that don't support it. For example,
921923 # some GPUs do not support BF16_BF16_F32, and TPU does not support F16_F16_F32.
922924 # Use the default precision as a fallback in these cases.
925+ out_sharding1 = permute_shard (out_sharding , "*BTNH" , "BNTS" , S = s_sharding )
923926 try :
924927 logits = jnp_einsum .einsum (
925928 "BTNH,BSNH->BNTS" ,
926929 query ,
927930 key ,
928931 precision = precision ,
929932 preferred_element_type = logits_dtype ,
933+ out_sharding = out_sharding1
930934 )
931935 except : # pylint: disable=bare-except
932936 logits = jnp_einsum .einsum (
@@ -949,7 +953,7 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
949953 padded_logits = padded_logits .astype (np .float32 )
950954 probs = softmax (padded_logits , axis = - 1 ).astype (key .dtype )
951955
952- encoded = jnp_einsum .einsum ('BNTS,BSNH->BTNH' , probs , value )
956+ encoded = jnp_einsum .einsum ('BNTS,BSNH->BTNH' , probs , value , out_sharding = out_sharding )
953957 if q_seqlen is not None :
954958 mask = _get_padding_mask_encoded (encoded .shape [1 ], q_seqlen )
955959 encoded *= mask .astype (encoded .dtype )
@@ -972,7 +976,9 @@ def _dot_product_attention_xla(
972976 q_seqlen : Array | None ,
973977 kv_seqlen : Array | None ,
974978 local_window_size : tuple [int , int ] | None ,
975- return_residual : bool = False ):
979+ return_residual : bool = False ,
980+ out_sharding : Any = None ,
981+ s_sharding : Any = None ):
976982
977983 B , T , N , H = query .shape
978984 _ , S , K , _ = key .shape
@@ -991,7 +997,7 @@ def _reshape_to_grouped(t):
991997 bias = _reshape_to_grouped (bias )
992998 mask = _reshape_to_grouped (mask )
993999 vmapped_fn = api .vmap (
994- _dot_product_attention_core ,
1000+ partial ( _dot_product_attention_core , out_sharding = out_sharding , s_sharding = s_sharding ) ,
9951001 in_axes = (3 , None , None , 2 , 2 , None , None , None , None , None , None ),
9961002 out_axes = 3 ,
9971003 )
@@ -1056,6 +1062,8 @@ def dot_product_attention(
10561062 local_window_size : int | tuple [int , int ] | None = None ,
10571063 implementation : Literal ['xla' , 'cudnn' ] | None = None ,
10581064 return_residual : bool = False ,
1065+ out_sharding = None ,
1066+ s_sharding = None ,
10591067):
10601068 r"""Scaled dot product attention function.
10611069
@@ -1185,6 +1193,8 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
11851193 kv_seqlen = key_value_seq_lengths ,
11861194 local_window_size = local_window_size ,
11871195 return_residual = return_residual ,
1196+ out_sharding = out_sharding ,
1197+ s_sharding = s_sharding ,
11881198 )
11891199 case 'cudnn' :
11901200 use_padding = (
@@ -1490,3 +1500,41 @@ def log1mexp(x: ArrayLike) -> Array:
14901500 )
14911501
14921502log1mexp .defjvps (lambda g , ans , x : g / jnp .expm1 (x ))
1503+
1504+ def permute_shard (outer_t , in_spec : str , out_spec : str , ** bindings ):
1505+ """
1506+ Permute a tuple according to einops-like dimension specifications.
1507+
1508+ Args:
1509+ t: Input sharding or partition spec or tuple to permute
1510+ in_spec: Input specification string (e.g., "BTNH")
1511+ out_spec: Output specification string (e.g., "BNTH")
1512+ **bindings: Values for dimensions in out_spec not present in in_spec
1513+
1514+ Returns:
1515+ Permuted tuple according to out_spec
1516+
1517+ Examples:
1518+ >>> permute_shard(P('A', 'B', 'C'), "ABC", "ACB")
1519+ PartitionSpec('A', 'C', 'B')
1520+
1521+ >>> permute_shard(P('A', 'B', 'C'), "XC", "CX")
1522+ PartitionSpec('C', 'A', 'B')
1523+
1524+ >>> permute_shard(P('A', 'B', 'C'), "ABC", "ANZ", N="D")
1525+ PartitionSpec('A', 'D', None)
1526+ """
1527+ t = outer_t .spec if isinstance (outer_t , NamedSharding ) else outer_t
1528+ d = len (t ) - len (in_spec )
1529+ assert d >= 0 , f"Tuple length { len (t )} is less than named dimensions { len (in_spec )} "
1530+ np_t = np .array (t )
1531+ dim_map = {dim : [i + d ] for (i , dim ) in enumerate (in_spec )}
1532+ dim_map [in_spec [0 ]] = range (0 , d + 1 )
1533+ new_spec = P (* list (itertools .chain .from_iterable (
1534+ map (str , np_t [dim_map [dim_name ]]) if dim_name in dim_map else [bindings .get (dim_name , None )]
1535+ for dim_name in out_spec
1536+ )))
1537+ if isinstance (outer_t , NamedSharding ):
1538+ return NamedSharding (outer_t .mesh , new_spec )
1539+ else :
1540+ return new_spec
0 commit comments