Skip to content

Commit fd4c32b

Browse files
committed
addressing comments
1 parent 3ac17e9 commit fd4c32b

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

src/diffusers/models/_modeling_parallel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,6 @@ def __post_init__(self):
9090
)
9191
if self.ring_degree < 1 or self.ulysses_degree < 1:
9292
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
93-
# if self.ring_degree > 1 and self.ulysses_degree > 1:
94-
# raise ValueError(
95-
# "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
96-
# )
9793
if self.rotate_method != "allgather":
9894
raise NotImplementedError(
9995
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."

src/diffusers/models/attention_dispatch.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,43 +1041,66 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10411041
return x
10421042

10431043
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
1044+
"""
1045+
Perform dimension sharding / reassembly across processes using _all_to_all_single.
1046+
1047+
This utility reshapes and redistributes tensor `x` across the given process group,
1048+
across sequence dimension or head dimension flexibly by accepting scatter_idx and gather_idx.
1049+
1050+
Args:
1051+
x (torch.Tensor):
1052+
Input tensor. Expected shapes:
1053+
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
1054+
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
1055+
scatter_idx (int) :
1056+
Dimension along which the tensor is partitioned before all-to-all.
1057+
gather_idx (int):
1058+
Dimension along which the output is reassembled after all-to-all.
1059+
group :
1060+
Distributed process group for the Ulysses group.
1061+
1062+
Returns:
1063+
torch.Tensor: Tensor with globally exchanged dimensions.
1064+
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
1065+
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
1066+
"""
10441067
group_world_size = torch.distributed.get_world_size(group)
10451068

10461069
if scatter_idx == 2 and gather_idx == 1:
1047-
B, S_LOCAL, H, D = x.shape
1048-
S = S_LOCAL * group_world_size
1049-
H_LOCAL = H // group_world_size
1070+
#Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
1071+
#dimension and scatters head dimension
1072+
batch_size, seq_len_local, num_heads, head_dim = x.shape
1073+
seq_len = seq_len_local * group_world_size
1074+
num_heads_local = num_heads // group_world_size
10501075

10511076
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1052-
x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous()
1077+
x_temp = x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim).transpose(0, 2).contiguous()
10531078

10541079

10551080
if group_world_size >1:
1056-
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
10571081
out = _all_to_all_single(x_temp, group=group)
1058-
#out = _wait_tensor(out)
10591082
else:
10601083
out = x_temp
10611084
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
1062-
out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous()
1063-
out = out.reshape(B, S, H_LOCAL, D)
1085+
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
1086+
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
10641087
return out
10651088
elif scatter_idx == 1 and gather_idx == 2:
1066-
B, S, H_LOCAL, D = x.shape
1067-
H = H_LOCAL * group_world_size
1068-
S_LOCAL = S // group_world_size
1089+
#Used after ulysses sequence parallel in unified SP. gathers the head dimension
1090+
#scatters back the sequence dimension.
1091+
batch_size, seq_len, num_heads_local, head_dim = x.shape
1092+
num_heads = num_heads_local * group_world_size
1093+
seq_len_local = seq_len // group_world_size
10691094

10701095
#B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1071-
x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)
1096+
x_temp = x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim).permute(1, 3, 2, 0, 4).reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
10721097

10731098
if group_world_size >1:
1074-
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
10751099
output = _all_to_all_single(x_temp, group)
1076-
#output = _wait_tensor(output)
10771100
else:
10781101
output = x_temp
1079-
output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous()
1080-
output = output.reshape(B, S_LOCAL, H, D)
1102+
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
1103+
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
10811104
return output
10821105
else:
10831106
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
@@ -1334,14 +1357,17 @@ def TemplatedUnifiedAttention(
13341357
forward_op,
13351358
backward_op,
13361359
_parallel_config: Optional["ParallelConfig"] = None,
1360+
scatter_idx: int =2,
1361+
gather_idx: int =1,
13371362
):
1363+
"""
1364+
Unified Sequence Parallelism attention combining Ulysses and ring attention.
1365+
See: https://arxiv.org/abs/2405.07719
1366+
"""
13381367
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
13391368
ulysses_group = ulysses_mesh.get_group()
13401369
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
13411370
ring_group = ring_mesh.get_group()
1342-
#hardcoded for now
1343-
scatter_idx = 2
1344-
gather_idx = 1
13451371

13461372
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
13471373
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)

0 commit comments

Comments
 (0)