Skip to content

Commit 3488d06

Browse files
committed
Add padding/unpadding helpers for QKV attention
Introduces utilities to unpad/repad tensors and compute indices/cumulative seqlens for ragged batches, reusing mask-derived metadata across Q/K/V to reduce overhead. Handles static KV caches longer than the mask by safe slicing to avoid incorrect attention scores, and supports left-padded sequences and single-token decoding. Improves performance and correctness for attention paths that operate on variable-length inputs.
1 parent 83fb0f7 commit 3488d06

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

flash_dmattn/utils/padding.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
7+
def index_first_axis(tensor, indices):
8+
"""
9+
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
10+
after flattening the first two dimensions of the tensor.
11+
"""
12+
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
13+
# two dimensions to get (total_tokens, ...) before indexing.
14+
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
15+
return reshaped_tensor[indices]
16+
17+
18+
def unpad_input(hidden_states, attention_mask, unused_mask=None):
19+
"""
20+
Arguments:
21+
hidden_states: (batch, seqlen, ...)
22+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
23+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
24+
25+
Return:
26+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
27+
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
28+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
29+
max_seqlen_in_batch: int
30+
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
31+
"""
32+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
33+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
34+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
35+
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
36+
max_seqlen_in_batch = seqlens_in_batch.max().item()
37+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
38+
39+
return (
40+
index_first_axis(hidden_states, indices),
41+
indices,
42+
cu_seqlens,
43+
max_seqlen_in_batch,
44+
used_seqlens_in_batch,
45+
)
46+
47+
48+
def pad_input(hidden_states, indices, batch, seqlen):
49+
"""
50+
Arguments:
51+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
52+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
53+
batch: int, batch size for the padded sequence.
54+
seqlen: int, maximum sequence length for the padded sequence.
55+
56+
Return:
57+
hidden_states: (batch, seqlen, ...)
58+
"""
59+
dim = hidden_states.shape[1:]
60+
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
61+
output[indices] = hidden_states
62+
return output.view(batch, seqlen, *dim)
63+
64+
65+
def get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
66+
"""
67+
Retrieves indexing data required to repad unpadded (ragged) tensors.
68+
69+
Arguments:
70+
attention_mask (`torch.Tensor`):
71+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
72+
73+
Return:
74+
indices (`torch.Tensor`):
75+
The indices of non-masked tokens from the flattened input sequence.
76+
cu_seqlens (`torch.Tensor`):
77+
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
78+
max_seqlen_in_batch (`int`):
79+
Maximum sequence length in batch.
80+
"""
81+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
82+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
83+
# NOTE: Similar to the `.item()` in prepare_fdma_from_position_ids, with torch compile,
84+
# this might cause a graph break
85+
max_seqlen_in_batch = seqlens_in_batch.max().item()
86+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
87+
return (
88+
indices,
89+
cu_seqlens,
90+
max_seqlen_in_batch,
91+
)
92+
93+
94+
def upad_input(
95+
query_layer: torch.Tensor,
96+
key_layer: torch.Tensor,
97+
value_layer: torch.Tensor,
98+
attention_mask: torch.Tensor,
99+
query_length: int,
100+
unpad_input_func,
101+
):
102+
"""
103+
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
104+
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
105+
tensors for query, key, value tensors.
106+
107+
Arguments:
108+
query_layer (`torch.Tensor`):
109+
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
110+
key_layer (`torch.Tensor`):
111+
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
112+
value_layer (`torch.Tensor`):
113+
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
114+
attention_mask (`torch.Tensor`):
115+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
116+
query_length (`int`):
117+
Target length.
118+
unpad_input_func:
119+
The function to use for unpadding the input tensors.
120+
121+
Return:
122+
query_layer (`torch.Tensor`):
123+
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
124+
key_layer (`torch.Tensor`):
125+
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
126+
value_layer (`torch.Tensor`):
127+
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
128+
indices_q (`torch.Tensor`):
129+
The indices of non-masked tokens from the flattened input target sequence.
130+
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
131+
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
132+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
133+
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
134+
"""
135+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask)
136+
137+
# With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
138+
# It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
139+
if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
140+
key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
141+
142+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
143+
144+
key_layer = index_first_axis(key_layer, indices_k)
145+
value_layer = index_first_axis(value_layer, indices_k)
146+
if query_length == kv_seq_len:
147+
query_layer = index_first_axis(query_layer, indices_k)
148+
cu_seqlens_q = cu_seqlens_k
149+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
150+
indices_q = indices_k
151+
elif query_length == 1:
152+
max_seqlen_in_batch_q = 1
153+
cu_seqlens_q = torch.arange(
154+
batch_size + 1, dtype=torch.int32, device=query_layer.device
155+
) # There is a memcpy here, that is very bad.
156+
indices_q = cu_seqlens_q[:-1]
157+
query_layer = query_layer.squeeze(1)
158+
else:
159+
# The -q_len: slice assumes left padding.
160+
attention_mask = attention_mask[:, -query_length:]
161+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
162+
163+
return (
164+
query_layer,
165+
key_layer,
166+
value_layer,
167+
indices_q,
168+
(cu_seqlens_q, cu_seqlens_k),
169+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
170+
)

0 commit comments

Comments
 (0)