3131from diffusers .utils import load_image , logging
3232from diffusers .utils .testing_utils import (
3333 CaptureLogger ,
34+ backend_empty_cache ,
3435 floats_tensor ,
3536 is_peft_available ,
3637 nightly ,
3738 numpy_cosine_similarity_distance ,
38- require_big_gpu_with_torch_cuda ,
39+ require_big_accelerator ,
3940 require_peft_backend ,
40- require_torch_gpu ,
41+ require_torch_accelerator ,
4142 slow ,
4243 torch_device ,
4344)
@@ -809,10 +810,10 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
809810
810811@slow
811812@nightly
812- @require_torch_gpu
813+ @require_torch_accelerator
813814@require_peft_backend
814- @require_big_gpu_with_torch_cuda
815- @pytest .mark .big_gpu_with_torch_cuda
815+ @require_big_accelerator
816+ @pytest .mark .big_accelerator
816817class FluxLoRAIntegrationTests (unittest .TestCase ):
817818 """internal note: The integration slices were obtained on audace.
818819
@@ -827,7 +828,7 @@ def setUp(self):
827828 super ().setUp ()
828829
829830 gc .collect ()
830- torch . cuda . empty_cache ( )
831+ backend_empty_cache ( torch_device )
831832
832833 self .pipeline = FluxPipeline .from_pretrained ("black-forest-labs/FLUX.1-dev" , torch_dtype = torch .bfloat16 )
833834
@@ -836,13 +837,13 @@ def tearDown(self):
836837
837838 del self .pipeline
838839 gc .collect ()
839- torch . cuda . empty_cache ( )
840+ backend_empty_cache ( torch_device )
840841
841842 def test_flux_the_last_ben (self ):
842843 self .pipeline .load_lora_weights ("TheLastBen/Jon_Snow_Flux_LoRA" , weight_name = "jon_snow.safetensors" )
843844 self .pipeline .fuse_lora ()
844845 self .pipeline .unload_lora_weights ()
845- # Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI
846+ # Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
846847 # run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
847848 # `enable_model_cpu_offload()`. We repeat this for the other tests, too.
848849 self .pipeline = self .pipeline .to (torch_device )
@@ -956,10 +957,10 @@ def test_flux_xlabs_load_lora_with_single_blocks(self):
956957
957958
958959@nightly
959- @require_torch_gpu
960+ @require_torch_accelerator
960961@require_peft_backend
961- @require_big_gpu_with_torch_cuda
962- @pytest .mark .big_gpu_with_torch_cuda
962+ @require_big_accelerator
963+ @pytest .mark .big_accelerator
963964class FluxControlLoRAIntegrationTests (unittest .TestCase ):
964965 num_inference_steps = 10
965966 seed = 0
@@ -969,17 +970,17 @@ def setUp(self):
969970 super ().setUp ()
970971
971972 gc .collect ()
972- torch . cuda . empty_cache ( )
973+ backend_empty_cache ( torch_device )
973974
974975 self .pipeline = FluxControlPipeline .from_pretrained (
975976 "black-forest-labs/FLUX.1-dev" , torch_dtype = torch .bfloat16
976- ).to ("cuda" )
977+ ).to (torch_device )
977978
978979 def tearDown (self ):
979980 super ().tearDown ()
980981
981982 gc .collect ()
982- torch . cuda . empty_cache ( )
983+ backend_empty_cache ( torch_device )
983984
984985 @parameterized .expand (["black-forest-labs/FLUX.1-Canny-dev-lora" , "black-forest-labs/FLUX.1-Depth-dev-lora" ])
985986 def test_lora (self , lora_ckpt_id ):
0 commit comments