Skip to content

Commit 9a6470a

Browse files
committed
enable semantic diffusion and stable diffusion panorama cases on XPU
Signed-off-by: Yao Matrix <[email protected]>
1 parent 58431f1 commit 9a6470a

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
2626
from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipeline as StableDiffusionPipeline
2727
from diffusers.utils.testing_utils import (
28+
backend_empty_cache,
2829
enable_full_determinism,
2930
floats_tensor,
3031
nightly,
31-
require_accelerator,
32-
require_torch_gpu,
32+
require_torch_accelerator,
3333
torch_device,
3434
)
3535

@@ -42,13 +42,13 @@ def setUp(self):
4242
# clean up the VRAM before each test
4343
super().setUp()
4444
gc.collect()
45-
torch.cuda.empty_cache()
45+
backend_empty_cache(torch_device)
4646

4747
def tearDown(self):
4848
# clean up the VRAM after each test
4949
super().tearDown()
5050
gc.collect()
51-
torch.cuda.empty_cache()
51+
backend_empty_cache(torch_device)
5252

5353
@property
5454
def dummy_image(self):
@@ -238,7 +238,7 @@ def test_semantic_diffusion_no_safety_checker(self):
238238
image = pipe("example prompt", num_inference_steps=2).images[0]
239239
assert image is not None
240240

241-
@require_accelerator
241+
@require_torch_accelerator
242242
def test_semantic_diffusion_fp16(self):
243243
"""Test that stable diffusion works with fp16"""
244244
unet = self.dummy_cond_unet
@@ -272,22 +272,21 @@ def test_semantic_diffusion_fp16(self):
272272

273273

274274
@nightly
275-
@require_torch_gpu
275+
@require_torch_accelerator
276276
class SemanticDiffusionPipelineIntegrationTests(unittest.TestCase):
277277
def setUp(self):
278278
# clean up the VRAM before each test
279279
super().setUp()
280280
gc.collect()
281-
torch.cuda.empty_cache()
281+
backend_empty_cache(torch_device)
282282

283283
def tearDown(self):
284284
# clean up the VRAM after each test
285285
super().tearDown()
286286
gc.collect()
287-
torch.cuda.empty_cache()
287+
backend_empty_cache(torch_device)
288288

289289
def test_positive_guidance(self):
290-
torch_device = "cuda"
291290
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
292291
pipe = pipe.to(torch_device)
293292
pipe.set_progress_bar_config(disable=None)
@@ -370,7 +369,6 @@ def test_positive_guidance(self):
370369
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
371370

372371
def test_negative_guidance(self):
373-
torch_device = "cuda"
374372
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
375373
pipe = pipe.to(torch_device)
376374
pipe.set_progress_bar_config(disable=None)
@@ -453,7 +451,6 @@ def test_negative_guidance(self):
453451
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
454452

455453
def test_multi_cond_guidance(self):
456-
torch_device = "cuda"
457454
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
458455
pipe = pipe.to(torch_device)
459456
pipe.set_progress_bar_config(disable=None)
@@ -536,7 +533,6 @@ def test_multi_cond_guidance(self):
536533
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
537534

538535
def test_guidance_fp16(self):
539-
torch_device = "cuda"
540536
pipe = StableDiffusionPipeline.from_pretrained(
541537
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
542538
)

tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,17 @@
2929
StableDiffusionPanoramaPipeline,
3030
UNet2DConditionModel,
3131
)
32-
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
32+
from diffusers.utils.testing_utils import (
33+
backend_empty_cache,
34+
backend_max_memory_allocated,
35+
backend_reset_max_memory_allocated,
36+
backend_reset_peak_memory_stats,
37+
enable_full_determinism,
38+
nightly,
39+
require_torch_accelerator,
40+
skip_mps,
41+
torch_device,
42+
)
3343

3444
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3545
from ..test_pipelines_common import (
@@ -267,17 +277,17 @@ def test_encode_prompt_works_in_isolation(self):
267277

268278

269279
@nightly
270-
@require_torch_gpu
280+
@require_torch_accelerator
271281
class StableDiffusionPanoramaNightlyTests(unittest.TestCase):
272282
def setUp(self):
273283
super().setUp()
274284
gc.collect()
275-
torch.cuda.empty_cache()
285+
backend_empty_cache(torch_device)
276286

277287
def tearDown(self):
278288
super().tearDown()
279289
gc.collect()
280-
torch.cuda.empty_cache()
290+
backend_empty_cache(torch_device)
281291

282292
def get_inputs(self, seed=0):
283293
generator = torch.manual_seed(seed)
@@ -415,9 +425,9 @@ def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
415425
assert number_of_steps == 3
416426

417427
def test_stable_diffusion_panorama_pipeline_with_sequential_cpu_offloading(self):
418-
torch.cuda.empty_cache()
419-
torch.cuda.reset_max_memory_allocated()
420-
torch.cuda.reset_peak_memory_stats()
428+
backend_empty_cache(torch_device)
429+
backend_reset_max_memory_allocated(torch_device)
430+
backend_reset_peak_memory_stats(torch_device)
421431

422432
model_ckpt = "stabilityai/stable-diffusion-2-base"
423433
scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
@@ -429,6 +439,6 @@ def test_stable_diffusion_panorama_pipeline_with_sequential_cpu_offloading(self)
429439
inputs = self.get_inputs()
430440
_ = pipe(**inputs)
431441

432-
mem_bytes = torch.cuda.max_memory_allocated()
442+
mem_bytes = backend_max_memory_allocated(torch_device)
433443
# make sure that less than 5.2 GB is allocated
434444
assert mem_bytes < 5.5 * 10**9

0 commit comments

Comments
 (0)