Skip to content

Commit c2e3cc2

Browse files
authored
Fix chunked attention mask with left-padding (#40324)
* add fix * add test * raise proper warning for older versions * fix * fix and add 2nd test * fix for flex and torch 2.5
1 parent 242bb2c commit c2e3cc2

File tree

2 files changed

+154
-10
lines changed

2 files changed

+154
-10
lines changed

src/transformers/masking_utils.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from .cache_utils import Cache
2222
from .configuration_utils import PretrainedConfig
23-
from .utils import is_torch_xpu_available
23+
from .utils import is_torch_xpu_available, logging
2424
from .utils.generic import GeneralInterface
2525
from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_torchdynamo_compiling
2626

@@ -40,6 +40,9 @@
4040
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
4141

4242

43+
logger = logging.get_logger(__name__)
44+
45+
4346
def and_masks(*mask_functions: list[Callable]) -> Callable:
4447
"""Returns a mask function that is the intersection of provided mask functions"""
4548
if not all(callable(arg) for arg in mask_functions):
@@ -87,12 +90,24 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
8790
return inner_mask
8891

8992

90-
def chunked_overlay(chunk_size: int) -> Callable:
93+
def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
9194
"""
9295
This is an overlay depicting a chunked attention pattern. Add it on top of a causal mask for a proper chunked
9396
attention mask.
9497
"""
9598

99+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
100+
return (kv_idx - left_padding[batch_idx]) // chunk_size == (q_idx - left_padding[batch_idx]) // chunk_size
101+
102+
return inner_mask
103+
104+
105+
def _legacy_chunked_overlay(chunk_size: int) -> Callable:
106+
"""
107+
Same as the above function, but do not correctly account for left padding tokens.
108+
Only kept for compatibility with older torch versions (< 2.6).
109+
"""
110+
96111
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
97112
return kv_idx // chunk_size == q_idx // chunk_size
98113

@@ -106,11 +121,13 @@ def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
106121
return and_masks(sliding_window_overlay(sliding_window), causal_mask_function)
107122

108123

109-
def chunked_causal_mask_function(chunk_size: int) -> Callable:
124+
def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) -> Callable:
110125
"""
111126
This return the mask_function function to create a chunked attention mask.
112127
"""
113-
return and_masks(chunked_overlay(chunk_size), causal_mask_function)
128+
if not _is_torch_greater_or_equal_than_2_6:
129+
return and_masks(_legacy_chunked_overlay(chunk_size), causal_mask_function)
130+
return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
114131

115132

116133
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
@@ -298,7 +315,7 @@ def sdpa_mask_recent_torch(
298315
You can do
299316
300317
```python
301-
>>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
318+
>>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
302319
>>> tensor([[[[ True, False, False, False, False],
303320
[ True, True, False, False, False],
304321
[ True, True, True, False, False],
@@ -319,7 +336,7 @@ def sdpa_mask_recent_torch(
319336
You can do
320337
321338
```python
322-
>>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
339+
>>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
323340
>>> tensor([[[[ True, False, False, False, False],
324341
[ True, True, False, False, False],
325342
[ True, True, True, False, False],
@@ -340,7 +357,7 @@ def sdpa_mask_recent_torch(
340357
You can do
341358
342359
```python
343-
>>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3))
360+
>>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
344361
>>> tensor([[[[ True, False, False, False, False],
345362
[ True, True, False, False, False],
346363
[ True, True, True, False, False],
@@ -973,7 +990,25 @@ def create_chunked_causal_mask(
973990
)
974991

975992
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
976-
mask_factory_function = chunked_causal_mask_function(chunk_size)
993+
# For chunked attention and batched inputs, we need to take the number of left padding tokens into account
994+
# to start the chunk from the actual start of the sequence for the padded sequence
995+
if attention_mask is not None:
996+
# Only count the left padding tokens, not all of them
997+
left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
998+
else:
999+
left_padding_tokens = torch.zeros(batch_size, device=cache_position.device, dtype=int)
1000+
# Raise a warning for older versions if the problematic left-padding situation arises
1001+
if (
1002+
not _is_torch_greater_or_equal_than_2_6
1003+
and kv_length + kv_offset > chunk_size
1004+
and (left_padding_tokens > 0).any()
1005+
):
1006+
logger.warning_once(
1007+
"Due to limitations of your current torch version, we cannot correctly account for the left-padding "
1008+
"when computing the chunked attention pattern. This will lead to a wrong attention mask for the padded "
1009+
"sequences. Behavior will be undefined. Please upgrade to `torch>=2.6` to solve this issue."
1010+
)
1011+
mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
9771012
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
9781013

9791014
# Do not allow skip if we are compiling (this is to match BC)

tests/utils/test_masking_utils.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import torch
2222
from torch.nn.attention.flex_attention import create_block_mask
2323

24-
from transformers import LlamaConfig
25-
from transformers.masking_utils import create_causal_mask, find_packed_sequence_indices
24+
from transformers import DynamicCache, LlamaConfig
25+
from transformers.cache_utils import DynamicSlidingWindowLayer
26+
from transformers.masking_utils import create_causal_mask, create_chunked_causal_mask, find_packed_sequence_indices
2627

2728

2829
# fmt: off
@@ -135,3 +136,111 @@ def test_find_packed_sequence_indices(self):
135136
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
136137
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
137138
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
139+
140+
def test_chunked_mask_with_left_padding_and_large_prefill(self):
141+
# Make sur we have an attention_chunk_size in the config
142+
config = LlamaConfig(attention_chunk_size=3, attn_implementation="sdpa")
143+
144+
batch_size = 2
145+
sequence_length = 8
146+
pad_tokens = 4
147+
148+
input_ids = torch.randint(100, 200, (batch_size, sequence_length))
149+
attention_mask = torch.tensor(
150+
[[0 if i < pad_tokens else 1 for i in range(sequence_length)], [1] * sequence_length]
151+
)
152+
inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16)
153+
cache_position = torch.arange(sequence_length)
154+
position_ids = torch.empty(batch_size, sequence_length, dtype=cache_position.dtype)
155+
position_ids[0, :pad_tokens] = 1
156+
position_ids[0, pad_tokens:] = torch.arange(sequence_length - pad_tokens)
157+
position_ids[1, :] = cache_position
158+
159+
chunked_attention_mask = create_chunked_causal_mask(
160+
config=config,
161+
input_embeds=inputs_embeds,
162+
attention_mask=attention_mask,
163+
cache_position=cache_position,
164+
past_key_values=None,
165+
position_ids=position_ids,
166+
)
167+
168+
# fmt: off
169+
EXPECTED_CHUNKED_MASK = torch.tensor(
170+
# Here, for the padded sequence, the chunk size should start correctly at index 4 (otherwise, with 4 padding
171+
# tokens are chunk_size=3, the first chunk is from indices 0-2, then 3-6 if we don't account for the padding correctly)
172+
[[[[False, False, False, False, False, False, False, False],
173+
[False, False, False, False, False, False, False, False],
174+
[False, False, False, False, False, False, False, False],
175+
[False, False, False, False, False, False, False, False],
176+
[False, False, False, False, True, False, False, False],
177+
[False, False, False, False, True, True, False, False],
178+
[False, False, False, False, True, True, True, False],
179+
[False, False, False, False, False, False, False, True]]],
180+
181+
182+
[[[ True, False, False, False, False, False, False, False],
183+
[ True, True, False, False, False, False, False, False],
184+
[ True, True, True, False, False, False, False, False],
185+
[False, False, False, True, False, False, False, False],
186+
[False, False, False, True, True, False, False, False],
187+
[False, False, False, True, True, True, False, False],
188+
[False, False, False, False, False, False, True, False],
189+
[False, False, False, False, False, False, True, True]]]],
190+
dtype=torch.bool)
191+
# fmt: on
192+
193+
self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())
194+
195+
def test_chunked_mask_with_left_padding_decoding(self):
196+
# Make sur we have an attention_chunk_size in the config
197+
config = LlamaConfig(attention_chunk_size=4, attn_implementation="sdpa", num_hidden_layers=1)
198+
199+
cache = DynamicCache(config=config)
200+
# Sanity check
201+
self.assertEqual(len(cache), 1)
202+
self.assertTrue(isinstance(cache.layers[0], DynamicSlidingWindowLayer))
203+
204+
# Fill-in the Cache (sequence length is bigger than chunk size here)
205+
batch_size = 2
206+
prefill_size = 8
207+
pad_tokens = 7
208+
fake_kv = torch.rand(batch_size, 32, prefill_size, 32)
209+
cache.update(fake_kv, fake_kv, 0, torch.arange(prefill_size))
210+
211+
# Create a new input after the prefill
212+
input_ids = torch.randint(100, 200, (batch_size, 1))
213+
attention_mask = torch.tensor(
214+
[[0 if i < pad_tokens else 1 for i in range(prefill_size + 1)], [1] * (prefill_size + 1)]
215+
)
216+
inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16)
217+
cache_position = torch.tensor([prefill_size], dtype=int)
218+
position_ids = torch.tensor([[prefill_size - pad_tokens], [prefill_size]])
219+
220+
chunked_attention_mask = create_chunked_causal_mask(
221+
config=config,
222+
input_embeds=inputs_embeds,
223+
attention_mask=attention_mask,
224+
cache_position=cache_position,
225+
past_key_values=cache,
226+
position_ids=position_ids,
227+
)
228+
229+
# To understand a bit more the following expected mask, here is the full 2d mask, where the "|" characters are the chunk
230+
# separators (where the tokens should stop seeing each other)
231+
# [0, 0, 0, 0, 0, 0, 0, | 1, 1], -> due to left padding, the first chunk only starts after the padding tokens
232+
# [| 1, 1, 1, 1, | 1, 1, 1, 1, | 1]]) -> easy case, each 4 tokens is a new chunk
233+
234+
# fmt: off
235+
EXPECTED_CHUNKED_MASK = torch.tensor(
236+
# Here, for the padded sequence, the chunk size should start correctly at index 7 (the first unpadded
237+
# index), and so only indices 7 and 8 should be True
238+
[[[[False, False, True, True]]],
239+
240+
# Here, for the unpadded sequence, the chunks start at index 0. Since we have 9 tokens in total, the last
241+
# token (index 8) will only see itself (we have 2 full chunks before)
242+
[[[False, False, False, True]]]],
243+
dtype=torch.bool)
244+
# fmt: on
245+
246+
self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())

0 commit comments

Comments
 (0)