Skip to content

Commit eebe119

Browse files
committed
unified attention prototype done
1 parent 3743558 commit eebe119

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10401040
out = _wait_tensor(out)
10411041
else:
10421042
out = x_temp
1043+
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
10431044
out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous()
10441045
out = out.reshape(B, S, H_LOCAL, D)
10451046
return out
@@ -1053,6 +1054,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10531054
.permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D))
10541055

10551056
if group_world_size >1:
1057+
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
10561058
output = funcol.all_to_all_single(x_temp, None, None, group)
10571059
output = _wait_tensor(output)
10581060
else:
@@ -1066,12 +1068,15 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10661068

10671069
class SeqAllToAllDim(torch.autograd.Function):
10681070
@staticmethod
1069-
def forward():
1070-
pass
1071+
def forward(ctx, group, input, scatter_id=2, gather_id=1):
1072+
ctx.group = group
1073+
ctx.scatter_id = scatter_id
1074+
ctx.gather_id = gather_id
1075+
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
10711076

10721077
@staticmethod
1073-
def backward():
1074-
pass
1078+
def backward(ctx, *grad_outputs):
1079+
return (None, _all_to_all_dim_exchange(grad_outputs[0], ctx.gather_id, ctx.scatter_id, ctx.group), None, None)
10751080

10761081

10771082

@@ -1313,12 +1318,13 @@ def forward(ctx: torch.autograd.function.FunctionCtx,
13131318
ulysses_group = ulysses_mesh.get_group()
13141319
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
13151320
ring_group = ring_mesh.get_group()
1321+
#hardcoded for now
13161322
scatter_idx = 2
13171323
gather_idx = 1
13181324

1319-
query = SeqAllToAllDouble.apply(ulysses_group, query, scatter_idx, gather_idx)
1320-
key = SeqAllToAllDouble.apply(ulysses_group, key, scatter_idx, gather_idx)
1321-
value = SeqAllToAllDouble.apply(ulysses_group, value, scatter_idx, gather_idx)
1325+
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
1326+
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
1327+
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
13221328
out = TemplatedRingAttention.apply(
13231329
query,
13241330
key,
@@ -1337,12 +1343,17 @@ def forward(ctx: torch.autograd.function.FunctionCtx,
13371343
context_layer, lse, *_ = out
13381344
else:
13391345
context_layer = out
1340-
output = SeqAllToAllDouble.apply(
1346+
output = SeqAllToAllDim.apply(
13411347
ulysses_group,
13421348
context_layer,
13431349
gather_idx,
13441350
scatter_idx,
13451351
)
1352+
if return_lse:
1353+
# not sure if this is correct
1354+
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
1355+
return (output, lse)
1356+
return output
13461357

13471358
def _templated_context_parallel_attention(
13481359
query: torch.Tensor,

0 commit comments

Comments
 (0)