Skip to content

Commit c51e5f2

Browse files
committed
up
1 parent b62fec0 commit c51e5f2

File tree

1 file changed

+177
-35
lines changed

1 file changed

+177
-35
lines changed

tests/others/test_attention_backends.py

Lines changed: 177 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
66
To run this test suite:
77
8-
```
8+
```bash
99
export RUN_ATTENTION_BACKEND_TESTS=yes
1010
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
1111
1212
pytest tests/others/test_attention_backends.py
1313
```
14+
15+
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
16+
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
1417
"""
1518

1619
import os
@@ -20,28 +23,165 @@
2023

2124

2225
pytestmark = pytest.mark.skipif(
23-
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "true", reason="Feature not mature enough."
26+
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
2427
)
25-
26-
from pytest import mark as parameterize # noqa: E402
27-
from torch._dynamo import config as dynamo_config # noqa: E402
28-
2928
from diffusers import FluxPipeline # noqa: E402
29+
from diffusers.utils import is_torch_version # noqa: E402
3030

3131

3232
FORWARD_CASES = [
3333
("flash_hub", None),
34-
("_flash_3_hub", None),
35-
("native", None),
36-
("_native_cudnn", None),
34+
(
35+
"_flash_3_hub",
36+
torch.tensor(
37+
[
38+
0.0820,
39+
0.0859,
40+
0.0938,
41+
0.1016,
42+
0.0977,
43+
0.0996,
44+
0.1016,
45+
0.1016,
46+
0.2188,
47+
0.2246,
48+
0.2344,
49+
0.2480,
50+
0.2539,
51+
0.2480,
52+
0.2441,
53+
0.2715,
54+
],
55+
dtype=torch.bfloat16,
56+
),
57+
),
58+
(
59+
"native",
60+
torch.tensor(
61+
[
62+
0.0820,
63+
0.0859,
64+
0.0938,
65+
0.1016,
66+
0.0957,
67+
0.0996,
68+
0.0996,
69+
0.1016,
70+
0.2188,
71+
0.2266,
72+
0.2363,
73+
0.2500,
74+
0.2539,
75+
0.2480,
76+
0.2461,
77+
0.2734,
78+
],
79+
dtype=torch.bfloat16,
80+
),
81+
),
82+
(
83+
"_native_cudnn",
84+
torch.tensor(
85+
[
86+
0.0781,
87+
0.0840,
88+
0.0879,
89+
0.0957,
90+
0.0898,
91+
0.0957,
92+
0.0957,
93+
0.0977,
94+
0.2168,
95+
0.2246,
96+
0.2324,
97+
0.2500,
98+
0.2539,
99+
0.2480,
100+
0.2441,
101+
0.2695,
102+
],
103+
dtype=torch.bfloat16,
104+
),
105+
),
37106
]
38107

39108
COMPILE_CASES = [
40109
("flash_hub", None, True),
41-
("_flash_3_hub", None, True),
42-
("native", None, True),
43-
("_native_cudnn", None, True),
44-
("native", None, True),
110+
(
111+
"_flash_3_hub",
112+
torch.tensor(
113+
[
114+
0.0410,
115+
0.0410,
116+
0.0449,
117+
0.0508,
118+
0.0508,
119+
0.0605,
120+
0.0625,
121+
0.0605,
122+
0.2344,
123+
0.2461,
124+
0.2578,
125+
0.2734,
126+
0.2852,
127+
0.2812,
128+
0.2773,
129+
0.3047,
130+
],
131+
dtype=torch.bfloat16,
132+
),
133+
True,
134+
),
135+
(
136+
"native",
137+
torch.tensor(
138+
[
139+
0.0410,
140+
0.0410,
141+
0.0449,
142+
0.0508,
143+
0.0508,
144+
0.0605,
145+
0.0605,
146+
0.0605,
147+
0.2344,
148+
0.2461,
149+
0.2578,
150+
0.2773,
151+
0.2871,
152+
0.2832,
153+
0.2773,
154+
0.3066,
155+
],
156+
dtype=torch.bfloat16,
157+
),
158+
True,
159+
),
160+
(
161+
"_native_cudnn",
162+
torch.tensor(
163+
[
164+
0.0410,
165+
0.0410,
166+
0.0430,
167+
0.0508,
168+
0.0488,
169+
0.0586,
170+
0.0605,
171+
0.0586,
172+
0.2344,
173+
0.2461,
174+
0.2578,
175+
0.2773,
176+
0.2871,
177+
0.2832,
178+
0.2793,
179+
0.3086,
180+
],
181+
dtype=torch.bfloat16,
182+
),
183+
True,
184+
),
45185
]
46186

47187
INFER_KW = {
@@ -55,19 +195,18 @@
55195
}
56196

57197

58-
def _backend_is_probably_supported(pipe, name: str) -> bool:
198+
def _backend_is_probably_supported(pipe, name: str):
59199
try:
60200
pipe.transformer.set_attention_backend(name)
61-
return True
62-
except (NotImplementedError, RuntimeError, ValueError):
201+
return pipe, True
202+
except Exception:
63203
return False
64204

65205

66206
def _check_if_slices_match(output, expected_slice):
67-
img = output.images
207+
img = output.images.detach().cpu()
68208
generated_slice = img.flatten()
69209
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
70-
print(f"{generated_slice=}")
71210
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
72211

73212

@@ -88,35 +227,38 @@ def pipe(device):
88227
return pipe
89228

90229

91-
@parameterize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
230+
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
92231
def test_forward(pipe, backend_name, expected_slice):
93-
if not _backend_is_probably_supported(pipe, backend_name):
232+
out = _backend_is_probably_supported(pipe, backend_name)
233+
if isinstance(out, bool):
94234
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
95235

96-
out = pipe(
97-
"a tiny toy cat in a box",
98-
**INFER_KW,
99-
generator=torch.manual_seed(0),
100-
)
236+
modified_pipe = out[0]
237+
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
101238
_check_if_slices_match(out, expected_slice)
102239

103240

104-
@parameterize(
241+
@pytest.mark.parametrize(
105242
"backend_name,expected_slice,error_on_recompile",
106243
COMPILE_CASES,
107244
ids=[c[0] for c in COMPILE_CASES],
108245
)
109246
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
110-
if not _backend_is_probably_supported(pipe, backend_name):
247+
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
248+
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
249+
250+
out = _backend_is_probably_supported(pipe, backend_name)
251+
if isinstance(out, bool):
111252
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
112253

113-
pipe.transformer.compile(fullgraph=True)
114-
with dynamo_config.patch(error_on_recompile=bool(error_on_recompile)):
115-
torch.manual_seed(0)
116-
out = pipe(
117-
"a tiny toy cat in a box",
118-
**INFER_KW,
119-
generator=torch.manual_seed(0),
120-
)
254+
modified_pipe = out[0]
255+
modified_pipe.transformer.compile(fullgraph=True)
256+
257+
torch.compiler.reset()
258+
with (
259+
torch._inductor.utils.fresh_inductor_cache(),
260+
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
261+
):
262+
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
121263

122264
_check_if_slices_match(out, expected_slice)

0 commit comments

Comments
 (0)