From 123eb686ee3de15853b0a89fd7b432714aca4fdb Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 09:45:44 +0100 Subject: [PATCH 01/16] initial scheme of unified-sp --- src/diffusers/models/attention_dispatch.py | 58 ++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ffad94cc7f27..236b4aa90113 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1040,6 +1040,14 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: x = _wait_tensor(x) return x +def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: + pass + + +class SeqAllToAllDouble(torch.autograd.Function): + pass + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod @@ -1259,6 +1267,56 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None +class TemplatedUnifiedAttention(torch.nn.Module): + @staticmethod + def forward(ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + ring_mesh = _parallel_config.context_parallel_config._ring_mesh + ring_group = ring_mesh.get_group() + scatter_idx = 2 + gather_idx = 1 + + query = SeqAllToAllDouble.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDouble.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDouble.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + output = SeqAllToAllDouble.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) def _templated_context_parallel_attention( query: torch.Tensor, From 2a16bcc56b407f5b8a20f92c79cdd7752b98efb4 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 10:57:55 +0100 Subject: [PATCH 02/16] initial all_to_all_double --- src/diffusers/models/attention_dispatch.py | 47 +++++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 236b4aa90113..aaba91692273 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1041,11 +1041,54 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: - pass + group_world_size = funcol.get_world_size(group) + #dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + B, S_LOCAL, H, D = x.shape + S = S_LOCAL * group_world_size + H_LOCAL = H // group_world_size + + x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D) + .permute(0, 2, 1, 3, 4).contiguous() + ) + + out = torch.empty_like(x_temp) + if group_world_size >1: + funcol.all_to_all_single(out, x_temp, None, None, group) + else: + out = x_temp + out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() + out = out.reshape(B, S, H_LOCAL, D) + return out + elif scatter_idx == 1 and gather_idx == 2: + B, S, H_LOCAL, D = x.shape + H = H_LOCAL * group_world_size + S_LOCAL = S // group_world_size + + # + x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D) + .permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)) + output = torch.empty_like(x_temp) + if group_world_size >1: + funcol.all_to_all_single(output, x_temp, None, None, group) + else: + output = x_temp + output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() + output = output.reshape(B, S_LOCAL, H, D) + return output + else: + raise ValueError("Invalid scatter/gather indices for all_to_all_double.") class SeqAllToAllDouble(torch.autograd.Function): - pass + @staticmethod + def forward(): + pass + + @staticmethod + def backward(): + pass From ec17a1af86682d96ae7ce7b8b4dbfe5e90357c6b Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 12:15:01 +0100 Subject: [PATCH 03/16] bug fixes, added cmnts --- src/diffusers/models/attention_dispatch.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index aaba91692273..dcdee466e078 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1040,22 +1040,23 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: x = _wait_tensor(x) return x -def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: +def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: group_world_size = funcol.get_world_size(group) - #dist.get_world_size(group) if scatter_idx == 2 and gather_idx == 1: B, S_LOCAL, H, D = x.shape S = S_LOCAL * group_world_size H_LOCAL = H // group_world_size + # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D) - .permute(0, 2, 1, 3, 4).contiguous() + .transpose(0, 2).contiguous() ) - out = torch.empty_like(x_temp) if group_world_size >1: - funcol.all_to_all_single(out, x_temp, None, None, group) + #maybe here need to use the _all_to_all_single helper to avoid contiguity issues + out = funcol.all_to_all_single(x_temp, None, None, group=group) + out = _wait_tensor(out) else: out = x_temp out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() @@ -1069,19 +1070,20 @@ def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = # x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D) .permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)) - output = torch.empty_like(x_temp) + if group_world_size >1: - funcol.all_to_all_single(output, x_temp, None, None, group) + output = funcol.all_to_all_single(x_temp, None, None, group) + output = _wait_tensor(output) else: output = x_temp output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() output = output.reshape(B, S_LOCAL, H, D) return output else: - raise ValueError("Invalid scatter/gather indices for all_to_all_double.") + raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") -class SeqAllToAllDouble(torch.autograd.Function): +class SeqAllToAllDim(torch.autograd.Function): @staticmethod def forward(): pass From bbc65b80d01a4cb3c1a4f379b2148ee82b963a46 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 14:36:42 +0100 Subject: [PATCH 04/16] unified attention prototype done --- src/diffusers/models/attention_dispatch.py | 27 +++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index dcdee466e078..a164df4c4a2c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1059,6 +1059,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: out = _wait_tensor(out) else: out = x_temp + # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() out = out.reshape(B, S, H_LOCAL, D) return out @@ -1072,6 +1073,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: .permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)) if group_world_size >1: + #maybe here need to use the _all_to_all_single helper to avoid contiguity issues output = funcol.all_to_all_single(x_temp, None, None, group) output = _wait_tensor(output) else: @@ -1085,12 +1087,15 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: class SeqAllToAllDim(torch.autograd.Function): @staticmethod - def forward(): - pass + def forward(ctx, group, input, scatter_id=2, gather_id=1): + ctx.group = group + ctx.scatter_id = scatter_id + ctx.gather_id = gather_id + return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) @staticmethod - def backward(): - pass + def backward(ctx, *grad_outputs): + return (None, _all_to_all_dim_exchange(grad_outputs[0], ctx.gather_id, ctx.scatter_id, ctx.group), None, None) @@ -1332,12 +1337,13 @@ def forward(ctx: torch.autograd.function.FunctionCtx, ulysses_group = ulysses_mesh.get_group() ring_mesh = _parallel_config.context_parallel_config._ring_mesh ring_group = ring_mesh.get_group() + #hardcoded for now scatter_idx = 2 gather_idx = 1 - query = SeqAllToAllDouble.apply(ulysses_group, query, scatter_idx, gather_idx) - key = SeqAllToAllDouble.apply(ulysses_group, key, scatter_idx, gather_idx) - value = SeqAllToAllDouble.apply(ulysses_group, value, scatter_idx, gather_idx) + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) out = TemplatedRingAttention.apply( query, key, @@ -1356,12 +1362,17 @@ def forward(ctx: torch.autograd.function.FunctionCtx, context_layer, lse, *_ = out else: context_layer = out - output = SeqAllToAllDouble.apply( + output = SeqAllToAllDim.apply( ulysses_group, context_layer, gather_idx, scatter_idx, ) + if return_lse: + # not sure if this is correct + lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) + return (output, lse) + return output def _templated_context_parallel_attention( query: torch.Tensor, From 32940ea9f1516a3018fd58fefd63ae04d69dbe59 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 15:46:42 +0100 Subject: [PATCH 05/16] remove raising value error in contextParallelConfig to enable unified attention --- src/diffusers/models/_modeling_parallel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a4eb520c796..8d9e4193616c 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -90,10 +90,10 @@ def __post_init__(self): ) if self.ring_degree < 1 or self.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if self.ring_degree > 1 and self.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) + # if self.ring_degree > 1 and self.ulysses_degree > 1: + # raise ValueError( + # "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + # ) if self.rotate_method != "allgather": raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." From 326c3dbbb63641ee358d7ed860a29923331efc39 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Fri, 21 Nov 2025 09:15:50 +0100 Subject: [PATCH 06/16] bug fix --- src/diffusers/models/attention_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index a164df4c4a2c..9817c2426195 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1041,7 +1041,7 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: - group_world_size = funcol.get_world_size(group) + group_world_size = torch.distributed.get_world_size(group) if scatter_idx == 2 and gather_idx == 1: B, S_LOCAL, H, D = x.shape From 93162212decd61fdfef84bf806bda4bf4abccdd3 Mon Sep 17 00:00:00 2001 From: KarthikSundar2002 Date: Fri, 21 Nov 2025 11:52:22 +0000 Subject: [PATCH 07/16] feat: Adds Test for Unified SP Attention and Fixes a bug in Template Ring Attention --- src/diffusers/models/attention_dispatch.py | 2 +- tests/others/test_unified_sp_attention.py | 131 +++++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 tests/others/test_unified_sp_attention.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 9817c2426195..700ffa61359e 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1220,7 +1220,7 @@ def backward( grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py new file mode 100644 index 000000000000..00c4403bf3d2 --- /dev/null +++ b/tests/others/test_unified_sp_attention.py @@ -0,0 +1,131 @@ +import math +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from diffusers.models.attention_dispatch import TemplatedUnifiedAttention +import os + +def run(rank, world_size): + dist.init_process_group( + backend="gloo", + rank=rank, + world_size=world_size + ) + + torch.manual_seed(0) + + B, S, H, D = 2, 8, 4, 16 # small toy + q = torch.randn(B, S, H, D) + k = torch.randn(B, S, H, D) + v = torch.randn(B, S, H, D) + + q.requires_grad_(True) + + from diffusers.models._modeling_parallel import ( + ParallelConfig, + ContextParallelConfig + ) + + pc = ParallelConfig( + context_parallel_config=ContextParallelConfig( + ring_degree=2, + ulysses_degree=2, + ) + ) + + pc.context_parallel_config.setup( + rank=rank, + world_size=world_size, + device=torch.device("cpu"), + mesh=dist.device_mesh.init_device_mesh("cpu", + (2,2), + mesh_dim_names=["ring", "ulysses"], + ) + ) + + def dummy_forward_op( + ctx, + q, + k, + v, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + *, + _save_ctx=True, + _parallel_config=None, + ): + head_scale = math.sqrt(D) + attn = (q @ k.transpose(-1, -2)) / head_scale + out = attn @ v + lse = torch.logsumexp(attn, dim=-1) + + if _save_ctx: + ctx.save_for_backward(q, k, v) + ctx._cached_qkv = [] + ctx._cached_iter = 0 + + if not hasattr(ctx, "_cached_qkv"): + ctx._cached_qkv = [] + + ctx._cached_qkv.append((q.detach(), k.detach(), v.detach())) + + return (out, lse) if return_lse else out + + def dummy_backward_op(ctx, grad_out, *args, **kwargs): + if not hasattr(ctx, "_cached_qkv"): + raise RuntimeError("No cached tensors for backward.") + + if not hasattr(ctx, "_cached_iter"): + ctx._cached_iter = 0 + + if ctx._cached_iter >= len(ctx._cached_qkv): + raise RuntimeError("Backward called more times than cached forwards.") + + q, k, v = ctx._cached_qkv[ctx._cached_iter] + ctx._cached_iter += 1 + + head_scale = math.sqrt(D) + attn = (q @ k.transpose(-1, -2)) / head_scale + + grad_v = attn.transpose(-1, -2) @ grad_out + grad_attn = grad_out @ v.transpose(-1, -2) + grad_q = (grad_attn @ k) / head_scale + grad_k = (grad_attn.transpose(-1, -2) @ q) / head_scale + + return ( + grad_q, + grad_k, + grad_v, + ) + + attn = TemplatedUnifiedAttention() + + out = attn( + None, + q, k, v, None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + return_lse=False, + forward_op=dummy_forward_op, + backward_op=dummy_backward_op, + _parallel_config=pc, + ) + + print(f"[RANK {rank}] output:", out.shape) + + out.sum().backward() + print(f"[RANK {rank}] grad:", q.grad.shape) + + dist.destroy_process_group() + +if __name__ == "__main__": + world_size = 4 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + mp.spawn(run, args=(world_size,), nprocs=world_size) \ No newline at end of file From 3ac17e97f77440e72601ed6de7d4bb5eb52e7573 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Sun, 23 Nov 2025 12:34:44 +0100 Subject: [PATCH 08/16] bug fix, lse calculation, testing bug fixes, lse calculation - switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues bug fix bug fix bug fix --- src/diffusers/models/attention_dispatch.py | 155 ++++++++++++--------- tests/others/test_unified_sp_attention.py | 4 +- 2 files changed, 89 insertions(+), 70 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 700ffa61359e..a7d1535c726e 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1049,14 +1049,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: H_LOCAL = H // group_world_size # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D - x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D) - .transpose(0, 2).contiguous() - ) + x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous() + if group_world_size >1: #maybe here need to use the _all_to_all_single helper to avoid contiguity issues - out = funcol.all_to_all_single(x_temp, None, None, group=group) - out = _wait_tensor(out) + out = _all_to_all_single(x_temp, group=group) + #out = _wait_tensor(out) else: out = x_temp # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D @@ -1068,14 +1067,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: H = H_LOCAL * group_world_size S_LOCAL = S // group_world_size - # - x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D) - .permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)) + #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D) if group_world_size >1: #maybe here need to use the _all_to_all_single helper to avoid contiguity issues - output = funcol.all_to_all_single(x_temp, None, None, group) - output = _wait_tensor(output) + output = _all_to_all_single(x_temp, group) + #output = _wait_tensor(output) else: output = x_temp output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() @@ -1094,8 +1092,14 @@ def forward(ctx, group, input, scatter_id=2, gather_id=1): return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) @staticmethod - def backward(ctx, *grad_outputs): - return (None, _all_to_all_dim_exchange(grad_outputs[0], ctx.gather_id, ctx.scatter_id, ctx.group), None, None) + def backward(ctx, grad_outputs): + grad_input = SeqAllToAllDim.apply( + ctx.group, + grad_outputs, + ctx.gather_id, # reversed + ctx.scatter_id, # reversed + ) + return (None, grad_input, None, None) @@ -1317,62 +1321,64 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None -class TemplatedUnifiedAttention(torch.nn.Module): - @staticmethod - def forward(ctx: torch.autograd.function.FunctionCtx, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor], - dropout_p: float, - is_causal: bool, - scale: Optional[float], - enable_gqa: bool, - return_lse: bool, +def TemplatedUnifiedAttention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + ring_mesh = _parallel_config.context_parallel_config._ring_mesh + ring_group = ring_mesh.get_group() + #hardcoded for now + scatter_idx = 2 + gather_idx = 1 + + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, forward_op, backward_op, - _parallel_config: Optional["ParallelConfig"] = None, - ): - ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh - ulysses_group = ulysses_mesh.get_group() - ring_mesh = _parallel_config.context_parallel_config._ring_mesh - ring_group = ring_mesh.get_group() - #hardcoded for now - scatter_idx = 2 - gather_idx = 1 - - query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) - key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) - value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) - out = TemplatedRingAttention.apply( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - forward_op, - backward_op, - _parallel_config, - ) - if return_lse: - context_layer, lse, *_ = out - else: - context_layer = out - output = SeqAllToAllDim.apply( - ulysses_group, - context_layer, - gather_idx, - scatter_idx, - ) - if return_lse: - # not sure if this is correct - lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) - return (output, lse) - return output + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + # Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D) + output = SeqAllToAllDim.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) + if return_lse: + # not sure if this is correct: Assuming (based on forward ops in ringAttention) + # the lse is of shape (B, S, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) + lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1) + lse = lse.squeeze(-1) + return (output, lse) + return output def _templated_context_parallel_attention( query: torch.Tensor, @@ -1397,7 +1403,22 @@ def _templated_context_parallel_attention( raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 - if _parallel_config.context_parallel_config.ring_degree > 1: + if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1: + return TemplatedUnifiedAttention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + elif _parallel_config.context_parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( query, key, diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py index 00c4403bf3d2..4c0621999bd0 100644 --- a/tests/others/test_unified_sp_attention.py +++ b/tests/others/test_unified_sp_attention.py @@ -102,10 +102,8 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs): grad_v, ) - attn = TemplatedUnifiedAttention() - out = attn( - None, + out = TemplatedUnifiedAttention( q, k, v, None, dropout_p=0.0, is_causal=False, From fd4c32b1762ad8f5243709e66a262506210086f7 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Mon, 8 Dec 2025 11:20:45 +0100 Subject: [PATCH 09/16] addressing comments --- src/diffusers/models/_modeling_parallel.py | 4 -- src/diffusers/models/attention_dispatch.py | 64 +++++++++++++++------- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 8d9e4193616c..1c7703a13c52 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -90,10 +90,6 @@ def __post_init__(self): ) if self.ring_degree < 1 or self.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - # if self.ring_degree > 1 and self.ulysses_degree > 1: - # raise ValueError( - # "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - # ) if self.rotate_method != "allgather": raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index a7d1535c726e..aaa45c757329 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1041,43 +1041,66 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: + """ + Perform dimension sharding / reassembly across processes using _all_to_all_single. + + This utility reshapes and redistributes tensor `x` across the given process group, + across sequence dimension or head dimension flexibly by accepting scatter_idx and gather_idx. + + Args: + x (torch.Tensor): + Input tensor. Expected shapes: + - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim) + - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim) + scatter_idx (int) : + Dimension along which the tensor is partitioned before all-to-all. + gather_idx (int): + Dimension along which the output is reassembled after all-to-all. + group : + Distributed process group for the Ulysses group. + + Returns: + torch.Tensor: Tensor with globally exchanged dimensions. + - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim) + - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim) + """ group_world_size = torch.distributed.get_world_size(group) if scatter_idx == 2 and gather_idx == 1: - B, S_LOCAL, H, D = x.shape - S = S_LOCAL * group_world_size - H_LOCAL = H // group_world_size + #Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence + #dimension and scatters head dimension + batch_size, seq_len_local, num_heads, head_dim = x.shape + seq_len = seq_len_local * group_world_size + num_heads_local = num_heads // group_world_size # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D - x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous() + x_temp = x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim).transpose(0, 2).contiguous() if group_world_size >1: - #maybe here need to use the _all_to_all_single helper to avoid contiguity issues out = _all_to_all_single(x_temp, group=group) - #out = _wait_tensor(out) else: out = x_temp # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D - out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() - out = out.reshape(B, S, H_LOCAL, D) + out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous() + out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) return out elif scatter_idx == 1 and gather_idx == 2: - B, S, H_LOCAL, D = x.shape - H = H_LOCAL * group_world_size - S_LOCAL = S // group_world_size + #Used after ulysses sequence parallel in unified SP. gathers the head dimension + #scatters back the sequence dimension. + batch_size, seq_len, num_heads_local, head_dim = x.shape + num_heads = num_heads_local * group_world_size + seq_len_local = seq_len // group_world_size #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D - x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D) + x_temp = x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim).permute(1, 3, 2, 0, 4).reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) if group_world_size >1: - #maybe here need to use the _all_to_all_single helper to avoid contiguity issues output = _all_to_all_single(x_temp, group) - #output = _wait_tensor(output) else: output = x_temp - output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() - output = output.reshape(B, S_LOCAL, H, D) + output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous() + output = output.reshape(batch_size, seq_len_local, num_heads, head_dim) return output else: raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") @@ -1334,14 +1357,17 @@ def TemplatedUnifiedAttention( forward_op, backward_op, _parallel_config: Optional["ParallelConfig"] = None, + scatter_idx: int =2, + gather_idx: int =1, ): + """ + Unified Sequence Parallelism attention combining Ulysses and ring attention. + See: https://arxiv.org/abs/2405.07719 + """ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh ulysses_group = ulysses_mesh.get_group() ring_mesh = _parallel_config.context_parallel_config._ring_mesh ring_group = ring_mesh.get_group() - #hardcoded for now - scatter_idx = 2 - gather_idx = 1 query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) From 628f72d7846b93ee59ff69b1be8595d952d08e7b Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Tue, 9 Dec 2025 11:24:38 +0100 Subject: [PATCH 10/16] sequence parallelsim bug fixes --- src/diffusers/models/attention_dispatch.py | 26 ++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index aaa45c757329..464edd1169c3 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -44,6 +44,7 @@ is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from ..utils import is_torch_version if TYPE_CHECKING: @@ -1107,6 +1108,10 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: class SeqAllToAllDim(torch.autograd.Function): + """ + all_to_all operation for unified sequence parallelism. + uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange for more info. + """ @staticmethod def forward(ctx, group, input, scatter_id=2, gather_id=1): ctx.group = group @@ -1186,7 +1191,10 @@ def forward( out = out.to(torch.float32) lse = lse.to(torch.float32) - lse = lse.unsqueeze(-1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) @@ -1342,7 +1350,7 @@ def backward( x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) ) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None def TemplatedUnifiedAttention( query: torch.Tensor, @@ -1366,8 +1374,6 @@ def TemplatedUnifiedAttention( """ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh ulysses_group = ulysses_mesh.get_group() - ring_mesh = _parallel_config.context_parallel_config._ring_mesh - ring_group = ring_mesh.get_group() query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) @@ -1390,7 +1396,7 @@ def TemplatedUnifiedAttention( context_layer, lse, *_ = out else: context_layer = out - # Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D) + #context_layer is of shape (B, S, H_LOCAL, D) output = SeqAllToAllDim.apply( ulysses_group, context_layer, @@ -1398,10 +1404,12 @@ def TemplatedUnifiedAttention( scatter_idx, ) if return_lse: - # not sure if this is correct: Assuming (based on forward ops in ringAttention) - # the lse is of shape (B, S, H_LOCAL) - lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) - lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1) + #lse is of shape (B, S, H_LOCAL, 1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) + lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) lse = lse.squeeze(-1) return (output, lse) return output From 3a08cf4c76e4152ee2216d7bd0949209bed06cd6 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Tue, 9 Dec 2025 11:55:19 +0100 Subject: [PATCH 11/16] code format fixes --- src/diffusers/models/attention_dispatch.py | 5 ++--- tests/others/test_unified_sp_attention.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 464edd1169c3..31bd0638573c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -44,7 +44,6 @@ is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -from ..utils import is_torch_version if TYPE_CHECKING: @@ -1076,7 +1075,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D x_temp = x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim).transpose(0, 2).contiguous() - + if group_world_size >1: out = _all_to_all_single(x_temp, group=group) @@ -1365,7 +1364,7 @@ def TemplatedUnifiedAttention( forward_op, backward_op, _parallel_config: Optional["ParallelConfig"] = None, - scatter_idx: int =2, + scatter_idx: int =2, gather_idx: int =1, ): """ diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py index 4c0621999bd0..37caad46e8c0 100644 --- a/tests/others/test_unified_sp_attention.py +++ b/tests/others/test_unified_sp_attention.py @@ -1,13 +1,19 @@ import math +import os + import torch import torch.distributed as dist import torch.multiprocessing as mp + from diffusers.models.attention_dispatch import TemplatedUnifiedAttention -import os +from diffusers.models._modeling_parallel import ( + ParallelConfig, + ContextParallelConfig + ) def run(rank, world_size): dist.init_process_group( - backend="gloo", + backend="gloo", rank=rank, world_size=world_size ) @@ -21,10 +27,7 @@ def run(rank, world_size): q.requires_grad_(True) - from diffusers.models._modeling_parallel import ( - ParallelConfig, - ContextParallelConfig - ) + pc = ParallelConfig( context_parallel_config=ContextParallelConfig( @@ -126,4 +129,4 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs): world_size = 4 os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" - mp.spawn(run, args=(world_size,), nprocs=world_size) \ No newline at end of file + mp.spawn(run, args=(world_size,), nprocs=world_size) From 5579cf163f6980ea55aa3e5517e3b6c78c7db46e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 11 Dec 2025 08:22:38 +0000 Subject: [PATCH 12/16] Apply style fixes --- src/diffusers/models/attention_dispatch.py | 62 +++++++++++++--------- tests/others/test_unified_sp_attention.py | 28 +++++----- 2 files changed, 49 insertions(+), 41 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 31bd0638573c..a0d1e9a520eb 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1040,12 +1040,13 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: x = _wait_tensor(x) return x + def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: """ Perform dimension sharding / reassembly across processes using _all_to_all_single. - This utility reshapes and redistributes tensor `x` across the given process group, - across sequence dimension or head dimension flexibly by accepting scatter_idx and gather_idx. + This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or + head dimension flexibly by accepting scatter_idx and gather_idx. Args: x (torch.Tensor): @@ -1067,17 +1068,20 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: group_world_size = torch.distributed.get_world_size(group) if scatter_idx == 2 and gather_idx == 1: - #Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence - #dimension and scatters head dimension + # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence + # dimension and scatters head dimension batch_size, seq_len_local, num_heads, head_dim = x.shape seq_len = seq_len_local * group_world_size num_heads_local = num_heads // group_world_size # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D - x_temp = x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim).transpose(0, 2).contiguous() - + x_temp = ( + x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim) + .transpose(0, 2) + .contiguous() + ) - if group_world_size >1: + if group_world_size > 1: out = _all_to_all_single(x_temp, group=group) else: out = x_temp @@ -1086,16 +1090,20 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) return out elif scatter_idx == 1 and gather_idx == 2: - #Used after ulysses sequence parallel in unified SP. gathers the head dimension - #scatters back the sequence dimension. + # Used after ulysses sequence parallel in unified SP. gathers the head dimension + # scatters back the sequence dimension. batch_size, seq_len, num_heads_local, head_dim = x.shape num_heads = num_heads_local * group_world_size seq_len_local = seq_len // group_world_size - #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D - x_temp = x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim).permute(1, 3, 2, 0, 4).reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) - - if group_world_size >1: + # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = ( + x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) + .permute(1, 3, 2, 0, 4) + .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) + ) + + if group_world_size > 1: output = _all_to_all_single(x_temp, group) else: output = x_temp @@ -1108,9 +1116,10 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: class SeqAllToAllDim(torch.autograd.Function): """ - all_to_all operation for unified sequence parallelism. - uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange for more info. + all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange + for more info. """ + @staticmethod def forward(ctx, group, input, scatter_id=2, gather_id=1): ctx.group = group @@ -1123,13 +1132,12 @@ def backward(ctx, grad_outputs): grad_input = SeqAllToAllDim.apply( ctx.group, grad_outputs, - ctx.gather_id, # reversed + ctx.gather_id, # reversed ctx.scatter_id, # reversed ) return (None, grad_input, None, None) - class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( @@ -1351,6 +1359,7 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None + def TemplatedUnifiedAttention( query: torch.Tensor, key: torch.Tensor, @@ -1364,12 +1373,11 @@ def TemplatedUnifiedAttention( forward_op, backward_op, _parallel_config: Optional["ParallelConfig"] = None, - scatter_idx: int =2, - gather_idx: int =1, - ): + scatter_idx: int = 2, + gather_idx: int = 1, +): """ - Unified Sequence Parallelism attention combining Ulysses and ring attention. - See: https://arxiv.org/abs/2405.07719 + Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719 """ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh ulysses_group = ulysses_mesh.get_group() @@ -1395,7 +1403,7 @@ def TemplatedUnifiedAttention( context_layer, lse, *_ = out else: context_layer = out - #context_layer is of shape (B, S, H_LOCAL, D) + # context_layer is of shape (B, S, H_LOCAL, D) output = SeqAllToAllDim.apply( ulysses_group, context_layer, @@ -1403,7 +1411,7 @@ def TemplatedUnifiedAttention( scatter_idx, ) if return_lse: - #lse is of shape (B, S, H_LOCAL, 1) + # lse is of shape (B, S, H_LOCAL, 1) # Refer to: # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 if is_torch_version("<", "2.9.0"): @@ -1413,6 +1421,7 @@ def TemplatedUnifiedAttention( return (output, lse) return output + def _templated_context_parallel_attention( query: torch.Tensor, key: torch.Tensor, @@ -1436,7 +1445,10 @@ def _templated_context_parallel_attention( raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 - if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1: + if ( + _parallel_config.context_parallel_config.ring_degree > 1 + and _parallel_config.context_parallel_config.ulysses_degree > 1 + ): return TemplatedUnifiedAttention( query, key, diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py index 37caad46e8c0..106f050ffd69 100644 --- a/tests/others/test_unified_sp_attention.py +++ b/tests/others/test_unified_sp_attention.py @@ -5,18 +5,12 @@ import torch.distributed as dist import torch.multiprocessing as mp +from diffusers.models._modeling_parallel import ContextParallelConfig, ParallelConfig from diffusers.models.attention_dispatch import TemplatedUnifiedAttention -from diffusers.models._modeling_parallel import ( - ParallelConfig, - ContextParallelConfig - ) + def run(rank, world_size): - dist.init_process_group( - backend="gloo", - rank=rank, - world_size=world_size - ) + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) torch.manual_seed(0) @@ -27,8 +21,6 @@ def run(rank, world_size): q.requires_grad_(True) - - pc = ParallelConfig( context_parallel_config=ContextParallelConfig( ring_degree=2, @@ -40,10 +32,11 @@ def run(rank, world_size): rank=rank, world_size=world_size, device=torch.device("cpu"), - mesh=dist.device_mesh.init_device_mesh("cpu", - (2,2), + mesh=dist.device_mesh.init_device_mesh( + "cpu", + (2, 2), mesh_dim_names=["ring", "ulysses"], - ) + ), ) def dummy_forward_op( @@ -105,9 +98,11 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs): grad_v, ) - out = TemplatedUnifiedAttention( - q, k, v, None, + q, + k, + v, + None, dropout_p=0.0, is_causal=False, scale=None, @@ -125,6 +120,7 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs): dist.destroy_process_group() + if __name__ == "__main__": world_size = 4 os.environ["MASTER_ADDR"] = "localhost" From da748347e1b525624cabed3d701e3be47f4cd49e Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 11 Dec 2025 09:32:56 +0100 Subject: [PATCH 13/16] code formatting fix --- src/diffusers/models/attention_dispatch.py | 10 +++------- tests/others/test_unified_sp_attention.py | 2 ++ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index a0d1e9a520eb..49d9b8628f58 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1096,14 +1096,10 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: num_heads = num_heads_local * group_world_size seq_len_local = seq_len // group_world_size - # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D - x_temp = ( - x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) - .permute(1, 3, 2, 0, 4) - .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) - ) + #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim).permute(1, 3, 2, 0, 4).reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) - if group_world_size > 1: + if group_world_size >1: output = _all_to_all_single(x_temp, group) else: output = x_temp diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py index 106f050ffd69..ad8ea4471ac4 100644 --- a/tests/others/test_unified_sp_attention.py +++ b/tests/others/test_unified_sp_attention.py @@ -21,6 +21,8 @@ def run(rank, world_size): q.requires_grad_(True) + + pc = ParallelConfig( context_parallel_config=ContextParallelConfig( ring_degree=2, From e681fe4b2c4d0b17813f828ca42acfd52a723ee5 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 11 Dec 2025 10:22:16 +0100 Subject: [PATCH 14/16] added unified attention docs and removed test file --- .../en/training/distributed_inference.md | 19 ++- tests/others/test_unified_sp_attention.py | 130 ------------------ 2 files changed, 18 insertions(+), 131 deletions(-) delete mode 100644 tests/others/test_unified_sp_attention.py diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 534124cb93ec..859f42456ec5 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -332,4 +332,21 @@ transformer = AutoModel.from_pretrained( pipeline = DiffusionPipeline.from_pretrained( CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16, ).to(device) -``` \ No newline at end of file +``` +### Unified Attention + +[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout. + +This hybrid approach leverages the strengths of both methods: +- **Ulysses Attention** efficiently parallelizes across attention heads +- **Ring Attention** handles very long sequences with minimal memory overhead +- Together, they enable 2D parallelization across both heads and sequence dimensions + +[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping). +Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)) +``` + +Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices). diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py deleted file mode 100644 index ad8ea4471ac4..000000000000 --- a/tests/others/test_unified_sp_attention.py +++ /dev/null @@ -1,130 +0,0 @@ -import math -import os - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from diffusers.models._modeling_parallel import ContextParallelConfig, ParallelConfig -from diffusers.models.attention_dispatch import TemplatedUnifiedAttention - - -def run(rank, world_size): - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) - - torch.manual_seed(0) - - B, S, H, D = 2, 8, 4, 16 # small toy - q = torch.randn(B, S, H, D) - k = torch.randn(B, S, H, D) - v = torch.randn(B, S, H, D) - - q.requires_grad_(True) - - - - pc = ParallelConfig( - context_parallel_config=ContextParallelConfig( - ring_degree=2, - ulysses_degree=2, - ) - ) - - pc.context_parallel_config.setup( - rank=rank, - world_size=world_size, - device=torch.device("cpu"), - mesh=dist.device_mesh.init_device_mesh( - "cpu", - (2, 2), - mesh_dim_names=["ring", "ulysses"], - ), - ) - - def dummy_forward_op( - ctx, - q, - k, - v, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - *, - _save_ctx=True, - _parallel_config=None, - ): - head_scale = math.sqrt(D) - attn = (q @ k.transpose(-1, -2)) / head_scale - out = attn @ v - lse = torch.logsumexp(attn, dim=-1) - - if _save_ctx: - ctx.save_for_backward(q, k, v) - ctx._cached_qkv = [] - ctx._cached_iter = 0 - - if not hasattr(ctx, "_cached_qkv"): - ctx._cached_qkv = [] - - ctx._cached_qkv.append((q.detach(), k.detach(), v.detach())) - - return (out, lse) if return_lse else out - - def dummy_backward_op(ctx, grad_out, *args, **kwargs): - if not hasattr(ctx, "_cached_qkv"): - raise RuntimeError("No cached tensors for backward.") - - if not hasattr(ctx, "_cached_iter"): - ctx._cached_iter = 0 - - if ctx._cached_iter >= len(ctx._cached_qkv): - raise RuntimeError("Backward called more times than cached forwards.") - - q, k, v = ctx._cached_qkv[ctx._cached_iter] - ctx._cached_iter += 1 - - head_scale = math.sqrt(D) - attn = (q @ k.transpose(-1, -2)) / head_scale - - grad_v = attn.transpose(-1, -2) @ grad_out - grad_attn = grad_out @ v.transpose(-1, -2) - grad_q = (grad_attn @ k) / head_scale - grad_k = (grad_attn.transpose(-1, -2) @ q) / head_scale - - return ( - grad_q, - grad_k, - grad_v, - ) - - out = TemplatedUnifiedAttention( - q, - k, - v, - None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False, - return_lse=False, - forward_op=dummy_forward_op, - backward_op=dummy_backward_op, - _parallel_config=pc, - ) - - print(f"[RANK {rank}] output:", out.shape) - - out.sum().backward() - print(f"[RANK {rank}] grad:", q.grad.shape) - - dist.destroy_process_group() - - -if __name__ == "__main__": - world_size = 4 - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - mp.spawn(run, args=(world_size,), nprocs=world_size) From 5643bc93a90516a8d1083cd4f9f9c42b38e267cd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 13 Dec 2025 03:42:26 +0000 Subject: [PATCH 15/16] Apply style fixes --- src/diffusers/models/attention_dispatch.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 2fb47b4aba4a..37e821a26f95 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1187,10 +1187,14 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: num_heads = num_heads_local * group_world_size seq_len_local = seq_len // group_world_size - #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D - x_temp = x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim).permute(1, 3, 2, 0, 4).reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) + # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = ( + x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) + .permute(1, 3, 2, 0, 4) + .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) + ) - if group_world_size >1: + if group_world_size > 1: output = _all_to_all_single(x_temp, group) else: output = x_temp From 27ec85f1690cf3b203fc4d34b50b4ff14552b6a2 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri <66717082+Bissmella@users.noreply.github.com> Date: Sat, 13 Dec 2025 21:37:10 +0100 Subject: [PATCH 16/16] tip for unified attention in docs at distributed_inference.md Co-authored-by: Sayak Paul --- docs/source/en/training/distributed_inference.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index a449fe64658d..673c9c4bc574 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -349,4 +349,5 @@ Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)) ``` -Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices). +> [!TIP] +> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).