3030)
3131from diffusers .models .attention_processor import Attention
3232from diffusers .utils .testing_utils import (
33+ backend_empty_cache ,
34+ backend_synchronize ,
3335 enable_full_determinism ,
3436 is_torch_available ,
3537 is_torchao_available ,
3638 nightly ,
3739 numpy_cosine_similarity_distance ,
3840 require_torch ,
41+ require_torch_accelerator ,
3942 require_torch_gpu ,
4043 require_torchao_version_greater_or_equal ,
4144 slow ,
6164
6265
6366@require_torch
64- @require_torch_gpu
67+ @require_torch_accelerator
6568@require_torchao_version_greater_or_equal ("0.7.0" )
6669class TorchAoConfigTest (unittest .TestCase ):
6770 def test_to_dict (self ):
@@ -119,12 +122,12 @@ def test_repr(self):
119122
120123# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
121124@require_torch
122- @require_torch_gpu
125+ @require_torch_accelerator
123126@require_torchao_version_greater_or_equal ("0.7.0" )
124127class TorchAoTest (unittest .TestCase ):
125128 def tearDown (self ):
126129 gc .collect ()
127- torch . cuda . empty_cache ( )
130+ backend_empty_cache ( torch_device )
128131
129132 def get_dummy_components (
130133 self , quantization_config : TorchAoConfig , model_id : str = "hf-internal-testing/tiny-flux-pipe"
@@ -518,14 +521,14 @@ def test_sequential_cpu_offload(self):
518521
519522# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
520523@require_torch
521- @require_torch_gpu
524+ @require_torch_accelerator
522525@require_torchao_version_greater_or_equal ("0.7.0" )
523526class TorchAoSerializationTest (unittest .TestCase ):
524527 model_name = "hf-internal-testing/tiny-flux-pipe"
525528
526529 def tearDown (self ):
527530 gc .collect ()
528- torch . cuda . empty_cache ( )
531+ backend_empty_cache ( torch_device )
529532
530533 def get_dummy_model (self , quant_method , quant_method_kwargs , device = None ):
531534 quantization_config = TorchAoConfig (quant_method , ** quant_method_kwargs )
@@ -596,14 +599,14 @@ def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs,
596599 def test_int_a8w8_cuda (self ):
597600 quant_method , quant_method_kwargs = "int8_dynamic_activation_int8_weight" , {}
598601 expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
599- device = "cuda"
602+ device = torch_device
600603 self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
601604 self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
602605
603606 def test_int_a16w8_cuda (self ):
604607 quant_method , quant_method_kwargs = "int8_weight_only" , {}
605608 expected_slice = np .array ([0.3613 , - 0.127 , - 0.0223 , - 0.2539 , - 0.459 , 0.4961 , - 0.1357 , - 0.6992 , 0.4551 ])
606- device = "cuda"
609+ device = torch_device
607610 self ._test_original_model_expected_slice (quant_method , quant_method_kwargs , expected_slice )
608611 self ._check_serialization_expected_slice (quant_method , quant_method_kwargs , expected_slice , device )
609612
@@ -624,14 +627,14 @@ def test_int_a16w8_cpu(self):
624627
625628# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
626629@require_torch
627- @require_torch_gpu
630+ @require_torch_accelerator
628631@require_torchao_version_greater_or_equal ("0.7.0" )
629632@slow
630633@nightly
631634class SlowTorchAoTests (unittest .TestCase ):
632635 def tearDown (self ):
633636 gc .collect ()
634- torch . cuda . empty_cache ( )
637+ backend_empty_cache ( torch_device )
635638
636639 def get_dummy_components (self , quantization_config : TorchAoConfig ):
637640 # This is just for convenience, so that we can modify it at one place for custom environments and locally testing
@@ -713,8 +716,8 @@ def test_quantization(self):
713716 quantization_config = TorchAoConfig (quant_type = quantization_name , modules_to_not_convert = ["x_embedder" ])
714717 self ._test_quant_type (quantization_config , expected_slice )
715718 gc .collect ()
716- torch . cuda . empty_cache ( )
717- torch . cuda . synchronize ( )
719+ backend_empty_cache ( torch_device )
720+ backend_synchronize ( torch_device )
718721
719722 def test_serialization_int8wo (self ):
720723 quantization_config = TorchAoConfig ("int8wo" )
@@ -733,8 +736,8 @@ def test_serialization_int8wo(self):
733736 pipe .remove_all_hooks ()
734737 del pipe .transformer
735738 gc .collect ()
736- torch . cuda . empty_cache ( )
737- torch . cuda . synchronize ( )
739+ backend_empty_cache ( torch_device )
740+ backend_synchronize ( torch_device )
738741 transformer = FluxTransformer2DModel .from_pretrained (
739742 tmp_dir , torch_dtype = torch .bfloat16 , use_safetensors = False
740743 )
@@ -783,14 +786,14 @@ def test_memory_footprint_int8wo(self):
783786
784787
785788@require_torch
786- @require_torch_gpu
789+ @require_torch_accelerator
787790@require_torchao_version_greater_or_equal ("0.7.0" )
788791@slow
789792@nightly
790793class SlowTorchAoPreserializedModelTests (unittest .TestCase ):
791794 def tearDown (self ):
792795 gc .collect ()
793- torch . cuda . empty_cache ( )
796+ backend_empty_cache ( torch_device )
794797
795798 def get_dummy_inputs (self , device : torch .device , seed : int = 0 ):
796799 if str (device ).startswith ("mps" ):
0 commit comments