Skip to content

Commit 628f72d

Browse files
committed
sequence parallelsim bug fixes
1 parent fd4c32b commit 628f72d

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
is_xformers_version,
4545
)
4646
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
47+
from ..utils import is_torch_version
4748

4849

4950
if TYPE_CHECKING:
@@ -1107,6 +1108,10 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
11071108

11081109

11091110
class SeqAllToAllDim(torch.autograd.Function):
1111+
"""
1112+
all_to_all operation for unified sequence parallelism.
1113+
uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange for more info.
1114+
"""
11101115
@staticmethod
11111116
def forward(ctx, group, input, scatter_id=2, gather_id=1):
11121117
ctx.group = group
@@ -1186,7 +1191,10 @@ def forward(
11861191
out = out.to(torch.float32)
11871192
lse = lse.to(torch.float32)
11881193

1189-
lse = lse.unsqueeze(-1)
1194+
# Refer to:
1195+
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1196+
if is_torch_version("<", "2.9.0"):
1197+
lse = lse.unsqueeze(-1)
11901198
if prev_out is not None:
11911199
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
11921200
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1342,7 +1350,7 @@ def backward(
13421350
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
13431351
)
13441352

1345-
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
1353+
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
13461354

13471355
def TemplatedUnifiedAttention(
13481356
query: torch.Tensor,
@@ -1366,8 +1374,6 @@ def TemplatedUnifiedAttention(
13661374
"""
13671375
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
13681376
ulysses_group = ulysses_mesh.get_group()
1369-
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
1370-
ring_group = ring_mesh.get_group()
13711377

13721378
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
13731379
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
@@ -1390,18 +1396,20 @@ def TemplatedUnifiedAttention(
13901396
context_layer, lse, *_ = out
13911397
else:
13921398
context_layer = out
1393-
# Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D)
1399+
#context_layer is of shape (B, S, H_LOCAL, D)
13941400
output = SeqAllToAllDim.apply(
13951401
ulysses_group,
13961402
context_layer,
13971403
gather_idx,
13981404
scatter_idx,
13991405
)
14001406
if return_lse:
1401-
# not sure if this is correct: Assuming (based on forward ops in ringAttention)
1402-
# the lse is of shape (B, S, H_LOCAL)
1403-
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
1404-
lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1)
1407+
#lse is of shape (B, S, H_LOCAL, 1)
1408+
# Refer to:
1409+
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1410+
if is_torch_version("<", "2.9.0"):
1411+
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
1412+
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
14051413
lse = lse.squeeze(-1)
14061414
return (output, lse)
14071415
return output

0 commit comments

Comments
 (0)