-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix Attention Mask Padding to Ensure Multiple of 8 Alignment #9678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
57d8e0b
84e267c
de6f81c
9e65a76
fa79062
6f56ca4
c6e68d6
3d30670
75fa7a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
|
|
||
| import copy | ||
| import gc | ||
| import math | ||
| import os | ||
| import tempfile | ||
| import unittest | ||
|
|
@@ -406,6 +407,33 @@ def test_xformers_enable_works(self): | |
| == "XFormersAttnProcessor" | ||
| ), "xformers is not enabled" | ||
|
|
||
| def test_attention_mask_padding(self): | ||
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
|
||
| encoder_hidden_states = inputs_dict["encoder_hidden_states"] | ||
| query_tokens = encoder_hidden_states.shape[1] | ||
| attention_mask = torch.ones((2, query_tokens, 22), dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| attention_mask_shape_before = attention_mask.shape[-1] | ||
| if attention_mask.dtype == torch.bfloat16 and attention_mask.shape[-1] % 8 != 0: | ||
| padded_length = math.ceil(attention_mask.shape[-1] / 8) * 8 | ||
| mask = torch.zeros( | ||
| (attention_mask.shape[0], attention_mask.shape[1], padded_length), | ||
| device=attention_mask.device, | ||
| dtype=attention_mask.dtype, | ||
| ) | ||
| mask[:, :, : attention_mask.shape[-1]] = attention_mask | ||
| attention_mask = mask | ||
|
|
||
| assert attention_mask.shape[-1] % 8 == 0, "Attention mask not padded to a multiple of 8" | ||
| assert attention_mask[:, :, :attention_mask_shape_before].equal( | ||
| attention_mask[:, :, :attention_mask_shape_before] | ||
| ), "Original values in attention mask are not preserved" | ||
|
|
||
| expanded_attention_mask = attention_mask.expand(-1, query_tokens, -1) | ||
|
|
||
| assert expanded_attention_mask.shape[1] == query_tokens, "Attention mask expansion for query tokens failed" | ||
|
||
|
|
||
| @require_torch_accelerator_with_training | ||
| def test_gradient_checkpointing(self): | ||
| # enable deterministic behavior for gradient checkpointing | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.