Skip to content

Diverging attention kernels due to allow_is_bidirectional_skip branching on torch.compile #44188

@xmfan

Description

@xmfan

System Info

Hi, while we were updating the PyTorch transformers pin to v5.2.0, our regression tests caught a numerics issue between eager and compiled, the difference is very substantial (3.3 vs the typical e-4 accepted difference). Digging into it: pytorch/pytorch#175274 (comment), we found the cause to be in these lines (added in #41265):

if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask, kv_length, local_size):
return None

We set allow_is_bidirectional_skip=True in a few places:

# Allow skipping the mask creation except we have additional masking operators (and/or masks)
allow_is_bidirectional_skip = True

And in _ignore_bidirectional_mask_sdpa, we branch logic on whether we compile or not:

if (
not is_tracing(padding_mask)
and (padding_mask is None or padding_mask.all())
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
and (local_attention_size is None or kv_length < local_attention_size)
):
return True
return False

This issue was found on BERT but it seems like it would affect other models too.

We've also verified that removing the branching fixes the numerical difference. I'm creating this issue to ask about the best way forward here. From the PR that added it, it looks like this was necessary specifically for executorch, but the algorithm difference is also affected all other APIs that fall under is_tracing . Can we restrict the check?

Who can help?

@vasqu @ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I believe the description is enough, but I can provide a simpler repro on request

Expected behavior

transformers users probably shouldn't run into large numeric differences when compiling, at least not by default

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions