Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions docs/source/en/optimization/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,45 @@ with attention_backend("_flash_3_hub"):
> [!TIP]
> Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference.

## Checks

The attention dispatcher includes debugging checks that catch common errors before they cause problems.

1. Device checks verify that query, key, and value tensors live on the same device.
2. Data type checks confirm tensors have matching dtypes and use either bfloat16 or float16.
3. Shape checks validate tensor dimensions and prevent mixing attention masks with causal flags.

Enable these checks by setting the `DIFFUSERS_ATTN_CHECKS` environment variable. Checks add overhead to every attention operation, so they're disabled by default.

```bash
export DIFFUSERS_ATTN_CHECKS=yes
```

The checks are run now before every attention operation.

```py
import torch

query = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
key = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
value = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")

try:
with attention_backend("flash"):
output = dispatch_attention_fn(query, key, value)
print("✓ Flash Attention works with checks enabled")
except Exception as e:
print(f"✗ Flash Attention failed: {e}")
Comment on lines +107 to +112
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be set as export DIFFUSERS_ATTN_CHECKS=yes before running any execution because of how DIFFUSERS_ATTN_CHECKS is used:

from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS

DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah DIFFUSERS_ATTN_CHECKS is determined at the module-level so I am not sure setting it that way would be effective.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, updated to use export DIFFUSERS_ATTN_CHECKS=yes then!

```

You can also configure the registry directly.

```py
from diffusers.models.attention_dispatch import _AttentionBackendRegistry

_AttentionBackendRegistry._checks_enabled = True
```

## Available backends

Refer to the table below for a complete list of available attention backends and their variants.
Expand Down