1717)
1818from diffusers .utils import load_image
1919from diffusers .utils .testing_utils import (
20+ Expectations ,
21+ backend_empty_cache ,
22+ backend_max_memory_allocated ,
23+ backend_reset_peak_memory_stats ,
24+ enable_full_determinism ,
2025 is_gguf_available ,
2126 nightly ,
2227 numpy_cosine_similarity_distance ,
2328 require_accelerate ,
24- require_big_gpu_with_torch_cuda ,
29+ require_big_accelerator ,
2530 require_gguf_version_greater_or_equal ,
2631 require_peft_backend ,
2732 torch_device ,
3136if is_gguf_available ():
3237 from diffusers .quantizers .gguf .utils import GGUFLinear , GGUFParameter
3338
39+ enable_full_determinism ()
40+
3441
3542@nightly
36- @require_big_gpu_with_torch_cuda
43+ @require_big_accelerator
3744@require_accelerate
3845@require_gguf_version_greater_or_equal ("0.10.0" )
3946class GGUFSingleFileTesterMixin :
@@ -68,15 +75,15 @@ def test_gguf_memory_usage(self):
6875 model = self .model_cls .from_single_file (
6976 self .ckpt_path , quantization_config = quantization_config , torch_dtype = self .torch_dtype
7077 )
71- model .to ("cuda" )
78+ model .to (torch_device )
7279 assert (model .get_memory_footprint () / 1024 ** 3 ) < self .expected_memory_use_in_gb
7380 inputs = self .get_dummy_inputs ()
7481
75- torch . cuda . reset_peak_memory_stats ( )
76- torch . cuda . empty_cache ( )
82+ backend_reset_peak_memory_stats ( torch_device )
83+ backend_empty_cache ( torch_device )
7784 with torch .no_grad ():
7885 model (** inputs )
79- max_memory = torch . cuda . max_memory_allocated ( )
86+ max_memory = backend_max_memory_allocated ( torch_device )
8087 assert (max_memory / 1024 ** 3 ) < self .expected_memory_use_in_gb
8188
8289 def test_keep_modules_in_fp32 (self ):
@@ -106,7 +113,8 @@ def test_dtype_assignment(self):
106113
107114 with self .assertRaises (ValueError ):
108115 # Tries with a `device` and `dtype`
109- model .to (device = "cuda:0" , dtype = torch .float16 )
116+ device_0 = f"{ torch_device } :0"
117+ model .to (device = device_0 , dtype = torch .float16 )
110118
111119 with self .assertRaises (ValueError ):
112120 # Tries with a cast
@@ -117,7 +125,7 @@ def test_dtype_assignment(self):
117125 model .half ()
118126
119127 # This should work
120- model .to ("cuda" )
128+ model .to (torch_device )
121129
122130 def test_dequantize_model (self ):
123131 quantization_config = GGUFQuantizationConfig (compute_dtype = self .torch_dtype )
@@ -146,11 +154,11 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
146154
147155 def setUp (self ):
148156 gc .collect ()
149- torch . cuda . empty_cache ( )
157+ backend_empty_cache ( torch_device )
150158
151159 def tearDown (self ):
152160 gc .collect ()
153- torch . cuda . empty_cache ( )
161+ backend_empty_cache ( torch_device )
154162
155163 def get_dummy_inputs (self ):
156164 return {
@@ -233,11 +241,11 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase)
233241
234242 def setUp (self ):
235243 gc .collect ()
236- torch . cuda . empty_cache ( )
244+ backend_empty_cache ( torch_device )
237245
238246 def tearDown (self ):
239247 gc .collect ()
240- torch . cuda . empty_cache ( )
248+ backend_empty_cache ( torch_device )
241249
242250 def get_dummy_inputs (self ):
243251 return {
@@ -267,40 +275,79 @@ def test_pipeline_inference(self):
267275
268276 prompt = "a cat holding a sign that says hello"
269277 output = pipe (
270- prompt = prompt , num_inference_steps = 2 , generator = torch .Generator ("cpu" ).manual_seed (0 ), output_type = "np"
278+ prompt = prompt ,
279+ num_inference_steps = 2 ,
280+ generator = torch .Generator ("cpu" ).manual_seed (0 ),
281+ output_type = "np" ,
271282 ).images [0 ]
272283 output_slice = output [:3 , :3 , :].flatten ()
273- expected_slice = np .array (
274- [
275- 0.17578125 ,
276- 0.27539062 ,
277- 0.27734375 ,
278- 0.11914062 ,
279- 0.26953125 ,
280- 0.25390625 ,
281- 0.109375 ,
282- 0.25390625 ,
283- 0.25 ,
284- 0.15039062 ,
285- 0.26171875 ,
286- 0.28515625 ,
287- 0.13671875 ,
288- 0.27734375 ,
289- 0.28515625 ,
290- 0.12109375 ,
291- 0.26757812 ,
292- 0.265625 ,
293- 0.16210938 ,
294- 0.29882812 ,
295- 0.28515625 ,
296- 0.15625 ,
297- 0.30664062 ,
298- 0.27734375 ,
299- 0.14648438 ,
300- 0.29296875 ,
301- 0.26953125 ,
302- ]
284+ expected_slices = Expectations (
285+ {
286+ ("xpu" , 3 ): np .array (
287+ [
288+ 0.19335938 ,
289+ 0.3125 ,
290+ 0.3203125 ,
291+ 0.1328125 ,
292+ 0.3046875 ,
293+ 0.296875 ,
294+ 0.11914062 ,
295+ 0.2890625 ,
296+ 0.2890625 ,
297+ 0.16796875 ,
298+ 0.30273438 ,
299+ 0.33203125 ,
300+ 0.14648438 ,
301+ 0.31640625 ,
302+ 0.33007812 ,
303+ 0.12890625 ,
304+ 0.3046875 ,
305+ 0.30859375 ,
306+ 0.17773438 ,
307+ 0.33789062 ,
308+ 0.33203125 ,
309+ 0.16796875 ,
310+ 0.34570312 ,
311+ 0.32421875 ,
312+ 0.15625 ,
313+ 0.33203125 ,
314+ 0.31445312 ,
315+ ]
316+ ),
317+ ("cuda" , 7 ): np .array (
318+ [
319+ 0.17578125 ,
320+ 0.27539062 ,
321+ 0.27734375 ,
322+ 0.11914062 ,
323+ 0.26953125 ,
324+ 0.25390625 ,
325+ 0.109375 ,
326+ 0.25390625 ,
327+ 0.25 ,
328+ 0.15039062 ,
329+ 0.26171875 ,
330+ 0.28515625 ,
331+ 0.13671875 ,
332+ 0.27734375 ,
333+ 0.28515625 ,
334+ 0.12109375 ,
335+ 0.26757812 ,
336+ 0.265625 ,
337+ 0.16210938 ,
338+ 0.29882812 ,
339+ 0.28515625 ,
340+ 0.15625 ,
341+ 0.30664062 ,
342+ 0.27734375 ,
343+ 0.14648438 ,
344+ 0.29296875 ,
345+ 0.26953125 ,
346+ ]
347+ ),
348+ }
303349 )
350+ expected_slice = expected_slices .get_expectation ()
304351 max_diff = numpy_cosine_similarity_distance (expected_slice , output_slice )
305352 assert max_diff < 1e-4
306353
@@ -313,11 +360,11 @@ class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase
313360
314361 def setUp (self ):
315362 gc .collect ()
316- torch . cuda . empty_cache ( )
363+ backend_empty_cache ( torch_device )
317364
318365 def tearDown (self ):
319366 gc .collect ()
320- torch . cuda . empty_cache ( )
367+ backend_empty_cache ( torch_device )
321368
322369 def get_dummy_inputs (self ):
323370 return {
@@ -393,11 +440,11 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
393440
394441 def setUp (self ):
395442 gc .collect ()
396- torch . cuda . empty_cache ( )
443+ backend_empty_cache ( torch_device )
397444
398445 def tearDown (self ):
399446 gc .collect ()
400- torch . cuda . empty_cache ( )
447+ backend_empty_cache ( torch_device )
401448
402449 def get_dummy_inputs (self ):
403450 return {
@@ -463,7 +510,7 @@ def test_pipeline_inference(self):
463510
464511@require_peft_backend
465512@nightly
466- @require_big_gpu_with_torch_cuda
513+ @require_big_accelerator
467514@require_accelerate
468515@require_gguf_version_greater_or_equal ("0.10.0" )
469516class FluxControlLoRAGGUFTests (unittest .TestCase ):
@@ -478,7 +525,7 @@ def test_lora_loading(self):
478525 "black-forest-labs/FLUX.1-dev" ,
479526 transformer = transformer ,
480527 torch_dtype = torch .bfloat16 ,
481- ).to ("cuda" )
528+ ).to (torch_device )
482529 pipe .load_lora_weights ("black-forest-labs/FLUX.1-Canny-dev-lora" )
483530
484531 prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
0 commit comments