1717
1818import torch
1919
20- from diffusers import Flux2Transformer2DModel
20+ from diffusers import Flux2Transformer2DModel , attention_backend
2121from diffusers .models .attention_processor import FluxIPAdapterJointAttnProcessor2_0
2222from diffusers .models .embeddings import ImageProjection
2323
@@ -166,11 +166,12 @@ def test_flux2_consistency(self, seed=0):
166166 model .to (torch_device )
167167 model .eval ()
168168
169- with torch .no_grad ():
170- output = model (** inputs_dict )
169+ with attention_backend ("native" ):
170+ with torch .no_grad ():
171+ output = model (** inputs_dict )
171172
172- if isinstance (output , dict ):
173- output = output .to_tuple ()[0 ]
173+ if isinstance (output , dict ):
174+ output = output .to_tuple ()[0 ]
174175
175176 self .assertIsNotNone (output )
176177
@@ -181,12 +182,12 @@ def test_flux2_consistency(self, seed=0):
181182
182183 # Check against expected slice
183184 # fmt: off
184- expected_slice = torch .tensor ([- 0.3180 , 0.4818 , 0.6621 , - 0.3386 , 0.2313 , 0.0688 , 0.0985 , - 0.2686 , - 0.1480 , - 0.1607 , - 0.7245 , 0.5385 , - 0.2842 , 0.6575 , - 0.0697 , 0.4951 ])
185+ expected_slice = torch .tensor ([- 0.3662 , 0.4844 , 0.6334 , - 0.3497 , 0.2162 , 0.0188 , 0.0521 , - 0.2061 , - 0.2041 , - 0.0342 , - 0.7107 , 0.4797 , - 0.3280 , 0.7059 , - 0.0849 , 0.4416 ])
185186 # fmt: on
186187
187188 flat_output = output .cpu ().flatten ()
188189 generated_slice = torch .cat ([flat_output [:8 ], flat_output [- 8 :]])
189- self .assertTrue (torch .allclose (expected_slice , generated_slice ))
190+ self .assertTrue (torch .allclose (generated_slice , expected_slice ))
190191
191192 def test_gradient_checkpointing_is_applied (self ):
192193 expected_set = {"Flux2Transformer2DModel" }
0 commit comments