5959from  diffusers .utils .testing_utils  import  (
6060    CaptureLogger ,
6161    backend_empty_cache ,
62+     backend_synchronize ,
63+     backend_max_memory_allocated ,
64+     backend_reset_peak_memory_stats ,
6265    floats_tensor ,
6366    get_python_version ,
6467    is_torch_compile ,
6871    require_torch_2 ,
6972    require_torch_accelerator ,
7073    require_torch_accelerator_with_training ,
71-     require_torch_gpu ,
7274    require_torch_multi_accelerator ,
7375    run_test_in_subprocess ,
7476    slow ,
@@ -341,7 +343,7 @@ def test_weight_overwrite(self):
341343
342344        assert  model .config .in_channels  ==  9 
343345
344-     @require_torch_gpu  
346+     @require_torch_accelerator  
345347    def  test_keep_modules_in_fp32 (self ):
346348        r""" 
347349        A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16 
@@ -1480,16 +1482,16 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
14801482        test_layerwise_casting (torch .float8_e5m2 , torch .float32 )
14811483        test_layerwise_casting (torch .float8_e4m3fn , torch .bfloat16 )
14821484
1483-     @require_torch_gpu  
1485+     @require_torch_accelerator  
14841486    def  test_layerwise_casting_memory (self ):
14851487        MB_TOLERANCE  =  0.2 
14861488        LEAST_COMPUTE_CAPABILITY  =  8.0 
14871489
14881490        def  reset_memory_stats ():
14891491            gc .collect ()
1490-             torch . cuda . synchronize ( )
1491-             torch . cuda . empty_cache ( )
1492-             torch . cuda . reset_peak_memory_stats ( )
1492+             backend_synchronize ( torch_device )
1493+             backend_empty_cache ( torch_device )
1494+             backend_reset_peak_memory_stats ( torch_device )
14931495
14941496        def  get_memory_usage (storage_dtype , compute_dtype ):
14951497            torch .manual_seed (0 )
@@ -1502,7 +1504,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15021504            reset_memory_stats ()
15031505            model (** inputs_dict )
15041506            model_memory_footprint  =  model .get_memory_footprint ()
1505-             peak_inference_memory_allocated_mb  =  torch . cuda . max_memory_allocated ( ) /  1024 ** 2 
1507+             peak_inference_memory_allocated_mb  =  backend_max_memory_allocated ( torch_device ) /  1024 ** 2 
15061508
15071509            return  model_memory_footprint , peak_inference_memory_allocated_mb 
15081510
@@ -1512,7 +1514,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15121514            torch .float8_e4m3fn , torch .bfloat16 
15131515        )
15141516
1515-         compute_capability  =  get_torch_cuda_device_capability ()
1517+         compute_capability  =  get_torch_cuda_device_capability ()  if   torch_device   ==   "cuda"   else   None 
15161518        self .assertTrue (fp8_e4m3_bf16_memory_footprint  <  fp8_e4m3_fp32_memory_footprint  <  fp32_memory_footprint )
15171519        # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. 
15181520        # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. 
@@ -1527,7 +1529,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15271529        )
15281530
15291531    @parameterized .expand ([False , True ]) 
1530-     @require_torch_gpu  
1532+     @require_torch_accelerator  
15311533    def  test_group_offloading (self , record_stream ):
15321534        init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
15331535        torch .manual_seed (0 )
0 commit comments