Skip to content

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

Refactors how we load the attention kernels from the Hub.

Currently, when a user specifies the DIFFUSERS_ENABLE_HUB_KERNELS env var, we always download the supported kernel. Currently, we have FA3, but we have ongoing PRs that support FA and SAGE: #12387 and #12439. So, we will download ALL of them even when they're not required. This is not good.

This PR makes it so that only the relevant kernel gets downloaded without breaking torch.compile compliance (fullgraph and no recompilation triggers).

Cc: @MekkCyber

@sayakpaul sayakpaul requested a review from DN6 October 13, 2025 10:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good to only load the invoked attention implementation ! Thanks for adding this

@DN6
Copy link
Collaborator

DN6 commented Oct 23, 2025

I think the control flow here is a bit difficult to follow. We should aim to minimize the number of new objects/concepts introduced in this module since there's already quite a lot of routing going on in here.

My recommendations:

  1. Add a _set_attention_backend function to attention_dispatch.py that handles checking requirements and downloading kernels if the are set. Call this function from modeling_utils when setting the backend
def _set_attention_backend(backend: AttentionBackendName) -> None:
    _check_attention_backend_requirements(backend)
    _maybe_download_kernel_for_backend(backend)

In attention dispatch, let's create a _HubKernelConfig and _HubKernelRegistry

@dataclass
class _HubKernelConfig:
    """Configuration for downloading and using a hub-based attention kernel."""

    repo_id: str
    function_attr: str
    revision: Optional[str] = None
    kernel_fn: Optional[Callable] = None


# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
    AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
        repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
    )
}

_maybe_download_kernel_for_backend(backend) would download the kernel and set the kernel_fn for a given backend if the backend is supported otherwise it's a no-op.

Then in your hub function, fetch the downloaded kernel from the registry

    func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]._kernel_fn
    out = func(
        q=query,
        ....

We shouldn't attempt kernel downloads from the dispatch function. It should already be downloaded/available before hand.

@sayakpaul
Copy link
Member Author

Okay but

  • I would maintain a separate _set_attention_backend() function within set_attention_backend()? Does it not complicate things further? Would rather have _check_attention_backend_requirements(backend) and _maybe_download_kernel_for_backend(backend) directly in set_attention_backend(). That's better no?
  • How does _maybe_download_kernel_for_backend(backend) interact with the proposed registry? After downloading the kernel from the given config spec, we set the _kernel_fn?

@DN6
Copy link
Collaborator

DN6 commented Oct 23, 2025

I would maintain a separate _set_attention_backend() function within set_attention_backend()? Does it not complicate things further? Would rather have _check_attention_backend_requirements(backend) and _maybe_download_kernel_for_backend(backend) directly in set_attention_backend(). That's better no?

This is also good 👍🏽

How does _maybe_download_kernel_for_backend(backend) interact with the proposed registry? After downloading the kernel from the given config spec, we set the _kernel_fn?

Something like

def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
    if backend not in _HUB_KERNELS_REGISTRY:
        return
    config = _HUB_KERNELS_REGISTRY[backend]

    if config._kernel_fn is not None:
        return

    try:
        from kernels import get_kernel

        kernel_module = get_kernel(config.repo_id, revision=config.revision)
        kernel_func = getattr(kernel_module, config.function_attr)

        # Cache the downloaded kernel function in the config object
        config._kernel_fn = kernel_func

    except Exception as e:
        raise
  

@sayakpaul
Copy link
Member Author

@DN6 check now. Your feedback should have been addressed. I was able to completely get rid of kernels_utils.py as a consequence of that. I also decided to get rid of the DIFFUSERS_ENABLE_HUB_KERNELS environment var.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants