Skip to content

Commit aec0804

Browse files
committed
initial scheme of unified-sp
1 parent d5da453 commit aec0804

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,14 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10211021
x = _wait_tensor(x)
10221022
return x
10231023

1024+
def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
1025+
pass
1026+
1027+
1028+
class SeqAllToAllDouble(torch.autograd.Function):
1029+
pass
1030+
1031+
10241032

10251033
class TemplatedRingAttention(torch.autograd.Function):
10261034
@staticmethod
@@ -1240,6 +1248,56 @@ def backward(
12401248

12411249
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
12421250

1251+
class TemplatedUnifiedAttention(torch.nn.Module):
1252+
@staticmethod
1253+
def forward(ctx: torch.autograd.function.FunctionCtx,
1254+
query: torch.Tensor,
1255+
key: torch.Tensor,
1256+
value: torch.Tensor,
1257+
attn_mask: Optional[torch.Tensor],
1258+
dropout_p: float,
1259+
is_causal: bool,
1260+
scale: Optional[float],
1261+
enable_gqa: bool,
1262+
return_lse: bool,
1263+
forward_op,
1264+
backward_op,
1265+
_parallel_config: Optional["ParallelConfig"] = None,
1266+
):
1267+
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
1268+
ulysses_group = ulysses_mesh.get_group()
1269+
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
1270+
ring_group = ring_mesh.get_group()
1271+
scatter_idx = 2
1272+
gather_idx = 1
1273+
1274+
query = SeqAllToAllDouble.apply(ulysses_group, query, scatter_idx, gather_idx)
1275+
key = SeqAllToAllDouble.apply(ulysses_group, key, scatter_idx, gather_idx)
1276+
value = SeqAllToAllDouble.apply(ulysses_group, value, scatter_idx, gather_idx)
1277+
out = TemplatedRingAttention.apply(
1278+
query,
1279+
key,
1280+
value,
1281+
attn_mask,
1282+
dropout_p,
1283+
is_causal,
1284+
scale,
1285+
enable_gqa,
1286+
return_lse,
1287+
forward_op,
1288+
backward_op,
1289+
_parallel_config,
1290+
)
1291+
if return_lse:
1292+
context_layer, lse, *_ = out
1293+
else:
1294+
context_layer = out
1295+
output = SeqAllToAllDouble.apply(
1296+
ulysses_group,
1297+
context_layer,
1298+
gather_idx,
1299+
scatter_idx,
1300+
)
12431301

12441302
def _templated_context_parallel_attention(
12451303
query: torch.Tensor,

0 commit comments

Comments
 (0)