-
Notifications
You must be signed in to change notification settings - Fork 32.2k
Description
System Info
Hi, while we were updating the PyTorch transformers pin to v5.2.0, our regression tests caught a numerics issue between eager and compiled, the difference is very substantial (3.3 vs the typical e-4 accepted difference). Digging into it: pytorch/pytorch#175274 (comment), we found the cause to be in these lines (added in #41265):
transformers/src/transformers/masking_utils.py
Lines 490 to 491 in 147b7aa
| if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask, kv_length, local_size): | |
| return None |
We set allow_is_bidirectional_skip=True in a few places:
transformers/src/transformers/masking_utils.py
Lines 996 to 997 in 147b7aa
| # Allow skipping the mask creation except we have additional masking operators (and/or masks) | |
| allow_is_bidirectional_skip = True |
And in _ignore_bidirectional_mask_sdpa, we branch logic on whether we compile or not:
transformers/src/transformers/masking_utils.py
Lines 324 to 332 in 147b7aa
| if ( | |
| not is_tracing(padding_mask) | |
| and (padding_mask is None or padding_mask.all()) | |
| # in this case we need to add special patterns to the mask so cannot be skipped otherwise | |
| and (local_attention_size is None or kv_length < local_attention_size) | |
| ): | |
| return True | |
| return False |
This issue was found on BERT but it seems like it would affect other models too.
We've also verified that removing the branching fixes the numerical difference. I'm creating this issue to ask about the best way forward here. From the PR that added it, it looks like this was necessary specifically for executorch, but the algorithm difference is also affected all other APIs that fall under is_tracing . Can we restrict the check?
Who can help?
@vasqu @ArthurZucker @Cyrilvallez
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I believe the description is enough, but I can provide a simpler repro on request
Expected behavior
transformers users probably shouldn't run into large numeric differences when compiling, at least not by default