Skip to content

Commit c9b6296

Browse files
committed
add a lightweight test suite for attention backends.
1 parent 310fdaf commit c9b6296

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
3+
4+
Once attention backends become more mature, we can consider including this in our CI.
5+
6+
To run this test suite:
7+
8+
```
9+
export RUN_ATTENTION_BACKEND_TESTS=yes
10+
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
11+
12+
pytest tests/others/test_attention_backends.py
13+
```
14+
"""
15+
16+
import os
17+
18+
import pytest
19+
import torch
20+
21+
22+
pytestmark = pytest.mark.skipif(
23+
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "true", reason="Feature not mature enough."
24+
)
25+
26+
from pytest import mark as parameterize # noqa: E402
27+
from torch._dynamo import config as dynamo_config # noqa: E402
28+
29+
from diffusers import FluxPipeline # noqa: E402
30+
31+
32+
FORWARD_CASES = [
33+
("flash_hub", None),
34+
("_flash_3_hub", None),
35+
("native", None),
36+
("_native_cudnn", None),
37+
]
38+
39+
COMPILE_CASES = [
40+
("flash_hub", None, True),
41+
("_flash_3_hub", None, True),
42+
("native", None, True),
43+
("_native_cudnn", None, True),
44+
("native", None, True),
45+
]
46+
47+
INFER_KW = {
48+
"prompt": "dance doggo dance",
49+
"height": 256,
50+
"width": 256,
51+
"num_inference_steps": 2,
52+
"guidance_scale": 3.5,
53+
"max_sequence_length": 128,
54+
"output_type": "pt",
55+
}
56+
57+
58+
def _backend_is_probably_supported(pipe, name: str) -> bool:
59+
try:
60+
pipe.transformer.set_attention_backend(name)
61+
return True
62+
except (NotImplementedError, RuntimeError, ValueError):
63+
return False
64+
65+
66+
def _check_if_slices_match(output, expected_slice):
67+
img = output.images
68+
generated_slice = img.flatten()
69+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
70+
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
71+
72+
73+
@pytest.fixture(scope="session")
74+
def device():
75+
if not torch.cuda.is_available():
76+
pytest.skip("CUDA is required for these tests.")
77+
return torch.device("cuda:0")
78+
79+
80+
@pytest.fixture(scope="session")
81+
def pipe(device):
82+
torch.set_grad_enabled(False)
83+
model_id = "black-forest-labs/FLUX.1-dev"
84+
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
85+
pipe.set_progress_bar_config(disable=True)
86+
pipe.transformer.eval()
87+
return pipe
88+
89+
90+
@parameterize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
91+
def test_forward(pipe, backend_name, expected_slice):
92+
if not _backend_is_probably_supported(pipe, backend_name):
93+
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
94+
95+
out = pipe(
96+
"a tiny toy cat in a box",
97+
**INFER_KW,
98+
generator=torch.manual_seed(0),
99+
)
100+
_check_if_slices_match(out, expected_slice)
101+
102+
103+
@parameterize(
104+
"backend_name,expected_slice,error_on_recompile",
105+
COMPILE_CASES,
106+
ids=[c[0] for c in COMPILE_CASES],
107+
)
108+
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
109+
if not _backend_is_probably_supported(pipe, backend_name):
110+
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
111+
112+
pipe.transformer.compile(fullgraph=True)
113+
with dynamo_config.patch(error_on_recompile=bool(error_on_recompile)):
114+
torch.manual_seed(0)
115+
out = pipe(
116+
"a tiny toy cat in a box",
117+
**INFER_KW,
118+
generator=torch.manual_seed(0),
119+
)
120+
121+
_check_if_slices_match(out, expected_slice)

0 commit comments

Comments
 (0)