|
21 | 21 | import torch
|
22 | 22 | from torch.nn.attention.flex_attention import create_block_mask
|
23 | 23 |
|
24 |
| - from transformers import LlamaConfig |
25 |
| - from transformers.masking_utils import create_causal_mask, find_packed_sequence_indices |
| 24 | + from transformers import DynamicCache, LlamaConfig |
| 25 | + from transformers.cache_utils import DynamicSlidingWindowLayer |
| 26 | + from transformers.masking_utils import create_causal_mask, create_chunked_causal_mask, find_packed_sequence_indices |
26 | 27 |
|
27 | 28 |
|
28 | 29 | # fmt: off
|
@@ -135,3 +136,111 @@ def test_find_packed_sequence_indices(self):
|
135 | 136 | position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
|
136 | 137 | EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
|
137 | 138 | self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
|
| 139 | + |
| 140 | + def test_chunked_mask_with_left_padding_and_large_prefill(self): |
| 141 | + # Make sur we have an attention_chunk_size in the config |
| 142 | + config = LlamaConfig(attention_chunk_size=3, attn_implementation="sdpa") |
| 143 | + |
| 144 | + batch_size = 2 |
| 145 | + sequence_length = 8 |
| 146 | + pad_tokens = 4 |
| 147 | + |
| 148 | + input_ids = torch.randint(100, 200, (batch_size, sequence_length)) |
| 149 | + attention_mask = torch.tensor( |
| 150 | + [[0 if i < pad_tokens else 1 for i in range(sequence_length)], [1] * sequence_length] |
| 151 | + ) |
| 152 | + inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16) |
| 153 | + cache_position = torch.arange(sequence_length) |
| 154 | + position_ids = torch.empty(batch_size, sequence_length, dtype=cache_position.dtype) |
| 155 | + position_ids[0, :pad_tokens] = 1 |
| 156 | + position_ids[0, pad_tokens:] = torch.arange(sequence_length - pad_tokens) |
| 157 | + position_ids[1, :] = cache_position |
| 158 | + |
| 159 | + chunked_attention_mask = create_chunked_causal_mask( |
| 160 | + config=config, |
| 161 | + input_embeds=inputs_embeds, |
| 162 | + attention_mask=attention_mask, |
| 163 | + cache_position=cache_position, |
| 164 | + past_key_values=None, |
| 165 | + position_ids=position_ids, |
| 166 | + ) |
| 167 | + |
| 168 | + # fmt: off |
| 169 | + EXPECTED_CHUNKED_MASK = torch.tensor( |
| 170 | + # Here, for the padded sequence, the chunk size should start correctly at index 4 (otherwise, with 4 padding |
| 171 | + # tokens are chunk_size=3, the first chunk is from indices 0-2, then 3-6 if we don't account for the padding correctly) |
| 172 | + [[[[False, False, False, False, False, False, False, False], |
| 173 | + [False, False, False, False, False, False, False, False], |
| 174 | + [False, False, False, False, False, False, False, False], |
| 175 | + [False, False, False, False, False, False, False, False], |
| 176 | + [False, False, False, False, True, False, False, False], |
| 177 | + [False, False, False, False, True, True, False, False], |
| 178 | + [False, False, False, False, True, True, True, False], |
| 179 | + [False, False, False, False, False, False, False, True]]], |
| 180 | + |
| 181 | + |
| 182 | + [[[ True, False, False, False, False, False, False, False], |
| 183 | + [ True, True, False, False, False, False, False, False], |
| 184 | + [ True, True, True, False, False, False, False, False], |
| 185 | + [False, False, False, True, False, False, False, False], |
| 186 | + [False, False, False, True, True, False, False, False], |
| 187 | + [False, False, False, True, True, True, False, False], |
| 188 | + [False, False, False, False, False, False, True, False], |
| 189 | + [False, False, False, False, False, False, True, True]]]], |
| 190 | + dtype=torch.bool) |
| 191 | + # fmt: on |
| 192 | + |
| 193 | + self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all()) |
| 194 | + |
| 195 | + def test_chunked_mask_with_left_padding_decoding(self): |
| 196 | + # Make sur we have an attention_chunk_size in the config |
| 197 | + config = LlamaConfig(attention_chunk_size=4, attn_implementation="sdpa", num_hidden_layers=1) |
| 198 | + |
| 199 | + cache = DynamicCache(config=config) |
| 200 | + # Sanity check |
| 201 | + self.assertEqual(len(cache), 1) |
| 202 | + self.assertTrue(isinstance(cache.layers[0], DynamicSlidingWindowLayer)) |
| 203 | + |
| 204 | + # Fill-in the Cache (sequence length is bigger than chunk size here) |
| 205 | + batch_size = 2 |
| 206 | + prefill_size = 8 |
| 207 | + pad_tokens = 7 |
| 208 | + fake_kv = torch.rand(batch_size, 32, prefill_size, 32) |
| 209 | + cache.update(fake_kv, fake_kv, 0, torch.arange(prefill_size)) |
| 210 | + |
| 211 | + # Create a new input after the prefill |
| 212 | + input_ids = torch.randint(100, 200, (batch_size, 1)) |
| 213 | + attention_mask = torch.tensor( |
| 214 | + [[0 if i < pad_tokens else 1 for i in range(prefill_size + 1)], [1] * (prefill_size + 1)] |
| 215 | + ) |
| 216 | + inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16) |
| 217 | + cache_position = torch.tensor([prefill_size], dtype=int) |
| 218 | + position_ids = torch.tensor([[prefill_size - pad_tokens], [prefill_size]]) |
| 219 | + |
| 220 | + chunked_attention_mask = create_chunked_causal_mask( |
| 221 | + config=config, |
| 222 | + input_embeds=inputs_embeds, |
| 223 | + attention_mask=attention_mask, |
| 224 | + cache_position=cache_position, |
| 225 | + past_key_values=cache, |
| 226 | + position_ids=position_ids, |
| 227 | + ) |
| 228 | + |
| 229 | + # To understand a bit more the following expected mask, here is the full 2d mask, where the "|" characters are the chunk |
| 230 | + # separators (where the tokens should stop seeing each other) |
| 231 | + # [0, 0, 0, 0, 0, 0, 0, | 1, 1], -> due to left padding, the first chunk only starts after the padding tokens |
| 232 | + # [| 1, 1, 1, 1, | 1, 1, 1, 1, | 1]]) -> easy case, each 4 tokens is a new chunk |
| 233 | + |
| 234 | + # fmt: off |
| 235 | + EXPECTED_CHUNKED_MASK = torch.tensor( |
| 236 | + # Here, for the padded sequence, the chunk size should start correctly at index 7 (the first unpadded |
| 237 | + # index), and so only indices 7 and 8 should be True |
| 238 | + [[[[False, False, True, True]]], |
| 239 | + |
| 240 | + # Here, for the unpadded sequence, the chunks start at index 0. Since we have 9 tokens in total, the last |
| 241 | + # token (index 8) will only see itself (we have 2 full chunks before) |
| 242 | + [[[False, False, False, True]]]], |
| 243 | + dtype=torch.bool) |
| 244 | + # fmt: on |
| 245 | + |
| 246 | + self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all()) |
0 commit comments