Skip to content

Commit 5579cf1

Browse files
github-actions[bot]Bissmella
authored andcommitted
Apply style fixes
1 parent 3a08cf4 commit 5579cf1

File tree

2 files changed

+49
-41
lines changed

2 files changed

+49
-41
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,12 +1040,13 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10401040
x = _wait_tensor(x)
10411041
return x
10421042

1043+
10431044
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
10441045
"""
10451046
Perform dimension sharding / reassembly across processes using _all_to_all_single.
10461047
1047-
This utility reshapes and redistributes tensor `x` across the given process group,
1048-
across sequence dimension or head dimension flexibly by accepting scatter_idx and gather_idx.
1048+
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
1049+
head dimension flexibly by accepting scatter_idx and gather_idx.
10491050
10501051
Args:
10511052
x (torch.Tensor):
@@ -1067,17 +1068,20 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10671068
group_world_size = torch.distributed.get_world_size(group)
10681069

10691070
if scatter_idx == 2 and gather_idx == 1:
1070-
#Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
1071-
#dimension and scatters head dimension
1071+
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
1072+
# dimension and scatters head dimension
10721073
batch_size, seq_len_local, num_heads, head_dim = x.shape
10731074
seq_len = seq_len_local * group_world_size
10741075
num_heads_local = num_heads // group_world_size
10751076

10761077
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1077-
x_temp = x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim).transpose(0, 2).contiguous()
1078-
1078+
x_temp = (
1079+
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
1080+
.transpose(0, 2)
1081+
.contiguous()
1082+
)
10791083

1080-
if group_world_size >1:
1084+
if group_world_size > 1:
10811085
out = _all_to_all_single(x_temp, group=group)
10821086
else:
10831087
out = x_temp
@@ -1086,16 +1090,20 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10861090
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
10871091
return out
10881092
elif scatter_idx == 1 and gather_idx == 2:
1089-
#Used after ulysses sequence parallel in unified SP. gathers the head dimension
1090-
#scatters back the sequence dimension.
1093+
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
1094+
# scatters back the sequence dimension.
10911095
batch_size, seq_len, num_heads_local, head_dim = x.shape
10921096
num_heads = num_heads_local * group_world_size
10931097
seq_len_local = seq_len // group_world_size
10941098

1095-
#B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1096-
x_temp = x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim).permute(1, 3, 2, 0, 4).reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
1097-
1098-
if group_world_size >1:
1099+
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1100+
x_temp = (
1101+
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
1102+
.permute(1, 3, 2, 0, 4)
1103+
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
1104+
)
1105+
1106+
if group_world_size > 1:
10991107
output = _all_to_all_single(x_temp, group)
11001108
else:
11011109
output = x_temp
@@ -1108,9 +1116,10 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
11081116

11091117
class SeqAllToAllDim(torch.autograd.Function):
11101118
"""
1111-
all_to_all operation for unified sequence parallelism.
1112-
uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange for more info.
1119+
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
1120+
for more info.
11131121
"""
1122+
11141123
@staticmethod
11151124
def forward(ctx, group, input, scatter_id=2, gather_id=1):
11161125
ctx.group = group
@@ -1123,13 +1132,12 @@ def backward(ctx, grad_outputs):
11231132
grad_input = SeqAllToAllDim.apply(
11241133
ctx.group,
11251134
grad_outputs,
1126-
ctx.gather_id, # reversed
1135+
ctx.gather_id, # reversed
11271136
ctx.scatter_id, # reversed
11281137
)
11291138
return (None, grad_input, None, None)
11301139

11311140

1132-
11331141
class TemplatedRingAttention(torch.autograd.Function):
11341142
@staticmethod
11351143
def forward(
@@ -1351,6 +1359,7 @@ def backward(
13511359

13521360
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
13531361

1362+
13541363
def TemplatedUnifiedAttention(
13551364
query: torch.Tensor,
13561365
key: torch.Tensor,
@@ -1364,12 +1373,11 @@ def TemplatedUnifiedAttention(
13641373
forward_op,
13651374
backward_op,
13661375
_parallel_config: Optional["ParallelConfig"] = None,
1367-
scatter_idx: int =2,
1368-
gather_idx: int =1,
1369-
):
1376+
scatter_idx: int = 2,
1377+
gather_idx: int = 1,
1378+
):
13701379
"""
1371-
Unified Sequence Parallelism attention combining Ulysses and ring attention.
1372-
See: https://arxiv.org/abs/2405.07719
1380+
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
13731381
"""
13741382
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
13751383
ulysses_group = ulysses_mesh.get_group()
@@ -1395,15 +1403,15 @@ def TemplatedUnifiedAttention(
13951403
context_layer, lse, *_ = out
13961404
else:
13971405
context_layer = out
1398-
#context_layer is of shape (B, S, H_LOCAL, D)
1406+
# context_layer is of shape (B, S, H_LOCAL, D)
13991407
output = SeqAllToAllDim.apply(
14001408
ulysses_group,
14011409
context_layer,
14021410
gather_idx,
14031411
scatter_idx,
14041412
)
14051413
if return_lse:
1406-
#lse is of shape (B, S, H_LOCAL, 1)
1414+
# lse is of shape (B, S, H_LOCAL, 1)
14071415
# Refer to:
14081416
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
14091417
if is_torch_version("<", "2.9.0"):
@@ -1413,6 +1421,7 @@ def TemplatedUnifiedAttention(
14131421
return (output, lse)
14141422
return output
14151423

1424+
14161425
def _templated_context_parallel_attention(
14171426
query: torch.Tensor,
14181427
key: torch.Tensor,
@@ -1436,7 +1445,10 @@ def _templated_context_parallel_attention(
14361445
raise ValueError("GQA is not yet supported for templated attention.")
14371446

14381447
# TODO: add support for unified attention with ring/ulysses degree both being > 1
1439-
if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1:
1448+
if (
1449+
_parallel_config.context_parallel_config.ring_degree > 1
1450+
and _parallel_config.context_parallel_config.ulysses_degree > 1
1451+
):
14401452
return TemplatedUnifiedAttention(
14411453
query,
14421454
key,

tests/others/test_unified_sp_attention.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,12 @@
55
import torch.distributed as dist
66
import torch.multiprocessing as mp
77

8+
from diffusers.models._modeling_parallel import ContextParallelConfig, ParallelConfig
89
from diffusers.models.attention_dispatch import TemplatedUnifiedAttention
9-
from diffusers.models._modeling_parallel import (
10-
ParallelConfig,
11-
ContextParallelConfig
12-
)
10+
1311

1412
def run(rank, world_size):
15-
dist.init_process_group(
16-
backend="gloo",
17-
rank=rank,
18-
world_size=world_size
19-
)
13+
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
2014

2115
torch.manual_seed(0)
2216

@@ -27,8 +21,6 @@ def run(rank, world_size):
2721

2822
q.requires_grad_(True)
2923

30-
31-
3224
pc = ParallelConfig(
3325
context_parallel_config=ContextParallelConfig(
3426
ring_degree=2,
@@ -40,10 +32,11 @@ def run(rank, world_size):
4032
rank=rank,
4133
world_size=world_size,
4234
device=torch.device("cpu"),
43-
mesh=dist.device_mesh.init_device_mesh("cpu",
44-
(2,2),
35+
mesh=dist.device_mesh.init_device_mesh(
36+
"cpu",
37+
(2, 2),
4538
mesh_dim_names=["ring", "ulysses"],
46-
)
39+
),
4740
)
4841

4942
def dummy_forward_op(
@@ -105,9 +98,11 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs):
10598
grad_v,
10699
)
107100

108-
109101
out = TemplatedUnifiedAttention(
110-
q, k, v, None,
102+
q,
103+
k,
104+
v,
105+
None,
111106
dropout_p=0.0,
112107
is_causal=False,
113108
scale=None,
@@ -125,6 +120,7 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs):
125120

126121
dist.destroy_process_group()
127122

123+
128124
if __name__ == "__main__":
129125
world_size = 4
130126
os.environ["MASTER_ADDR"] = "localhost"

0 commit comments

Comments
 (0)