@@ -1040,12 +1040,13 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10401040 x = _wait_tensor (x )
10411041 return x
10421042
1043+
10431044def _all_to_all_dim_exchange (x : torch .Tensor , scatter_idx : int = 2 , gather_idx : int = 1 , group = None ) -> torch .Tensor :
10441045 """
10451046 Perform dimension sharding / reassembly across processes using _all_to_all_single.
10461047
1047- This utility reshapes and redistributes tensor `x` across the given process group,
1048- across sequence dimension or head dimension flexibly by accepting scatter_idx and gather_idx.
1048+ This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
1049+ head dimension flexibly by accepting scatter_idx and gather_idx.
10491050
10501051 Args:
10511052 x (torch.Tensor):
@@ -1067,17 +1068,20 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10671068 group_world_size = torch .distributed .get_world_size (group )
10681069
10691070 if scatter_idx == 2 and gather_idx == 1 :
1070- #Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
1071- #dimension and scatters head dimension
1071+ # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
1072+ # dimension and scatters head dimension
10721073 batch_size , seq_len_local , num_heads , head_dim = x .shape
10731074 seq_len = seq_len_local * group_world_size
10741075 num_heads_local = num_heads // group_world_size
10751076
10761077 # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1077- x_temp = x .reshape (batch_size , seq_len_local , group_world_size , num_heads_local , head_dim ).transpose (0 , 2 ).contiguous ()
1078-
1078+ x_temp = (
1079+ x .reshape (batch_size , seq_len_local , group_world_size , num_heads_local , head_dim )
1080+ .transpose (0 , 2 )
1081+ .contiguous ()
1082+ )
10791083
1080- if group_world_size > 1 :
1084+ if group_world_size > 1 :
10811085 out = _all_to_all_single (x_temp , group = group )
10821086 else :
10831087 out = x_temp
@@ -1086,16 +1090,20 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10861090 out = out .reshape (batch_size , seq_len , num_heads_local , head_dim )
10871091 return out
10881092 elif scatter_idx == 1 and gather_idx == 2 :
1089- #Used after ulysses sequence parallel in unified SP. gathers the head dimension
1090- #scatters back the sequence dimension.
1093+ # Used after ulysses sequence parallel in unified SP. gathers the head dimension
1094+ # scatters back the sequence dimension.
10911095 batch_size , seq_len , num_heads_local , head_dim = x .shape
10921096 num_heads = num_heads_local * group_world_size
10931097 seq_len_local = seq_len // group_world_size
10941098
1095- #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1096- x_temp = x .reshape (batch_size , group_world_size , seq_len_local , num_heads_local , head_dim ).permute (1 , 3 , 2 , 0 , 4 ).reshape (group_world_size , num_heads_local , seq_len_local , batch_size , head_dim )
1097-
1098- if group_world_size > 1 :
1099+ # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1100+ x_temp = (
1101+ x .reshape (batch_size , group_world_size , seq_len_local , num_heads_local , head_dim )
1102+ .permute (1 , 3 , 2 , 0 , 4 )
1103+ .reshape (group_world_size , num_heads_local , seq_len_local , batch_size , head_dim )
1104+ )
1105+
1106+ if group_world_size > 1 :
10991107 output = _all_to_all_single (x_temp , group )
11001108 else :
11011109 output = x_temp
@@ -1108,9 +1116,10 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
11081116
11091117class SeqAllToAllDim (torch .autograd .Function ):
11101118 """
1111- all_to_all operation for unified sequence parallelism.
1112- uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange for more info.
1119+ all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
1120+ for more info.
11131121 """
1122+
11141123 @staticmethod
11151124 def forward (ctx , group , input , scatter_id = 2 , gather_id = 1 ):
11161125 ctx .group = group
@@ -1123,13 +1132,12 @@ def backward(ctx, grad_outputs):
11231132 grad_input = SeqAllToAllDim .apply (
11241133 ctx .group ,
11251134 grad_outputs ,
1126- ctx .gather_id , # reversed
1135+ ctx .gather_id , # reversed
11271136 ctx .scatter_id , # reversed
11281137 )
11291138 return (None , grad_input , None , None )
11301139
11311140
1132-
11331141class TemplatedRingAttention (torch .autograd .Function ):
11341142 @staticmethod
11351143 def forward (
@@ -1351,6 +1359,7 @@ def backward(
13511359
13521360 return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None , None
13531361
1362+
13541363def TemplatedUnifiedAttention (
13551364 query : torch .Tensor ,
13561365 key : torch .Tensor ,
@@ -1364,12 +1373,11 @@ def TemplatedUnifiedAttention(
13641373 forward_op ,
13651374 backward_op ,
13661375 _parallel_config : Optional ["ParallelConfig" ] = None ,
1367- scatter_idx : int = 2 ,
1368- gather_idx : int = 1 ,
1369- ):
1376+ scatter_idx : int = 2 ,
1377+ gather_idx : int = 1 ,
1378+ ):
13701379 """
1371- Unified Sequence Parallelism attention combining Ulysses and ring attention.
1372- See: https://arxiv.org/abs/2405.07719
1380+ Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
13731381 """
13741382 ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
13751383 ulysses_group = ulysses_mesh .get_group ()
@@ -1395,15 +1403,15 @@ def TemplatedUnifiedAttention(
13951403 context_layer , lse , * _ = out
13961404 else :
13971405 context_layer = out
1398- #context_layer is of shape (B, S, H_LOCAL, D)
1406+ # context_layer is of shape (B, S, H_LOCAL, D)
13991407 output = SeqAllToAllDim .apply (
14001408 ulysses_group ,
14011409 context_layer ,
14021410 gather_idx ,
14031411 scatter_idx ,
14041412 )
14051413 if return_lse :
1406- #lse is of shape (B, S, H_LOCAL, 1)
1414+ # lse is of shape (B, S, H_LOCAL, 1)
14071415 # Refer to:
14081416 # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
14091417 if is_torch_version ("<" , "2.9.0" ):
@@ -1413,6 +1421,7 @@ def TemplatedUnifiedAttention(
14131421 return (output , lse )
14141422 return output
14151423
1424+
14161425def _templated_context_parallel_attention (
14171426 query : torch .Tensor ,
14181427 key : torch .Tensor ,
@@ -1436,7 +1445,10 @@ def _templated_context_parallel_attention(
14361445 raise ValueError ("GQA is not yet supported for templated attention." )
14371446
14381447 # TODO: add support for unified attention with ring/ulysses degree both being > 1
1439- if _parallel_config .context_parallel_config .ring_degree > 1 and _parallel_config .context_parallel_config .ulysses_degree > 1 :
1448+ if (
1449+ _parallel_config .context_parallel_config .ring_degree > 1
1450+ and _parallel_config .context_parallel_config .ulysses_degree > 1
1451+ ):
14401452 return TemplatedUnifiedAttention (
14411453 query ,
14421454 key ,
0 commit comments