-
Notifications
You must be signed in to change notification settings - Fork 30.7k
[CB
] Refactors the way we access paged
#41370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2527,7 +2527,7 @@ def _check_and_adjust_attn_implementation( | |
# If FA not installed, do not fail but use kernels instead | ||
if ( | ||
attn_implementation is not None | ||
and attn_implementation.startswith("flash_attention") | ||
and "flash" in attn_implementation | ||
and self._supports_flash_attn | ||
and not (is_flash_attn_2_available() or is_flash_attn_3_available()) | ||
and is_kernels_available() | ||
|
@@ -2635,8 +2635,6 @@ def set_attn_implementation(self, attn_implementation: Union[str, dict]): | |
else attn_implementation.get("", self.config._attn_implementation) | ||
) | ||
|
||
# At this point, the model was already instantiated, so instead of crashing on bad value, let's simply | ||
# warn the user that the requested value is not working | ||
if requested_implementation != self.config._attn_implementation: | ||
# In this case, raise | ||
if not self._can_set_attn_implementation(): | ||
|
@@ -5904,10 +5902,10 @@ class AttentionInterface(GeneralInterface): | |
"flash_attention_3": flash_attention_forward, | ||
"flash_attention_2": flash_attention_forward, | ||
"flex_attention": flex_attention_forward, | ||
"paged_attention": paged_attention_forward, | ||
"sdpa": sdpa_attention_forward, | ||
"sdpa_paged": sdpa_attention_paged_forward, | ||
"eager_paged": eager_paged_attention_forward, | ||
"paged|flash_attention2": paged_attention_forward, | ||
"paged|sdpa": sdpa_attention_paged_forward, | ||
"paged|eager": eager_paged_attention_forward, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not supported yet AFAIK There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok good to know 👍 ty |
||
} | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.