diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 22e8a30427b9..673c9c4bc574 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -333,3 +333,21 @@ pipeline = DiffusionPipeline.from_pretrained( CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16, ).to(device) ``` +### Unified Attention + +[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout. + +This hybrid approach leverages the strengths of both methods: +- **Ulysses Attention** efficiently parallelizes across attention heads +- **Ring Attention** handles very long sequences with minimal memory overhead +- Together, they enable 2D parallelization across both heads and sequence dimensions + +[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping). +Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)) +``` + +> [!TIP] +> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices). diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a4eb520c796..1c7703a13c52 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -90,10 +90,6 @@ def __post_init__(self): ) if self.ring_degree < 1 or self.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if self.ring_degree > 1 and self.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) if self.rotate_method != "allgather": raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 310c44457c27..37e821a26f95 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1132,6 +1132,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x +def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: + """ + Perform dimension sharding / reassembly across processes using _all_to_all_single. + + This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or + head dimension flexibly by accepting scatter_idx and gather_idx. + + Args: + x (torch.Tensor): + Input tensor. Expected shapes: + - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim) + - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim) + scatter_idx (int) : + Dimension along which the tensor is partitioned before all-to-all. + gather_idx (int): + Dimension along which the output is reassembled after all-to-all. + group : + Distributed process group for the Ulysses group. + + Returns: + torch.Tensor: Tensor with globally exchanged dimensions. + - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim) + - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim) + """ + group_world_size = torch.distributed.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence + # dimension and scatters head dimension + batch_size, seq_len_local, num_heads, head_dim = x.shape + seq_len = seq_len_local * group_world_size + num_heads_local = num_heads // group_world_size + + # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D + x_temp = ( + x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim) + .transpose(0, 2) + .contiguous() + ) + + if group_world_size > 1: + out = _all_to_all_single(x_temp, group=group) + else: + out = x_temp + # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D + out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous() + out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) + return out + elif scatter_idx == 1 and gather_idx == 2: + # Used after ulysses sequence parallel in unified SP. gathers the head dimension + # scatters back the sequence dimension. + batch_size, seq_len, num_heads_local, head_dim = x.shape + num_heads = num_heads_local * group_world_size + seq_len_local = seq_len // group_world_size + + # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = ( + x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) + .permute(1, 3, 2, 0, 4) + .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) + ) + + if group_world_size > 1: + output = _all_to_all_single(x_temp, group) + else: + output = x_temp + output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous() + output = output.reshape(batch_size, seq_len_local, num_heads, head_dim) + return output + else: + raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") + + +class SeqAllToAllDim(torch.autograd.Function): + """ + all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange + for more info. + """ + + @staticmethod + def forward(ctx, group, input, scatter_id=2, gather_id=1): + ctx.group = group + ctx.scatter_id = scatter_id + ctx.gather_id = gather_id + return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) + + @staticmethod + def backward(ctx, grad_outputs): + grad_input = SeqAllToAllDim.apply( + ctx.group, + grad_outputs, + ctx.gather_id, # reversed + ctx.scatter_id, # reversed + ) + return (None, grad_input, None, None) + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( @@ -1192,7 +1289,10 @@ def forward( out = out.to(torch.float32) lse = lse.to(torch.float32) - lse = lse.unsqueeze(-1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) @@ -1253,7 +1353,7 @@ def backward( grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): @@ -1348,7 +1448,69 @@ def backward( x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) ) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None + + +def TemplatedUnifiedAttention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + scatter_idx: int = 2, + gather_idx: int = 1, +): + """ + Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719 + """ + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + # context_layer is of shape (B, S, H_LOCAL, D) + output = SeqAllToAllDim.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) + if return_lse: + # lse is of shape (B, S, H_LOCAL, 1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) + lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) + lse = lse.squeeze(-1) + return (output, lse) + return output def _templated_context_parallel_attention( @@ -1374,7 +1536,25 @@ def _templated_context_parallel_attention( raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 - if _parallel_config.context_parallel_config.ring_degree > 1: + if ( + _parallel_config.context_parallel_config.ring_degree > 1 + and _parallel_config.context_parallel_config.ulysses_degree > 1 + ): + return TemplatedUnifiedAttention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + elif _parallel_config.context_parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( query, key,