Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Oct 6, 2025

What does this PR do?

Code to test:

from diffusers import DiffusionPipeline 
import torch 

repo_id = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.set_attention_backend("sage_hub")

image = pipe(
    prompt="a dog sitting by the sea, waiting for its companion to come",
    guidance_scale=3.5,
    num_inference_steps=30,
    max_sequence_length=512,
    generator=torch.manual_seed(0)
).images[0]
image.save("sage_flux.png")

Result:
image


Notes

  1. It would be nice to get torch.compile support when using sage attention like we have for flash and flash 3. Currently, this fails.
Code to test
from diffusers import DiffusionPipeline 
import torch 

repo_id = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.set_attention_backend("sage_hub")
pipe.transformer.compile_repeated_blocks(fullgraph=True)

with (
    torch._inductor.utils.fresh_inductor_cache(),
    torch._dynamo.config.patch(error_on_recompile=True),
):
    image = pipe(
        prompt="a dog sitting by the sea, waiting for its companion to come",
        guidance_scale=3.5,
        num_inference_steps=30,
        max_sequence_length=512,
        generator=torch.manual_seed(0)
    ).images[0]
image.save("sage_flux.png")

Error: https://pastebin.com/3HS6HNzR

  1. We have other sageattn variants (see here), which would be cool to expose from the Hub kernel.

Cc: @MekkCyber

@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Oct 6, 2025
@sayakpaul sayakpaul requested a review from DN6 October 6, 2025 05:48
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool ! I will try to look into the torch compile compatibility, but for the other variants, they are the same as sageattn, what i mean is sageattn is just a wrapper that dispatches to the correct kernel depending on the hardware used : https://github.com/thu-ml/SageAttention/blob/main/sageattention/core.py#L140

@sayakpaul
Copy link
Member Author

they are the same as sageattn, what i mean is sageattn is just a wrapper that dispatches to the correct kernel depending on the hardware used :

So, you mean we shouldn't have to have different dispatched functions like this?

_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"

@MekkCyber
Copy link

Yes I think we don't need that because it depends on the hardware. For example if a user chooses : _sage_qk_int8_pv_fp8_cuda on A100 (8.0) it will fail, because this function is only supported and compiled for 8.9 gpus

@sayakpaul sayakpaul marked this pull request as draft October 7, 2025 13:38
Comment on lines -165 to -167
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see their usage, hence removed.

@woct0rdho
Copy link

woct0rdho commented Oct 10, 2025

FYI, I've ported SageAttention to Python stable ABI (ABI3) and libtorch stable ABI, which should simplify building for HF Kernels:
woct0rdho/SageAttention@main...abi3_stable

There are also some refactors in my main branch to simplify building. If someone can maintain the build system, then I no longer need to maintain my repo :)

@sayakpaul sayakpaul marked this pull request as ready for review October 13, 2025 10:27
@sayakpaul
Copy link
Member Author

This PR is ready to be reviewed now. As discussed with @MekkCyber over DMs, we're disabling torch.compile support now as the compile branch leads to garbage outputs.

In order for us to support it with torch.compile, some kind of lightweight dispatcher might be needed. d344134 added support for that but I have removed it for now for the above-mentioned purpos. Those changes are still safe in sage-kernels-dispatch branch.

I think we should be good with the PR.

Cc: @MekkCyber @DN6

Comment on lines -86 to +92
from ..utils.kernels_utils import _get_fa3_from_hub
from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3)
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func

sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE)
sage_attn_func_hub = sage_interface_hub.sageattn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of downloading all kernels if the env variable is set, since it's downloading stuff without explicit user consent. I think we need to rethink this part a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. This is why #12475. Let's get that reviewed and merged first as it will unblock this PR and also #12387

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Anything related to performance improvements, profiling and benchmarking

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants