-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add Unified Sequence Parallel attention #12693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
123eb68
2a16bcc
ec17a1a
bbc65b8
32940ea
326c3db
9316221
3ac17e9
fd4c32b
628f72d
3a08cf4
5579cf1
da74834
e681fe4
68f4fd5
5643bc9
27ec85f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1132,6 +1132,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: | |
| return x | ||
|
|
||
|
|
||
| def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Perform dimension sharding / reassembly across processes using _all_to_all_single. | ||
|
|
||
| This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or | ||
| head dimension flexibly by accepting scatter_idx and gather_idx. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): | ||
| Input tensor. Expected shapes: | ||
| - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim) | ||
| - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim) | ||
| scatter_idx (int) : | ||
| Dimension along which the tensor is partitioned before all-to-all. | ||
| gather_idx (int): | ||
| Dimension along which the output is reassembled after all-to-all. | ||
| group : | ||
| Distributed process group for the Ulysses group. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Tensor with globally exchanged dimensions. | ||
| - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim) | ||
| - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim) | ||
| """ | ||
| group_world_size = torch.distributed.get_world_size(group) | ||
|
|
||
| if scatter_idx == 2 and gather_idx == 1: | ||
| # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence | ||
| # dimension and scatters head dimension | ||
| batch_size, seq_len_local, num_heads, head_dim = x.shape | ||
| seq_len = seq_len_local * group_world_size | ||
| num_heads_local = num_heads // group_world_size | ||
|
|
||
| # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D | ||
| x_temp = ( | ||
| x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim) | ||
| .transpose(0, 2) | ||
| .contiguous() | ||
| ) | ||
|
|
||
| if group_world_size > 1: | ||
| out = _all_to_all_single(x_temp, group=group) | ||
| else: | ||
| out = x_temp | ||
| # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D | ||
| out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous() | ||
| out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) | ||
| return out | ||
| elif scatter_idx == 1 and gather_idx == 2: | ||
| # Used after ulysses sequence parallel in unified SP. gathers the head dimension | ||
| # scatters back the sequence dimension. | ||
| batch_size, seq_len, num_heads_local, head_dim = x.shape | ||
| num_heads = num_heads_local * group_world_size | ||
| seq_len_local = seq_len // group_world_size | ||
|
|
||
| # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D | ||
| x_temp = ( | ||
| x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) | ||
| .permute(1, 3, 2, 0, 4) | ||
| .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) | ||
| ) | ||
|
|
||
| if group_world_size > 1: | ||
| output = _all_to_all_single(x_temp, group) | ||
| else: | ||
| output = x_temp | ||
| output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous() | ||
| output = output.reshape(batch_size, seq_len_local, num_heads, head_dim) | ||
| return output | ||
| else: | ||
| raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") | ||
|
|
||
|
|
||
| class SeqAllToAllDim(torch.autograd.Function): | ||
| """ | ||
| all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange | ||
| for more info. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, group, input, scatter_id=2, gather_id=1): | ||
| ctx.group = group | ||
| ctx.scatter_id = scatter_id | ||
| ctx.gather_id = gather_id | ||
| return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_outputs): | ||
| grad_input = SeqAllToAllDim.apply( | ||
| ctx.group, | ||
| grad_outputs, | ||
| ctx.gather_id, # reversed | ||
| ctx.scatter_id, # reversed | ||
| ) | ||
| return (None, grad_input, None, None) | ||
|
|
||
|
|
||
| class TemplatedRingAttention(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
|
|
@@ -1192,7 +1289,10 @@ def forward( | |
| out = out.to(torch.float32) | ||
| lse = lse.to(torch.float32) | ||
|
|
||
| lse = lse.unsqueeze(-1) | ||
| # Refer to: | ||
| # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 | ||
| if is_torch_version("<", "2.9.0"): | ||
| lse = lse.unsqueeze(-1) | ||
| if prev_out is not None: | ||
| out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) | ||
| lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) | ||
|
|
@@ -1253,7 +1353,7 @@ def backward( | |
|
|
||
| grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) | ||
|
|
||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | ||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the change here?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The forward function has 12 inputs (without ctx (context)) but the backward is giving 11 output. Normally the two should be the same. I was getting an error like this while testing: "RuntimeError: function backward returned an incorrect number of gradients (expected 12, got 11)".
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have a reproducer?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it can be reproduced in this notebook (it happens only during the backward): https://colab.research.google.com/drive/1Ac4nVSVjKHrPpcSRlX0E3NzY0mDEmkMx?usp=sharing |
||
|
|
||
|
|
||
| class TemplatedUlyssesAttention(torch.autograd.Function): | ||
|
|
@@ -1348,7 +1448,69 @@ def backward( | |
| x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) | ||
| ) | ||
|
|
||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | ||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None | ||
|
|
||
|
|
||
| def TemplatedUnifiedAttention( | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: Optional[torch.Tensor], | ||
| dropout_p: float, | ||
| is_causal: bool, | ||
| scale: Optional[float], | ||
| enable_gqa: bool, | ||
| return_lse: bool, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config: Optional["ParallelConfig"] = None, | ||
| scatter_idx: int = 2, | ||
| gather_idx: int = 1, | ||
| ): | ||
| """ | ||
| Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719 | ||
| """ | ||
| ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh | ||
| ulysses_group = ulysses_mesh.get_group() | ||
|
|
||
| query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) | ||
| key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) | ||
| value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| out = TemplatedRingAttention.apply( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| return_lse, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config, | ||
| ) | ||
| if return_lse: | ||
| context_layer, lse, *_ = out | ||
| else: | ||
| context_layer = out | ||
| # context_layer is of shape (B, S, H_LOCAL, D) | ||
| output = SeqAllToAllDim.apply( | ||
| ulysses_group, | ||
| context_layer, | ||
| gather_idx, | ||
| scatter_idx, | ||
| ) | ||
| if return_lse: | ||
| # lse is of shape (B, S, H_LOCAL, 1) | ||
| # Refer to: | ||
| # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 | ||
| if is_torch_version("<", "2.9.0"): | ||
| lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) | ||
| lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) | ||
| lse = lse.squeeze(-1) | ||
| return (output, lse) | ||
| return output | ||
|
|
||
|
|
||
| def _templated_context_parallel_attention( | ||
|
|
@@ -1374,7 +1536,25 @@ def _templated_context_parallel_attention( | |
| raise ValueError("GQA is not yet supported for templated attention.") | ||
|
|
||
| # TODO: add support for unified attention with ring/ulysses degree both being > 1 | ||
| if _parallel_config.context_parallel_config.ring_degree > 1: | ||
| if ( | ||
| _parallel_config.context_parallel_config.ring_degree > 1 | ||
| and _parallel_config.context_parallel_config.ulysses_degree > 1 | ||
| ): | ||
| return TemplatedUnifiedAttention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| return_lse, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config, | ||
| ) | ||
| elif _parallel_config.context_parallel_config.ring_degree > 1: | ||
| return TemplatedRingAttention.apply( | ||
| query, | ||
| key, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have a visual the users could refer to (external link is fine), feel free to add that here.