Skip to content

Commit f6ae056

Browse files
committed
enable on xpu
1 parent c091bcc commit f6ae056

File tree

16 files changed

+114
-102
lines changed

16 files changed

+114
-102
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,8 @@ def backend_device_count(device: str):
11371137
def backend_reset_peak_memory_stats(device: str):
11381138
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
11391139

1140+
def backend_reset_max_memory_allocated(device: str):
1141+
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
11401142

11411143
def backend_max_memory_allocated(device: str):
11421144
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from diffusers.utils.import_utils import is_xformers_available
3737
from diffusers.utils.testing_utils import (
3838
backend_empty_cache,
39+
backend_reset_max_memory_allocated,
40+
backend_reset_peak_memory_stats,
41+
backend_max_memory_allocated,
3942
enable_full_determinism,
4043
floats_tensor,
4144
is_peft_available,
@@ -1014,7 +1017,7 @@ def test_load_sharded_checkpoint_from_hub_local(self):
10141017
assert loaded_model
10151018
assert new_output.sample.shape == (4, 4, 16, 16)
10161019

1017-
@require_torch_gpu
1020+
@require_torch_accelerator
10181021
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
10191022
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10201023
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
@@ -1025,7 +1028,7 @@ def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
10251028
assert loaded_model
10261029
assert new_output.sample.shape == (4, 4, 16, 16)
10271030

1028-
@require_torch_gpu
1031+
@require_torch_accelerator
10291032
@parameterized.expand(
10301033
[
10311034
("hf-internal-testing/unet2d-sharded-dummy", None),
@@ -1040,7 +1043,7 @@ def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
10401043
assert loaded_model
10411044
assert new_output.sample.shape == (4, 4, 16, 16)
10421045

1043-
@require_torch_gpu
1046+
@require_torch_accelerator
10441047
@parameterized.expand(
10451048
[
10461049
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
@@ -1055,7 +1058,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, va
10551058
assert loaded_model
10561059
assert new_output.sample.shape == (4, 4, 16, 16)
10571060

1058-
@require_torch_gpu
1061+
@require_torch_accelerator
10591062
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10601063
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10611064
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
@@ -1065,7 +1068,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10651068
assert loaded_model
10661069
assert new_output.sample.shape == (4, 4, 16, 16)
10671070

1068-
@require_torch_gpu
1071+
@require_torch_accelerator
10691072
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
10701073
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10711074
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
@@ -1165,11 +1168,11 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
11651168

11661169
return model
11671170

1168-
@require_torch_gpu
1171+
@require_torch_accelerator
11691172
def test_set_attention_slice_auto(self):
1170-
torch.cuda.empty_cache()
1171-
torch.cuda.reset_max_memory_allocated()
1172-
torch.cuda.reset_peak_memory_stats()
1173+
backend_empty_cache(torch_device)
1174+
backend_reset_max_memory_allocated(torch_device)
1175+
backend_reset_peak_memory_stats(torch_device)
11731176

11741177
unet = self.get_unet_model()
11751178
unet.set_attention_slice("auto")
@@ -1181,15 +1184,15 @@ def test_set_attention_slice_auto(self):
11811184
with torch.no_grad():
11821185
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
11831186

1184-
mem_bytes = torch.cuda.max_memory_allocated()
1187+
mem_bytes = backend_max_memory_allocated(torch_device)
11851188

11861189
assert mem_bytes < 5 * 10**9
11871190

1188-
@require_torch_gpu
1191+
@require_torch_accelerator
11891192
def test_set_attention_slice_max(self):
1190-
torch.cuda.empty_cache()
1191-
torch.cuda.reset_max_memory_allocated()
1192-
torch.cuda.reset_peak_memory_stats()
1193+
backend_empty_cache(torch_device)
1194+
backend_reset_max_memory_allocated(torch_device)
1195+
backend_reset_peak_memory_stats(torch_device)
11931196

11941197
unet = self.get_unet_model()
11951198
unet.set_attention_slice("max")
@@ -1201,15 +1204,15 @@ def test_set_attention_slice_max(self):
12011204
with torch.no_grad():
12021205
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
12031206

1204-
mem_bytes = torch.cuda.max_memory_allocated()
1205-
1207+
mem_bytes = backend_max_memory_allocated(torch_device)
1208+
12061209
assert mem_bytes < 5 * 10**9
12071210

1208-
@require_torch_gpu
1211+
@require_torch_accelerator
12091212
def test_set_attention_slice_int(self):
1210-
torch.cuda.empty_cache()
1211-
torch.cuda.reset_max_memory_allocated()
1212-
torch.cuda.reset_peak_memory_stats()
1213+
backend_empty_cache(torch_device)
1214+
backend_reset_max_memory_allocated(torch_device)
1215+
backend_reset_peak_memory_stats(torch_device)
12131216

12141217
unet = self.get_unet_model()
12151218
unet.set_attention_slice(2)
@@ -1221,15 +1224,15 @@ def test_set_attention_slice_int(self):
12211224
with torch.no_grad():
12221225
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
12231226

1224-
mem_bytes = torch.cuda.max_memory_allocated()
1227+
mem_bytes = backend_max_memory_allocated(torch_device)
12251228

12261229
assert mem_bytes < 5 * 10**9
12271230

1228-
@require_torch_gpu
1231+
@require_torch_accelerator
12291232
def test_set_attention_slice_list(self):
1230-
torch.cuda.empty_cache()
1231-
torch.cuda.reset_max_memory_allocated()
1232-
torch.cuda.reset_peak_memory_stats()
1233+
backend_empty_cache(torch_device)
1234+
backend_reset_max_memory_allocated(torch_device)
1235+
backend_reset_peak_memory_stats(torch_device)
12331236

12341237
# there are 32 sliceable layers
12351238
slice_list = 16 * [2, 3]
@@ -1243,7 +1246,7 @@ def test_set_attention_slice_list(self):
12431246
with torch.no_grad():
12441247
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
12451248

1246-
mem_bytes = torch.cuda.max_memory_allocated()
1249+
mem_bytes = backend_max_memory_allocated(torch_device)
12471250

12481251
assert mem_bytes < 5 * 10**9
12491252

tests/pipelines/controlnet/test_controlnet_sdxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,12 @@ def test_stable_diffusion_xl_offloads(self):
222222

223223
components = self.get_dummy_components()
224224
sd_pipe = self.pipeline_class(**components)
225-
sd_pipe.enable_model_cpu_offload()
225+
sd_pipe.enable_model_cpu_offload(device=torch_device)
226226
pipes.append(sd_pipe)
227227

228228
components = self.get_dummy_components()
229229
sd_pipe = self.pipeline_class(**components)
230-
sd_pipe.enable_sequential_cpu_offload()
230+
sd_pipe.enable_sequential_cpu_offload(device=torch_device)
231231
pipes.append(sd_pipe)
232232

233233
image_slices = []

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
1111
from diffusers.utils.testing_utils import (
12+
backend_empty_cache,
1213
numpy_cosine_similarity_distance,
1314
require_big_gpu_with_torch_cuda,
1415
slow,
@@ -219,12 +220,12 @@ class FluxPipelineSlowTests(unittest.TestCase):
219220
def setUp(self):
220221
super().setUp()
221222
gc.collect()
222-
torch.cuda.empty_cache()
223+
backend_empty_cache(torch_device)
223224

224225
def tearDown(self):
225226
super().tearDown()
226227
gc.collect()
227-
torch.cuda.empty_cache()
228+
backend_empty_cache(torch_device)
228229

229230
def get_inputs(self, device, seed=0):
230231
if str(device).startswith("mps"):
@@ -254,7 +255,7 @@ def test_flux_inference(self):
254255
pipe = self.pipeline_class.from_pretrained(
255256
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
256257
)
257-
pipe.enable_model_cpu_offload()
258+
pipe.enable_model_cpu_offload(device=torch_device)
258259

259260
inputs = self.get_inputs(torch_device)
260261

tests/pipelines/pag/test_pag_sdxl_img2img.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@
3939
UNet2DConditionModel,
4040
)
4141
from diffusers.utils.testing_utils import (
42+
backend_empty_cache,
4243
enable_full_determinism,
4344
floats_tensor,
4445
load_image,
45-
require_torch_gpu,
46+
require_torch_accelerator,
4647
slow,
4748
torch_device,
4849
)
@@ -267,19 +268,19 @@ def test_pag_inference(self):
267268

268269

269270
@slow
270-
@require_torch_gpu
271+
@require_torch_accelerator
271272
class StableDiffusionXLPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
272273
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
273274

274275
def setUp(self):
275276
super().setUp()
276277
gc.collect()
277-
torch.cuda.empty_cache()
278+
backend_empty_cache(torch_device)
278279

279280
def tearDown(self):
280281
super().tearDown()
281282
gc.collect()
282-
torch.cuda.empty_cache()
283+
backend_empty_cache(torch_device)
283284

284285
def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0):
285286
img_url = (
@@ -303,7 +304,7 @@ def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0)
303304

304305
def test_pag_cfg(self):
305306
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
306-
pipeline.enable_model_cpu_offload()
307+
pipeline.enable_model_cpu_offload(device=torch_device)
307308
pipeline.set_progress_bar_config(disable=None)
308309

309310
inputs = self.get_inputs(torch_device)
@@ -320,7 +321,7 @@ def test_pag_cfg(self):
320321

321322
def test_pag_uncond(self):
322323
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
323-
pipeline.enable_model_cpu_offload()
324+
pipeline.enable_model_cpu_offload(device=torch_device)
324325
pipeline.set_progress_bar_config(disable=None)
325326

326327
inputs = self.get_inputs(torch_device, guidance_scale=0.0)

tests/pipelines/pag/test_pag_sdxl_inpaint.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@
4040
UNet2DConditionModel,
4141
)
4242
from diffusers.utils.testing_utils import (
43+
backend_empty_cache,
4344
enable_full_determinism,
4445
floats_tensor,
4546
load_image,
46-
require_torch_gpu,
47+
require_torch_accelerator,
4748
slow,
4849
torch_device,
4950
)
@@ -272,19 +273,19 @@ def test_pag_inference(self):
272273

273274

274275
@slow
275-
@require_torch_gpu
276+
@require_torch_accelerator
276277
class StableDiffusionXLPAGInpaintPipelineIntegrationTests(unittest.TestCase):
277278
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
278279

279280
def setUp(self):
280281
super().setUp()
281282
gc.collect()
282-
torch.cuda.empty_cache()
283+
backend_empty_cache(torch_device)
283284

284285
def tearDown(self):
285286
super().tearDown()
286287
gc.collect()
287-
torch.cuda.empty_cache()
288+
backend_empty_cache(torch_device)
288289

289290
def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0):
290291
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
@@ -309,7 +310,7 @@ def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0)
309310

310311
def test_pag_cfg(self):
311312
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
312-
pipeline.enable_model_cpu_offload()
313+
pipeline.enable_model_cpu_offload(device=torch_device)
313314
pipeline.set_progress_bar_config(disable=None)
314315

315316
inputs = self.get_inputs(torch_device)
@@ -326,7 +327,7 @@ def test_pag_cfg(self):
326327

327328
def test_pag_uncond(self):
328329
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
329-
pipeline.enable_model_cpu_offload()
330+
pipeline.enable_model_cpu_offload(device=torch_device)
330331
pipeline.set_progress_bar_config(disable=None)
331332

332333
inputs = self.get_inputs(torch_device, guidance_scale=0.0)

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ def test_stable_diffusion_attention_slicing(self):
988988
assert max_diff < 1e-3
989989

990990
def test_stable_diffusion_vae_slicing(self):
991-
torch.cuda.reset_peak_memory_stats()
991+
backend_reset_peak_memory_stats(torch_device)
992992
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
993993
pipe = pipe.to(torch_device)
994994
pipe.set_progress_bar_config(disable=None)

0 commit comments

Comments
 (0)