Skip to content

Commit ccdd96c

Browse files
authored
[tests] Test attention backends (huggingface#12388)
* add a lightweight test suite for attention backends. * up * up * Apply suggestions from code review * formatting
1 parent 4c723d8 commit ccdd96c

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
3+
4+
Once attention backends become more mature, we can consider including this in our CI.
5+
6+
To run this test suite:
7+
8+
```bash
9+
export RUN_ATTENTION_BACKEND_TESTS=yes
10+
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
11+
12+
pytest tests/others/test_attention_backends.py
13+
```
14+
15+
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
16+
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
17+
"""
18+
19+
import os
20+
21+
import pytest
22+
import torch
23+
24+
25+
pytestmark = pytest.mark.skipif(
26+
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
27+
)
28+
from diffusers import FluxPipeline # noqa: E402
29+
from diffusers.utils import is_torch_version # noqa: E402
30+
31+
32+
# fmt: off
33+
FORWARD_CASES = [
34+
("flash_hub", None),
35+
(
36+
"_flash_3_hub",
37+
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
38+
),
39+
(
40+
"native",
41+
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
42+
),
43+
(
44+
"_native_cudnn",
45+
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
46+
),
47+
]
48+
49+
COMPILE_CASES = [
50+
("flash_hub", None, True),
51+
(
52+
"_flash_3_hub",
53+
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
54+
True,
55+
),
56+
(
57+
"native",
58+
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
59+
True,
60+
),
61+
(
62+
"_native_cudnn",
63+
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
64+
True,
65+
),
66+
]
67+
# fmt: on
68+
69+
INFER_KW = {
70+
"prompt": "dance doggo dance",
71+
"height": 256,
72+
"width": 256,
73+
"num_inference_steps": 2,
74+
"guidance_scale": 3.5,
75+
"max_sequence_length": 128,
76+
"output_type": "pt",
77+
}
78+
79+
80+
def _backend_is_probably_supported(pipe, name: str):
81+
try:
82+
pipe.transformer.set_attention_backend(name)
83+
return pipe, True
84+
except Exception:
85+
return False
86+
87+
88+
def _check_if_slices_match(output, expected_slice):
89+
img = output.images.detach().cpu()
90+
generated_slice = img.flatten()
91+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
92+
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
93+
94+
95+
@pytest.fixture(scope="session")
96+
def device():
97+
if not torch.cuda.is_available():
98+
pytest.skip("CUDA is required for these tests.")
99+
return torch.device("cuda:0")
100+
101+
102+
@pytest.fixture(scope="session")
103+
def pipe(device):
104+
repo_id = "black-forest-labs/FLUX.1-dev"
105+
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
106+
pipe.set_progress_bar_config(disable=True)
107+
return pipe
108+
109+
110+
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
111+
def test_forward(pipe, backend_name, expected_slice):
112+
out = _backend_is_probably_supported(pipe, backend_name)
113+
if isinstance(out, bool):
114+
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
115+
116+
modified_pipe = out[0]
117+
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
118+
_check_if_slices_match(out, expected_slice)
119+
120+
121+
@pytest.mark.parametrize(
122+
"backend_name,expected_slice,error_on_recompile",
123+
COMPILE_CASES,
124+
ids=[c[0] for c in COMPILE_CASES],
125+
)
126+
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
127+
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
128+
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
129+
130+
out = _backend_is_probably_supported(pipe, backend_name)
131+
if isinstance(out, bool):
132+
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
133+
134+
modified_pipe = out[0]
135+
modified_pipe.transformer.compile(fullgraph=True)
136+
137+
torch.compiler.reset()
138+
with (
139+
torch._inductor.utils.fresh_inductor_cache(),
140+
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
141+
):
142+
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
143+
144+
_check_if_slices_match(out, expected_slice)

0 commit comments

Comments
 (0)