You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Introduces utilities to unpad/repad tensors and compute indices/cumulative seqlens for ragged batches, reusing mask-derived metadata across Q/K/V to reduce overhead.
Handles static KV caches longer than the mask by safe slicing to avoid incorrect attention scores, and supports left-padded sequences and single-token decoding.
Improves performance and correctness for attention paths that operate on variable-length inputs.
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
104
+
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
105
+
tensors for query, key, value tensors.
106
+
107
+
Arguments:
108
+
query_layer (`torch.Tensor`):
109
+
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
110
+
key_layer (`torch.Tensor`):
111
+
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
112
+
value_layer (`torch.Tensor`):
113
+
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
114
+
attention_mask (`torch.Tensor`):
115
+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
116
+
query_length (`int`):
117
+
Target length.
118
+
unpad_input_func:
119
+
The function to use for unpadding the input tensors.
120
+
121
+
Return:
122
+
query_layer (`torch.Tensor`):
123
+
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
124
+
key_layer (`torch.Tensor`):
125
+
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
126
+
value_layer (`torch.Tensor`):
127
+
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
128
+
indices_q (`torch.Tensor`):
129
+
The indices of non-masked tokens from the flattened input target sequence.
130
+
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
131
+
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
0 commit comments