2222from diffusers .pipelines .pipeline_utils import DiffusionPipeline
2323from diffusers .utils import get_logger
2424from diffusers .utils .import_utils import compare_versions
25- from diffusers .utils .testing_utils import require_torch_gpu , torch_device
25+ from diffusers .utils .testing_utils import (
26+ backend_empty_cache ,
27+ backend_max_memory_allocated ,
28+ backend_reset_peak_memory_stats ,
29+ require_torch_accelerator ,
30+ torch_device ,
31+ )
2632
2733
2834class DummyBlock (torch .nn .Module ):
@@ -107,7 +113,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
107113 return x
108114
109115
110- @require_torch_gpu
116+ @require_torch_accelerator
111117class GroupOffloadTests (unittest .TestCase ):
112118 in_features = 64
113119 hidden_features = 256
@@ -125,8 +131,8 @@ def tearDown(self):
125131 del self .model
126132 del self .input
127133 gc .collect ()
128- torch . cuda . empty_cache ( )
129- torch . cuda . reset_peak_memory_stats ( )
134+ backend_empty_cache ( torch_device )
135+ backend_reset_peak_memory_stats ( torch_device )
130136
131137 def get_model (self ):
132138 torch .manual_seed (0 )
@@ -141,8 +147,8 @@ def test_offloading_forward_pass(self):
141147 @torch .no_grad ()
142148 def run_forward (model ):
143149 gc .collect ()
144- torch . cuda . empty_cache ( )
145- torch . cuda . reset_peak_memory_stats ( )
150+ backend_empty_cache ( torch_device )
151+ backend_reset_peak_memory_stats ( torch_device )
146152 self .assertTrue (
147153 all (
148154 module ._diffusers_hook .get_hook ("group_offloading" ) is not None
@@ -152,7 +158,7 @@ def run_forward(model):
152158 )
153159 model .eval ()
154160 output = model (self .input )[0 ].cpu ()
155- max_memory_allocated = torch . cuda . max_memory_allocated ( )
161+ max_memory_allocated = backend_max_memory_allocated ( torch_device )
156162 return output , max_memory_allocated
157163
158164 self .model .to (torch_device )
@@ -187,10 +193,10 @@ def run_forward(model):
187193 self .assertTrue (torch .allclose (output_without_group_offloading , output_with_group_offloading5 , atol = 1e-5 ))
188194
189195 # Memory assertions - offloading should reduce memory usage
190- self .assertTrue (mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline )
196+ self .assertTrue (mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline )
191197
192- def test_warning_logged_if_group_offloaded_module_moved_to_cuda (self ):
193- if torch .device (torch_device ).type != "cuda" :
198+ def test_warning_logged_if_group_offloaded_module_moved_to_accelerator (self ):
199+ if torch .device (torch_device ).type not in [ "cuda" , "xpu" ] :
194200 return
195201 self .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
196202 logger = get_logger ("diffusers.models.modeling_utils" )
@@ -199,8 +205,8 @@ def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
199205 self .model .to (torch_device )
200206 self .assertIn (f"The module '{ self .model .__class__ .__name__ } ' is group offloaded" , cm .output [0 ])
201207
202- def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda (self ):
203- if torch .device (torch_device ).type != "cuda" :
208+ def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator (self ):
209+ if torch .device (torch_device ).type not in [ "cuda" , "xpu" ] :
204210 return
205211 pipe = DummyPipeline (self .model )
206212 self .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
@@ -210,19 +216,20 @@ def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
210216 pipe .to (torch_device )
211217 self .assertIn (f"The module '{ self .model .__class__ .__name__ } ' is group offloaded" , cm .output [0 ])
212218
213- def test_error_raised_if_streams_used_and_no_cuda_device (self ):
214- original_is_available = torch .cuda .is_available
215- torch .cuda .is_available = lambda : False
219+ def test_error_raised_if_streams_used_and_no_accelerator_device (self ):
220+ torch_accelerator_module = getattr (torch , torch_device , torch .cuda )
221+ original_is_available = torch_accelerator_module .is_available
222+ torch_accelerator_module .is_available = lambda : False
216223 with self .assertRaises (ValueError ):
217224 self .model .enable_group_offload (
218- onload_device = torch .device ("cuda" ), offload_type = "leaf_level" , use_stream = True
225+ onload_device = torch .device (torch_device ), offload_type = "leaf_level" , use_stream = True
219226 )
220- torch . cuda .is_available = original_is_available
227+ torch_accelerator_module .is_available = original_is_available
221228
222229 def test_error_raised_if_supports_group_offloading_false (self ):
223230 self .model ._supports_group_offloading = False
224231 with self .assertRaisesRegex (ValueError , "does not support group offloading" ):
225- self .model .enable_group_offload (onload_device = torch .device ("cuda" ))
232+ self .model .enable_group_offload (onload_device = torch .device (torch_device ))
226233
227234 def test_error_raised_if_model_offloading_applied_on_group_offloaded_module (self ):
228235 pipe = DummyPipeline (self .model )
@@ -249,7 +256,7 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
249256 pipe .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
250257
251258 def test_block_level_stream_with_invocation_order_different_from_initialization_order (self ):
252- if torch .device (torch_device ).type != "cuda" :
259+ if torch .device (torch_device ).type not in [ "cuda" , "xpu" ] :
253260 return
254261 model = DummyModelWithMultipleBlocks (
255262 in_features = self .in_features ,
0 commit comments