Skip to content

Commit 74a3164

Browse files
committed
enable group_offloading and PipelineDeviceAndDtypeStabilityTests on XPU,
all passed Signed-off-by: Matrix YAO <[email protected]>
1 parent 54cddc1 commit 74a3164

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
2323
from diffusers.utils import get_logger
2424
from diffusers.utils.import_utils import compare_versions
25-
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
25+
from diffusers.utils.testing_utils import backend_empty_cache, backend_max_memory_allocated, backend_reset_peak_memory_stats, require_torch_accelerator, require_torch_gpu, torch_device
2626

2727

2828
class DummyBlock(torch.nn.Module):
@@ -107,7 +107,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
107107
return x
108108

109109

110-
@require_torch_gpu
110+
@require_torch_accelerator
111111
class GroupOffloadTests(unittest.TestCase):
112112
in_features = 64
113113
hidden_features = 256
@@ -125,8 +125,8 @@ def tearDown(self):
125125
del self.model
126126
del self.input
127127
gc.collect()
128-
torch.cuda.empty_cache()
129-
torch.cuda.reset_peak_memory_stats()
128+
backend_empty_cache(torch_device)
129+
backend_reset_peak_memory_stats(torch_device)
130130

131131
def get_model(self):
132132
torch.manual_seed(0)
@@ -141,8 +141,8 @@ def test_offloading_forward_pass(self):
141141
@torch.no_grad()
142142
def run_forward(model):
143143
gc.collect()
144-
torch.cuda.empty_cache()
145-
torch.cuda.reset_peak_memory_stats()
144+
backend_empty_cache(torch_device)
145+
backend_reset_peak_memory_stats(torch_device)
146146
self.assertTrue(
147147
all(
148148
module._diffusers_hook.get_hook("group_offloading") is not None
@@ -152,7 +152,7 @@ def run_forward(model):
152152
)
153153
model.eval()
154154
output = model(self.input)[0].cpu()
155-
max_memory_allocated = torch.cuda.max_memory_allocated()
155+
max_memory_allocated = backend_max_memory_allocated(torch_device)
156156
return output, max_memory_allocated
157157

158158
self.model.to(torch_device)
@@ -187,10 +187,10 @@ def run_forward(model):
187187
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
188188

189189
# Memory assertions - offloading should reduce memory usage
190-
self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)
190+
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
191191

192-
def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
193-
if torch.device(torch_device).type != "cuda":
192+
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
193+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
194194
return
195195
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
196196
logger = get_logger("diffusers.models.modeling_utils")
@@ -199,8 +199,8 @@ def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
199199
self.model.to(torch_device)
200200
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
201201

202-
def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
203-
if torch.device(torch_device).type != "cuda":
202+
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
203+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
204204
return
205205
pipe = DummyPipeline(self.model)
206206
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
@@ -210,19 +210,20 @@ def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
210210
pipe.to(torch_device)
211211
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
212212

213-
def test_error_raised_if_streams_used_and_no_cuda_device(self):
214-
original_is_available = torch.cuda.is_available
215-
torch.cuda.is_available = lambda: False
213+
def test_error_raised_if_streams_used_and_no_accelerator_device(self):
214+
torch_accelerator_module = getattr(torch, torch_device)
215+
original_is_available = torch_accelerator_module.is_available
216+
torch_accelerator_module.is_available = lambda: False
216217
with self.assertRaises(ValueError):
217218
self.model.enable_group_offload(
218-
onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True
219+
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
219220
)
220-
torch.cuda.is_available = original_is_available
221+
torch_accelerator_module.is_available = original_is_available
221222

222223
def test_error_raised_if_supports_group_offloading_false(self):
223224
self.model._supports_group_offloading = False
224225
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
225-
self.model.enable_group_offload(onload_device=torch.device("cuda"))
226+
self.model.enable_group_offload(onload_device=torch.device(torch_device))
226227

227228
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
228229
pipe = DummyPipeline(self.model)
@@ -249,7 +250,7 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
249250
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
250251

251252
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
252-
if torch.device(torch_device).type != "cuda":
253+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
253254
return
254255
model = DummyModelWithMultipleBlocks(
255256
in_features=self.in_features,

tests/pipelines/test_pipeline_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
UNet2DConditionModel,
2020
)
2121
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
22-
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
22+
from diffusers.utils.testing_utils import require_torch_accelerator, require_torch_gpu, torch_device
2323

2424

2525
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -850,9 +850,9 @@ def test_video_to_video(self):
850850
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
851851

852852

853-
@require_torch_gpu
853+
@require_torch_accelerator
854854
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
855-
expected_pipe_device = torch.device("cuda:0")
855+
expected_pipe_device = torch.device(f"{torch_device}:0")
856856
expected_pipe_dtype = torch.float64
857857

858858
def get_dummy_components_image_generation(self):
@@ -921,8 +921,8 @@ def test_deterministic_device(self):
921921
pipe.to(device=torch_device, dtype=torch.float32)
922922

923923
pipe.unet.to(device="cpu")
924-
pipe.vae.to(device="cuda")
925-
pipe.text_encoder.to(device="cuda:0")
924+
pipe.vae.to(device=torch_device)
925+
pipe.text_encoder.to(device=f"{torch_device}:0")
926926

927927
pipe_device = pipe.device
928928

0 commit comments

Comments
 (0)