4444 is_xformers_version ,
4545)
4646from ..utils .constants import DIFFUSERS_ATTN_BACKEND , DIFFUSERS_ATTN_CHECKS
47+ from ..utils import is_torch_version
4748
4849
4950if TYPE_CHECKING :
@@ -1107,6 +1108,10 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
11071108
11081109
11091110class SeqAllToAllDim (torch .autograd .Function ):
1111+ """
1112+ all_to_all operation for unified sequence parallelism.
1113+ uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange for more info.
1114+ """
11101115 @staticmethod
11111116 def forward (ctx , group , input , scatter_id = 2 , gather_id = 1 ):
11121117 ctx .group = group
@@ -1186,7 +1191,10 @@ def forward(
11861191 out = out .to (torch .float32 )
11871192 lse = lse .to (torch .float32 )
11881193
1189- lse = lse .unsqueeze (- 1 )
1194+ # Refer to:
1195+ # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1196+ if is_torch_version ("<" , "2.9.0" ):
1197+ lse = lse .unsqueeze (- 1 )
11901198 if prev_out is not None :
11911199 out = prev_out - torch .nn .functional .sigmoid (lse - prev_lse ) * (prev_out - out )
11921200 lse = prev_lse - torch .nn .functional .logsigmoid (prev_lse - lse )
@@ -1342,7 +1350,7 @@ def backward(
13421350 x .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous () for x in (grad_query , grad_key , grad_value )
13431351 )
13441352
1345- return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
1353+ return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None , None
13461354
13471355def TemplatedUnifiedAttention (
13481356 query : torch .Tensor ,
@@ -1366,8 +1374,6 @@ def TemplatedUnifiedAttention(
13661374 """
13671375 ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
13681376 ulysses_group = ulysses_mesh .get_group ()
1369- ring_mesh = _parallel_config .context_parallel_config ._ring_mesh
1370- ring_group = ring_mesh .get_group ()
13711377
13721378 query = SeqAllToAllDim .apply (ulysses_group , query , scatter_idx , gather_idx )
13731379 key = SeqAllToAllDim .apply (ulysses_group , key , scatter_idx , gather_idx )
@@ -1390,18 +1396,20 @@ def TemplatedUnifiedAttention(
13901396 context_layer , lse , * _ = out
13911397 else :
13921398 context_layer = out
1393- # Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D)
1399+ #context_layer is of shape (B, S, H_LOCAL, D)
13941400 output = SeqAllToAllDim .apply (
13951401 ulysses_group ,
13961402 context_layer ,
13971403 gather_idx ,
13981404 scatter_idx ,
13991405 )
14001406 if return_lse :
1401- # not sure if this is correct: Assuming (based on forward ops in ringAttention)
1402- # the lse is of shape (B, S, H_LOCAL)
1403- lse = lse .unsqueeze (- 1 ) # (B, S, H_LOCAL, 1)
1404- lse = SeqAllToAllDim .apply (ulysses_group , lse , scatter_idx = 2 , gather_idx = 1 )
1407+ #lse is of shape (B, S, H_LOCAL, 1)
1408+ # Refer to:
1409+ # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1410+ if is_torch_version ("<" , "2.9.0" ):
1411+ lse = lse .unsqueeze (- 1 ) # (B, S, H_LOCAL, 1)
1412+ lse = SeqAllToAllDim .apply (ulysses_group , lse , gather_idx , scatter_idx )
14051413 lse = lse .squeeze (- 1 )
14061414 return (output , lse )
14071415 return output
0 commit comments