Skip to content

Conversation

@lauri9
Copy link
Contributor

@lauri9 lauri9 commented Oct 27, 2025

What does this PR do?

AITER is AMD’s centralized repository to support high performance AI operators such as attention kernels for AMD ROCm enabled accelerators. This PR adds support for FlashAttention through AITER by introducing a new attention backend.

Test code for Flux inference below. Requires installation of aiter>=0.15.0 and a supported ROCm enabled accelerator.

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, attention_backend

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda")
transformer.set_attention_backend("aiter")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")

prompt = "A cat holding a sign that says 'hello world'"

image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png")

We are interested in following up this PR by eventually also enabling AITER backend support for context parallelism across multiple devices as the feature becomes more mature.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc: @sayakpaul @DN6 for review and any comments

@lauri9 lauri9 force-pushed the add-aiter-backend branch from 7482105 to 89903c3 Compare October 27, 2025 09:52
@sayakpaul
Copy link
Member

Thanks for this PR!

transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda"

Pardon my unwisdom, but for AMD devices, does this string not change? 👀

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool PR!

| attention family | main feature |
|---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR.

Do you think it might be possible to package the aiter kernels with kernels? If so, we could also support through the kernel hub then like we do for FA3 and others (FA2 and SAGE).

Cc: @danieldk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great project and would also make for a good follow-up, though perhaps best handled via separate issue/PR? If I understand it correctly, the kernel would first need to make it to kernels before integration to diffusers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% not related.

@lauri9
Copy link
Contributor Author

lauri9 commented Oct 27, 2025

Pardon my unwisdom, but for AMD devices, does this string not change? 👀

Existing PyTorch code that uses torch.cuda functions (e.g., tensor.to('cuda')) will generally work directly with ROCm, see PyTorch docs on HIP/ROCm semantics. To set up the environment appropriately, it's possible to build PyTorch from source or use a ROCm Docker image, to name a couple of examples - further info on this is provided in the ROCm docs.

Anecdotally, over the last months running diffusers code on ROCm I haven't had compatibility issues with PyTorch.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Oct 27, 2025

Style bot fixed some files and pushed the changes.

@sayakpaul sayakpaul merged commit 250f5cb into huggingface:main Oct 27, 2025
11 checks passed
@sayakpaul
Copy link
Member

Let's go! Thanks a lot for adding this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants