55
66To run this test suite:
77
8- ```
8+ ```bash
99export RUN_ATTENTION_BACKEND_TESTS=yes
1010export DIFFUSERS_ENABLE_HUB_KERNELS=yes
1111
1212pytest tests/others/test_attention_backends.py
1313```
14+
15+ Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
16+ "native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
1417"""
1518
1619import os
2023
2124
2225pytestmark = pytest .mark .skipif (
23- os .getenv ("RUN_ATTENTION_BACKEND_TESTS" , "false" ) == "true " , reason = "Feature not mature enough."
26+ os .getenv ("RUN_ATTENTION_BACKEND_TESTS" , "false" ) == "false " , reason = "Feature not mature enough."
2427)
25-
26- from pytest import mark as parameterize # noqa: E402
27- from torch ._dynamo import config as dynamo_config # noqa: E402
28-
2928from diffusers import FluxPipeline # noqa: E402
29+ from diffusers .utils import is_torch_version # noqa: E402
3030
3131
3232FORWARD_CASES = [
3333 ("flash_hub" , None ),
34- ("_flash_3_hub" , None ),
35- ("native" , None ),
36- ("_native_cudnn" , None ),
34+ (
35+ "_flash_3_hub" ,
36+ torch .tensor (
37+ [
38+ 0.0820 ,
39+ 0.0859 ,
40+ 0.0938 ,
41+ 0.1016 ,
42+ 0.0977 ,
43+ 0.0996 ,
44+ 0.1016 ,
45+ 0.1016 ,
46+ 0.2188 ,
47+ 0.2246 ,
48+ 0.2344 ,
49+ 0.2480 ,
50+ 0.2539 ,
51+ 0.2480 ,
52+ 0.2441 ,
53+ 0.2715 ,
54+ ],
55+ dtype = torch .bfloat16 ,
56+ ),
57+ ),
58+ (
59+ "native" ,
60+ torch .tensor (
61+ [
62+ 0.0820 ,
63+ 0.0859 ,
64+ 0.0938 ,
65+ 0.1016 ,
66+ 0.0957 ,
67+ 0.0996 ,
68+ 0.0996 ,
69+ 0.1016 ,
70+ 0.2188 ,
71+ 0.2266 ,
72+ 0.2363 ,
73+ 0.2500 ,
74+ 0.2539 ,
75+ 0.2480 ,
76+ 0.2461 ,
77+ 0.2734 ,
78+ ],
79+ dtype = torch .bfloat16 ,
80+ ),
81+ ),
82+ (
83+ "_native_cudnn" ,
84+ torch .tensor (
85+ [
86+ 0.0781 ,
87+ 0.0840 ,
88+ 0.0879 ,
89+ 0.0957 ,
90+ 0.0898 ,
91+ 0.0957 ,
92+ 0.0957 ,
93+ 0.0977 ,
94+ 0.2168 ,
95+ 0.2246 ,
96+ 0.2324 ,
97+ 0.2500 ,
98+ 0.2539 ,
99+ 0.2480 ,
100+ 0.2441 ,
101+ 0.2695 ,
102+ ],
103+ dtype = torch .bfloat16 ,
104+ ),
105+ ),
37106]
38107
39108COMPILE_CASES = [
40109 ("flash_hub" , None , True ),
41- ("_flash_3_hub" , None , True ),
42- ("native" , None , True ),
43- ("_native_cudnn" , None , True ),
44- ("native" , None , True ),
110+ (
111+ "_flash_3_hub" ,
112+ torch .tensor (
113+ [
114+ 0.0410 ,
115+ 0.0410 ,
116+ 0.0449 ,
117+ 0.0508 ,
118+ 0.0508 ,
119+ 0.0605 ,
120+ 0.0625 ,
121+ 0.0605 ,
122+ 0.2344 ,
123+ 0.2461 ,
124+ 0.2578 ,
125+ 0.2734 ,
126+ 0.2852 ,
127+ 0.2812 ,
128+ 0.2773 ,
129+ 0.3047 ,
130+ ],
131+ dtype = torch .bfloat16 ,
132+ ),
133+ True ,
134+ ),
135+ (
136+ "native" ,
137+ torch .tensor (
138+ [
139+ 0.0410 ,
140+ 0.0410 ,
141+ 0.0449 ,
142+ 0.0508 ,
143+ 0.0508 ,
144+ 0.0605 ,
145+ 0.0605 ,
146+ 0.0605 ,
147+ 0.2344 ,
148+ 0.2461 ,
149+ 0.2578 ,
150+ 0.2773 ,
151+ 0.2871 ,
152+ 0.2832 ,
153+ 0.2773 ,
154+ 0.3066 ,
155+ ],
156+ dtype = torch .bfloat16 ,
157+ ),
158+ True ,
159+ ),
160+ (
161+ "_native_cudnn" ,
162+ torch .tensor (
163+ [
164+ 0.0410 ,
165+ 0.0410 ,
166+ 0.0430 ,
167+ 0.0508 ,
168+ 0.0488 ,
169+ 0.0586 ,
170+ 0.0605 ,
171+ 0.0586 ,
172+ 0.2344 ,
173+ 0.2461 ,
174+ 0.2578 ,
175+ 0.2773 ,
176+ 0.2871 ,
177+ 0.2832 ,
178+ 0.2793 ,
179+ 0.3086 ,
180+ ],
181+ dtype = torch .bfloat16 ,
182+ ),
183+ True ,
184+ ),
45185]
46186
47187INFER_KW = {
55195}
56196
57197
58- def _backend_is_probably_supported (pipe , name : str ) -> bool :
198+ def _backend_is_probably_supported (pipe , name : str ):
59199 try :
60200 pipe .transformer .set_attention_backend (name )
61- return True
62- except ( NotImplementedError , RuntimeError , ValueError ) :
201+ return pipe , True
202+ except Exception :
63203 return False
64204
65205
66206def _check_if_slices_match (output , expected_slice ):
67- img = output .images
207+ img = output .images . detach (). cpu ()
68208 generated_slice = img .flatten ()
69209 generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
70- print (f"{ generated_slice = } " )
71210 assert torch .allclose (generated_slice , expected_slice , atol = 1e-4 )
72211
73212
@@ -88,35 +227,38 @@ def pipe(device):
88227 return pipe
89228
90229
91- @parameterize ("backend_name,expected_slice" , FORWARD_CASES , ids = [c [0 ] for c in FORWARD_CASES ])
230+ @pytest . mark . parametrize ("backend_name,expected_slice" , FORWARD_CASES , ids = [c [0 ] for c in FORWARD_CASES ])
92231def test_forward (pipe , backend_name , expected_slice ):
93- if not _backend_is_probably_supported (pipe , backend_name ):
232+ out = _backend_is_probably_supported (pipe , backend_name )
233+ if isinstance (out , bool ):
94234 pytest .xfail (f"Backend '{ backend_name } ' not supported in this environment." )
95235
96- out = pipe (
97- "a tiny toy cat in a box" ,
98- ** INFER_KW ,
99- generator = torch .manual_seed (0 ),
100- )
236+ modified_pipe = out [0 ]
237+ out = modified_pipe (** INFER_KW , generator = torch .manual_seed (0 ))
101238 _check_if_slices_match (out , expected_slice )
102239
103240
104- @parameterize (
241+ @pytest . mark . parametrize (
105242 "backend_name,expected_slice,error_on_recompile" ,
106243 COMPILE_CASES ,
107244 ids = [c [0 ] for c in COMPILE_CASES ],
108245)
109246def test_forward_with_compile (pipe , backend_name , expected_slice , error_on_recompile ):
110- if not _backend_is_probably_supported (pipe , backend_name ):
247+ if "native" in backend_name and error_on_recompile and not is_torch_version (">=" , "2.9.0" ):
248+ pytest .xfail (f"Test with { backend_name = } is compatible with a higher version of torch." )
249+
250+ out = _backend_is_probably_supported (pipe , backend_name )
251+ if isinstance (out , bool ):
111252 pytest .xfail (f"Backend '{ backend_name } ' not supported in this environment." )
112253
113- pipe .transformer .compile (fullgraph = True )
114- with dynamo_config .patch (error_on_recompile = bool (error_on_recompile )):
115- torch .manual_seed (0 )
116- out = pipe (
117- "a tiny toy cat in a box" ,
118- ** INFER_KW ,
119- generator = torch .manual_seed (0 ),
120- )
254+ modified_pipe = out [0 ]
255+ modified_pipe .transformer .compile (fullgraph = True )
256+
257+ torch .compiler .reset ()
258+ with (
259+ torch ._inductor .utils .fresh_inductor_cache (),
260+ torch ._dynamo .config .patch (error_on_recompile = error_on_recompile ),
261+ ):
262+ out = modified_pipe (** INFER_KW , generator = torch .manual_seed (0 ))
121263
122264 _check_if_slices_match (out , expected_slice )
0 commit comments