-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[core] Refactor hub attn kernels #12475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good to only load the invoked attention implementation ! Thanks for adding this
|
I think the control flow here is a bit difficult to follow. We should aim to minimize the number of new objects/concepts introduced in this module since there's already quite a lot of routing going on in here. My recommendations:
def _set_attention_backend(backend: AttentionBackendName) -> None:
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)In attention dispatch, let's create a @dataclass
class _HubKernelConfig:
"""Configuration for downloading and using a hub-based attention kernel."""
repo_id: str
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
)
}
Then in your hub function, fetch the downloaded kernel from the registry func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]._kernel_fn
out = func(
q=query,
....We shouldn't attempt kernel downloads from the dispatch function. It should already be downloaded/available before hand. |
|
Okay but
|
This is also good 👍🏽
Something like def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config._kernel_fn is not None:
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config._kernel_fn = kernel_func
except Exception as e:
raise
|
Co-authored-by: Dhruv Nair <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
Co-authored-by: dn6 <[email protected]>
|
@DN6 check now. Your feedback should have been addressed. I was able to completely get rid of |
What does this PR do?
Refactors how we load the attention kernels from the Hub.
Currently, when a user specifies the
DIFFUSERS_ENABLE_HUB_KERNELSenv var, we always download the supported kernel. Currently, we have FA3, but we have ongoing PRs that support FA and SAGE: #12387 and #12439. So, we will download ALL of them even when they're not required. This is not good.This PR makes it so that only the relevant kernel gets downloaded without breaking
torch.compilecompliance (fullgraph and no recompilation triggers).Cc: @MekkCyber