2222from  diffusers .pipelines .pipeline_utils  import  DiffusionPipeline 
2323from  diffusers .utils  import  get_logger 
2424from  diffusers .utils .import_utils  import  compare_versions 
25- from  diffusers .utils .testing_utils  import  require_torch_gpu , torch_device 
25+ from  diffusers .utils .testing_utils  import  backend_empty_cache ,  backend_max_memory_allocated ,  backend_reset_peak_memory_stats ,  require_torch_accelerator ,  require_torch_gpu , torch_device 
2626
2727
2828class  DummyBlock (torch .nn .Module ):
@@ -107,7 +107,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
107107        return  x 
108108
109109
110- @require_torch_gpu  
110+ @require_torch_accelerator  
111111class  GroupOffloadTests (unittest .TestCase ):
112112    in_features  =  64 
113113    hidden_features  =  256 
@@ -125,8 +125,8 @@ def tearDown(self):
125125        del  self .model 
126126        del  self .input 
127127        gc .collect ()
128-         torch . cuda . empty_cache ( )
129-         torch . cuda . reset_peak_memory_stats ( )
128+         backend_empty_cache ( torch_device )
129+         backend_reset_peak_memory_stats ( torch_device )
130130
131131    def  get_model (self ):
132132        torch .manual_seed (0 )
@@ -141,8 +141,8 @@ def test_offloading_forward_pass(self):
141141        @torch .no_grad () 
142142        def  run_forward (model ):
143143            gc .collect ()
144-             torch . cuda . empty_cache ( )
145-             torch . cuda . reset_peak_memory_stats ( )
144+             backend_empty_cache ( torch_device )
145+             backend_reset_peak_memory_stats ( torch_device )
146146            self .assertTrue (
147147                all (
148148                    module ._diffusers_hook .get_hook ("group_offloading" ) is  not   None 
@@ -152,7 +152,7 @@ def run_forward(model):
152152            )
153153            model .eval ()
154154            output  =  model (self .input )[0 ].cpu ()
155-             max_memory_allocated  =  torch . cuda . max_memory_allocated ( )
155+             max_memory_allocated  =  backend_max_memory_allocated ( torch_device )
156156            return  output , max_memory_allocated 
157157
158158        self .model .to (torch_device )
@@ -187,10 +187,10 @@ def run_forward(model):
187187        self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading5 , atol = 1e-5 ))
188188
189189        # Memory assertions - offloading should reduce memory usage 
190-         self .assertTrue (mem4  <=  mem5  <  mem2  <  mem3  <  mem1  <  mem_baseline )
190+         self .assertTrue (mem4  <=  mem5  <  mem2  <=   mem3  <  mem1  <  mem_baseline )
191191
192-     def  test_warning_logged_if_group_offloaded_module_moved_to_cuda (self ):
193-         if  torch .device (torch_device ).type  !=   "cuda" :
192+     def  test_warning_logged_if_group_offloaded_module_moved_to_accelerator (self ):
193+         if  torch .device (torch_device ).type  not   in  [ "cuda" ,  "xpu" ] :
194194            return 
195195        self .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
196196        logger  =  get_logger ("diffusers.models.modeling_utils" )
@@ -199,8 +199,8 @@ def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
199199            self .model .to (torch_device )
200200        self .assertIn (f"The module '{ self .model .__class__ .__name__ }  ' is group offloaded" , cm .output [0 ])
201201
202-     def  test_warning_logged_if_group_offloaded_pipe_moved_to_cuda (self ):
203-         if  torch .device (torch_device ).type  !=   "cuda" :
202+     def  test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator (self ):
203+         if  torch .device (torch_device ).type  not   in  [ "cuda" ,  "xpu" ] :
204204            return 
205205        pipe  =  DummyPipeline (self .model )
206206        self .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
@@ -210,19 +210,20 @@ def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
210210            pipe .to (torch_device )
211211        self .assertIn (f"The module '{ self .model .__class__ .__name__ }  ' is group offloaded" , cm .output [0 ])
212212
213-     def  test_error_raised_if_streams_used_and_no_cuda_device (self ):
214-         original_is_available  =  torch .cuda .is_available 
215-         torch .cuda .is_available  =  lambda : False 
213+     def  test_error_raised_if_streams_used_and_no_accelerator_device (self ):
214+         torch_accelerator_module  =  getattr (torch , torch_device )
215+         original_is_available  =  torch_accelerator_module .is_available 
216+         torch_accelerator_module .is_available  =  lambda : False 
216217        with  self .assertRaises (ValueError ):
217218            self .model .enable_group_offload (
218-                 onload_device = torch .device ("cuda" ), offload_type = "leaf_level" , use_stream = True 
219+                 onload_device = torch .device (torch_device ), offload_type = "leaf_level" , use_stream = True 
219220            )
220-         torch . cuda .is_available  =  original_is_available 
221+         torch_accelerator_module .is_available  =  original_is_available 
221222
222223    def  test_error_raised_if_supports_group_offloading_false (self ):
223224        self .model ._supports_group_offloading  =  False 
224225        with  self .assertRaisesRegex (ValueError , "does not support group offloading" ):
225-             self .model .enable_group_offload (onload_device = torch .device ("cuda" ))
226+             self .model .enable_group_offload (onload_device = torch .device (torch_device ))
226227
227228    def  test_error_raised_if_model_offloading_applied_on_group_offloaded_module (self ):
228229        pipe  =  DummyPipeline (self .model )
@@ -249,7 +250,7 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
249250            pipe .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
250251
251252    def  test_block_level_stream_with_invocation_order_different_from_initialization_order (self ):
252-         if  torch .device (torch_device ).type  !=   "cuda" :
253+         if  torch .device (torch_device ).type  not   in  [ "cuda" ,  "xpu" ] :
253254            return 
254255        model  =  DummyModelWithMultipleBlocks (
255256            in_features = self .in_features ,
0 commit comments