- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.4k
[tests] Test attention backends #12388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|  | ||
|  | ||
| FORWARD_CASES = [ | ||
| ("flash_hub", None), | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add this once #12387 is merged.
| ] | ||
|  | ||
| COMPILE_CASES = [ | ||
| ("flash_hub", None, True), | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add add the test slices after #12387 is merged.
| Hi! I'm implementing a new attention backend and in preparation for that I tried the unit tests from this PR. I was working in an environment with a different torch version from nightly  There were differences also in the eager mode tests. Is it expected that the values diverge between versions? Could there be a better way to test than comparing numerical accuracy if the values are expected to vary between versions? | 
| The hardware might matter too. I ran the tests on an H100, actually. The CUDA version could matter, too. I will try to swap out exact assertion with cosine similarity-based checks which are a little more reliable and robust. | 
| ) | ||
| def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile): | ||
| if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"): | ||
| pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.") | |
| pytest.xfail(f"Test with {backend_name} is compatible with a higher version of torch.") | 
nit: typo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
What does this PR do?
Adds a lightweight test suite for popular attention backends. By default this won't be run on our CI.