2626from diffusers .utils import is_accelerate_version , logging
2727from diffusers .utils .testing_utils import (
2828 CaptureLogger ,
29+ backend_empty_cache ,
2930 is_bitsandbytes_available ,
3031 is_torch_available ,
3132 is_transformers_available ,
3536 require_bitsandbytes_version_greater ,
3637 require_peft_backend ,
3738 require_torch ,
38- require_torch_gpu ,
39+ require_torch_accelerator ,
3940 require_transformers_version_greater ,
4041 slow ,
4142 torch_device ,
@@ -66,7 +67,7 @@ def get_some_linear_layer(model):
6667@require_bitsandbytes_version_greater ("0.43.2" )
6768@require_accelerate
6869@require_torch
69- @require_torch_gpu
70+ @require_torch_accelerator
7071@slow
7172class Base4bitTests (unittest .TestCase ):
7273 # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
@@ -84,13 +85,16 @@ class Base4bitTests(unittest.TestCase):
8485
8586 def get_dummy_inputs (self ):
8687 prompt_embeds = load_pt (
87- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
88+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" ,
89+ torch_device ,
8890 )
8991 pooled_prompt_embeds = load_pt (
90- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
92+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" ,
93+ torch_device ,
9194 )
9295 latent_model_input = load_pt (
93- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
96+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" ,
97+ torch_device ,
9498 )
9599
96100 input_dict_for_transformer = {
@@ -106,7 +110,7 @@ def get_dummy_inputs(self):
106110class BnB4BitBasicTests (Base4bitTests ):
107111 def setUp (self ):
108112 gc .collect ()
109- torch . cuda . empty_cache ( )
113+ backend_empty_cache ( torch_device )
110114
111115 # Models
112116 self .model_fp16 = SD3Transformer2DModel .from_pretrained (
@@ -128,7 +132,7 @@ def tearDown(self):
128132 del self .model_4bit
129133
130134 gc .collect ()
131- torch . cuda . empty_cache ( )
135+ backend_empty_cache ( torch_device )
132136
133137 def test_quantization_num_parameters (self ):
134138 r"""
@@ -224,7 +228,7 @@ def test_keep_modules_in_fp32(self):
224228 self .assertTrue (module .weight .dtype == torch .uint8 )
225229
226230 # test if inference works.
227- with torch .no_grad () and torch .amp .autocast ("cuda" , dtype = torch .float16 ):
231+ with torch .no_grad () and torch .amp .autocast (torch_device , dtype = torch .float16 ):
228232 input_dict_for_transformer = self .get_dummy_inputs ()
229233 model_inputs = {
230234 k : v .to (device = torch_device ) for k , v in input_dict_for_transformer .items () if not isinstance (v , bool )
@@ -266,9 +270,9 @@ def test_device_assignment(self):
266270 self .assertAlmostEqual (self .model_4bit .get_memory_footprint (), mem_before )
267271
268272 # Move back to CUDA device
269- for device in [0 , "cuda " , "cuda :0" , "call()" ]:
273+ for device in [0 , f" { torch_device } " , f" { torch_device } :0" , "call()" ]:
270274 if device == "call()" :
271- self .model_4bit .cuda ( 0 )
275+ self .model_4bit .to ( f" { torch_device } :0" )
272276 else :
273277 self .model_4bit .to (device )
274278 self .assertEqual (self .model_4bit .device , torch .device (0 ))
@@ -286,7 +290,7 @@ def test_device_and_dtype_assignment(self):
286290
287291 with self .assertRaises (ValueError ):
288292 # Tries with a `device` and `dtype`
289- self .model_4bit .to (device = "cuda :0" , dtype = torch .float16 )
293+ self .model_4bit .to (device = f" { torch_device } :0" , dtype = torch .float16 )
290294
291295 with self .assertRaises (ValueError ):
292296 # Tries with a cast
@@ -297,7 +301,7 @@ def test_device_and_dtype_assignment(self):
297301 self .model_4bit .half ()
298302
299303 # This should work
300- self .model_4bit .to ("cuda" )
304+ self .model_4bit .to (torch_device )
301305
302306 # Test if we did not break anything
303307 self .model_fp16 = self .model_fp16 .to (dtype = torch .float32 , device = torch_device )
@@ -321,7 +325,7 @@ def test_device_and_dtype_assignment(self):
321325 _ = self .model_fp16 .float ()
322326
323327 # Check that this does not throw an error
324- _ = self .model_fp16 .cuda ( )
328+ _ = self .model_fp16 .to ( torch_device )
325329
326330 def test_bnb_4bit_wrong_config (self ):
327331 r"""
@@ -398,7 +402,7 @@ def test_training(self):
398402 model_inputs .update ({k : v for k , v in input_dict_for_transformer .items () if k not in model_inputs })
399403
400404 # Step 4: Check if the gradient is not None
401- with torch .amp .autocast ("cuda" , dtype = torch .float16 ):
405+ with torch .amp .autocast (torch_device , dtype = torch .float16 ):
402406 out = self .model_4bit (** model_inputs )[0 ]
403407 out .norm ().backward ()
404408
@@ -412,7 +416,7 @@ def test_training(self):
412416class SlowBnb4BitTests (Base4bitTests ):
413417 def setUp (self ) -> None :
414418 gc .collect ()
415- torch . cuda . empty_cache ( )
419+ backend_empty_cache ( torch_device )
416420
417421 nf4_config = BitsAndBytesConfig (
418422 load_in_4bit = True ,
@@ -431,7 +435,7 @@ def tearDown(self):
431435 del self .pipeline_4bit
432436
433437 gc .collect ()
434- torch . cuda . empty_cache ( )
438+ backend_empty_cache ( torch_device )
435439
436440 def test_quality (self ):
437441 output = self .pipeline_4bit (
@@ -501,7 +505,7 @@ def test_moving_to_cpu_throws_warning(self):
501505 reason = "Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release." ,
502506 strict = True ,
503507 )
504- def test_pipeline_cuda_placement_works_with_nf4 (self ):
508+ def test_pipeline_device_placement_works_with_nf4 (self ):
505509 transformer_nf4_config = BitsAndBytesConfig (
506510 load_in_4bit = True ,
507511 bnb_4bit_quant_type = "nf4" ,
@@ -532,7 +536,7 @@ def test_pipeline_cuda_placement_works_with_nf4(self):
532536 transformer = transformer_4bit ,
533537 text_encoder_3 = text_encoder_3_4bit ,
534538 torch_dtype = torch .float16 ,
535- ).to ("cuda" )
539+ ).to (torch_device )
536540
537541 # Check if inference works.
538542 _ = pipeline_4bit ("table" , max_sequence_length = 20 , num_inference_steps = 2 )
@@ -696,7 +700,7 @@ def test_lora_loading(self):
696700class BaseBnb4BitSerializationTests (Base4bitTests ):
697701 def tearDown (self ):
698702 gc .collect ()
699- torch . cuda . empty_cache ( )
703+ backend_empty_cache ( torch_device )
700704
701705 def test_serialization (self , quant_type = "nf4" , double_quant = True , safe_serialization = True ):
702706 r"""
0 commit comments