Skip to content

Conversation

@ParagEkbote
Copy link
Contributor

@ParagEkbote ParagEkbote commented Sep 11, 2025

What does this PR do?

As discussed in the issue, this PR adds support for kernels-community/flash-attn kernel. Could you please review?

Fixes #12308

Before submitting

Who can review?

@sayakpaul

@sayakpaul
Copy link
Member

Thanks for this PR. Could you update it with some code examples and results?

@ParagEkbote
Copy link
Contributor Author

ParagEkbote commented Sep 11, 2025

This is the test command, but unable to generate images.

import os
os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "yes"

# Debug: Verify the env var is set
print(f"DIFFUSERS_ENABLE_HUB_KERNELS = {os.environ.get('DIFFUSERS_ENABLE_HUB_KERNELS')}")

import torch
from diffusers import FluxPipeline
from diffusers.quantizers import PipelineQuantizationConfig

# Debug: Check if diffusers sees the env var
from diffusers.models.attention_dispatch import DIFFUSERS_ENABLE_HUB_KERNELS
print(f"Diffusers sees DIFFUSERS_ENABLE_HUB_KERNELS = {DIFFUSERS_ENABLE_HUB_KERNELS}")

# ✅ 3. Load pipeline with quantization
model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    quantization_config=PipelineQuantizationConfig(
        quant_backend="bitsandbytes_4bit",
        quant_kwargs={
            "load_in_4bit": True,
            "bnb_4bit_quant_type": "nf4",
            "bnb_4bit_compute_dtype": torch.bfloat16,
        },
        components_to_quantize=["transformer"],
    ),
).to("cuda")

pipe.transformer.set_attention_backend("_flash_hub")

prompt = "A cat holding a sign that says 'hello world'"
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png")

@ParagEkbote
Copy link
Contributor Author

I'm having issues regarding some of the parameters with the following traceback:

Traceback (most recent call last):
  File "/teamspace/studios/this_studio/diffusers/main.py", line 34, in <module>
    image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 944, in __call__
    noise_pred = self.transformer(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 720, in forward
    encoder_hidden_states, hidden_states = block(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 443, in forward
    attention_outputs = self.attn(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 342, in forward
    return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 116, in __call__
    hidden_states = dispatch_attention_fn(
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/attention_dispatch.py", line 304, in dispatch_attention_fn
    return backend_fn(**kwargs)
  File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/attention_dispatch.py", line 765, in _flash_attention_hub
    out = flash_attn_func_hub(
TypeError: flash_attn_func() got an unexpected keyword argument 'alibi_slopes'

The same error occurs with dropout_p parameter as well. WDYT?

cc: @sayakpaul

@sayakpaul
Copy link
Member

@ParagEkbote I think we can close this PR in favor of #12387. You're more than welcome to test the PR and let us know of any feedback.

@ParagEkbote
Copy link
Contributor Author

@sayakpaul Thanks for letting me know and being a patient reviewer. Closing the PR..

@ParagEkbote ParagEkbote deleted the Add-FA2 branch September 27, 2025 15:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support flash-attn kernel support for non-Hopper GPUs

2 participants