@@ -1041,43 +1041,66 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10411041 return x
10421042
10431043def _all_to_all_dim_exchange (x : torch .Tensor , scatter_idx : int = 2 , gather_idx : int = 1 , group = None ) -> torch .Tensor :
1044+ """
1045+ Perform dimension sharding / reassembly across processes using _all_to_all_single.
1046+
1047+ This utility reshapes and redistributes tensor `x` across the given process group,
1048+ across sequence dimension or head dimension flexibly by accepting scatter_idx and gather_idx.
1049+
1050+ Args:
1051+ x (torch.Tensor):
1052+ Input tensor. Expected shapes:
1053+ - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
1054+ - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
1055+ scatter_idx (int) :
1056+ Dimension along which the tensor is partitioned before all-to-all.
1057+ gather_idx (int):
1058+ Dimension along which the output is reassembled after all-to-all.
1059+ group :
1060+ Distributed process group for the Ulysses group.
1061+
1062+ Returns:
1063+ torch.Tensor: Tensor with globally exchanged dimensions.
1064+ - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
1065+ - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
1066+ """
10441067 group_world_size = torch .distributed .get_world_size (group )
10451068
10461069 if scatter_idx == 2 and gather_idx == 1 :
1047- B , S_LOCAL , H , D = x .shape
1048- S = S_LOCAL * group_world_size
1049- H_LOCAL = H // group_world_size
1070+ #Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
1071+ #dimension and scatters head dimension
1072+ batch_size , seq_len_local , num_heads , head_dim = x .shape
1073+ seq_len = seq_len_local * group_world_size
1074+ num_heads_local = num_heads // group_world_size
10501075
10511076 # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1052- x_temp = x .reshape (B , S_LOCAL , group_world_size , H_LOCAL , D ).transpose (0 , 2 ).contiguous ()
1077+ x_temp = x .reshape (batch_size , seq_len_local , group_world_size , num_heads_local , head_dim ).transpose (0 , 2 ).contiguous ()
10531078
10541079
10551080 if group_world_size > 1 :
1056- #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
10571081 out = _all_to_all_single (x_temp , group = group )
1058- #out = _wait_tensor(out)
10591082 else :
10601083 out = x_temp
10611084 # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
1062- out = out .reshape (S , B , H_LOCAL , D ).permute (1 , 0 , 2 , 3 ).contiguous ()
1063- out = out .reshape (B , S , H_LOCAL , D )
1085+ out = out .reshape (seq_len , batch_size , num_heads_local , head_dim ).permute (1 , 0 , 2 , 3 ).contiguous ()
1086+ out = out .reshape (batch_size , seq_len , num_heads_local , head_dim )
10641087 return out
10651088 elif scatter_idx == 1 and gather_idx == 2 :
1066- B , S , H_LOCAL , D = x .shape
1067- H = H_LOCAL * group_world_size
1068- S_LOCAL = S // group_world_size
1089+ #Used after ulysses sequence parallel in unified SP. gathers the head dimension
1090+ #scatters back the sequence dimension.
1091+ batch_size , seq_len , num_heads_local , head_dim = x .shape
1092+ num_heads = num_heads_local * group_world_size
1093+ seq_len_local = seq_len // group_world_size
10691094
10701095 #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1071- x_temp = x .reshape (B , group_world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , H_LOCAL , S_LOCAL , B , D )
1096+ x_temp = x .reshape (batch_size , group_world_size , seq_len_local , num_heads_local , head_dim ).permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , num_heads_local , seq_len_local , batch_size , head_dim )
10721097
10731098 if group_world_size > 1 :
1074- #maybe here need to use the _all_to_all_single helper to avoid contiguity issues
10751099 output = _all_to_all_single (x_temp , group )
1076- #output = _wait_tensor(output)
10771100 else :
10781101 output = x_temp
1079- output = output .reshape (H , S_LOCAL , B , D ).transpose (0 , 2 ).contiguous ()
1080- output = output .reshape (B , S_LOCAL , H , D )
1102+ output = output .reshape (num_heads , seq_len_local , batch_size , head_dim ).transpose (0 , 2 ).contiguous ()
1103+ output = output .reshape (batch_size , seq_len_local , num_heads , head_dim )
10811104 return output
10821105 else :
10831106 raise ValueError ("Invalid scatter/gather indices for _all_to_all_dim_exchange." )
@@ -1334,14 +1357,17 @@ def TemplatedUnifiedAttention(
13341357 forward_op ,
13351358 backward_op ,
13361359 _parallel_config : Optional ["ParallelConfig" ] = None ,
1360+ scatter_idx : int = 2 ,
1361+ gather_idx : int = 1 ,
13371362 ):
1363+ """
1364+ Unified Sequence Parallelism attention combining Ulysses and ring attention.
1365+ See: https://arxiv.org/abs/2405.07719
1366+ """
13381367 ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
13391368 ulysses_group = ulysses_mesh .get_group ()
13401369 ring_mesh = _parallel_config .context_parallel_config ._ring_mesh
13411370 ring_group = ring_mesh .get_group ()
1342- #hardcoded for now
1343- scatter_idx = 2
1344- gather_idx = 1
13451371
13461372 query = SeqAllToAllDim .apply (ulysses_group , query , scatter_idx , gather_idx )
13471373 key = SeqAllToAllDim .apply (ulysses_group , key , scatter_idx , gather_idx )
0 commit comments