Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/api/models/chroma_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.

# ChromaTransformer2DModel

A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)

## ChromaTransformer2DModel

Expand Down
13 changes: 7 additions & 6 deletions docs/source/en/api/pipelines/chroma.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License.

Chroma is a text to image generation model based on Flux.

Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
Original model checkpoints for Chroma can be found here:
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)

> [!TIP]
> Chroma can use all the same optimizations as Flux.

## Inference

The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).

```python
import torch
from diffusers import ChromaPipeline

pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = [
Expand Down Expand Up @@ -63,10 +64,10 @@ Then run the following example
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline

model_id = "lodestones/Chroma"
model_id = "lodestones/Chroma1-HD"
dtype = torch.bfloat16

transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)

pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/optimization/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
| attention family | main feature |
|---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
| SageAttention | quantizes attention to int8 |
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
| xFormers | memory-efficient attention with support for various attention kernels |
Expand Down Expand Up @@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
Expand Down
60 changes: 60 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from ..utils import (
get_logger,
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
Expand All @@ -47,13 +49,15 @@
from ._modeling_parallel import ParallelConfig

_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"

_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
Expand All @@ -78,6 +82,12 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None


if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None

if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
Expand Down Expand Up @@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum):
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.

# `aiter`
AITER = "aiter"

# PyTorch native
FLEX = "flex"
NATIVE = "native"
Expand Down Expand Up @@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
)

elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
Expand Down Expand Up @@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3(
return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.AITER,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=True,
)
else:
out = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out

return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
"""
The Transformer model introduced in Flux, modified for Chroma.

Reference: https://huggingface.co/lodestones/Chroma
Reference: https://huggingface.co/lodestones/Chroma1-HD

Args:
patch_size (`int`, defaults to `1`):
Expand Down
25 changes: 14 additions & 11 deletions src/diffusers/pipelines/chroma/pipeline_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
>>> import torch
>>> from diffusers import ChromaPipeline

>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> pipe = ChromaPipeline.from_pretrained(
... model_id,
Expand Down Expand Up @@ -158,7 +158,7 @@ class ChromaPipeline(
r"""
The Chroma pipeline for text-to-image generation.

Reference: https://huggingface.co/lodestones/Chroma/
Reference: https://huggingface.co/lodestones/Chroma1-HD/

Args:
transformer ([`ChromaTransformer2DModel`]):
Expand Down Expand Up @@ -233,20 +233,23 @@ def _get_t5_prompt_embeds(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
tokenizer_mask = text_inputs.attention_mask

# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
tokenizer_mask_device = tokenizer_mask.to(device)

# unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
)[0]

dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)

# for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)

_, seq_len, _ = prompt_embeds.shape

Expand Down
23 changes: 12 additions & 11 deletions src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
>>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline

>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> model_id = "lodestones/Chroma1-HD"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... model_id,
... transformer=transformer,
Expand Down Expand Up @@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline(
r"""
The Chroma pipeline for image-to-image generation.

Reference: https://huggingface.co/lodestones/Chroma/
Reference: https://huggingface.co/lodestones/Chroma1-HD/

Args:
transformer ([`ChromaTransformer2DModel`]):
Expand Down Expand Up @@ -247,20 +247,21 @@ def _get_t5_prompt_embeds(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.clone()
tokenizer_mask = text_inputs.attention_mask

# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
tokenizer_mask_device = tokenizer_mask.to(device)

prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=tokenizer_mask_device,
)[0]

dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)

seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)

_, seq_len, _ = prompt_embeds.shape

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
get_objects_from_module,
is_accelerate_available,
is_accelerate_version,
is_aiter_available,
is_aiter_version,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)

Expand Down Expand Up @@ -406,6 +407,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available


def is_aiter_available():
return _aiter_available


def is_kornia_available():
return _kornia_available

Expand Down Expand Up @@ -911,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str):
return compare_versions(parse(_flash_attn_version), operation, version)


@cache
def is_aiter_version(operation: str, version: str):
"""
Compares the current aiter version to a given reference with an operation.

Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _aiter_available:
return False
return compare_versions(parse(_aiter_version), operation, version)


def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
Expand Down
13 changes: 13 additions & 0 deletions tests/others/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).

Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
aiter 0.1.5.post4.dev20+ga25e55e79.
"""

import os
Expand Down Expand Up @@ -44,6 +48,10 @@
"_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
),
(
"aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
)
]

COMPILE_CASES = [
Expand All @@ -63,6 +71,11 @@
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True,
),
(
"aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True,
)
]
# fmt: on

Expand Down