Skip to content

Commit 8702062

Browse files
committed
add 2 more
1 parent 59a00e4 commit 8702062

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
99
from diffusers.utils import load_image
1010
from diffusers.utils.testing_utils import (
11+
backend_empty_cache,
1112
enable_full_determinism,
1213
numpy_cosine_similarity_distance,
13-
require_torch_gpu,
14+
require_torch_accelerator,
1415
slow,
1516
torch_device,
1617
)
@@ -27,7 +28,7 @@
2728

2829

2930
@slow
30-
@require_torch_gpu
31+
@require_torch_accelerator
3132
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
3233
pipeline_class = StableDiffusionControlNetPipeline
3334
ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -39,12 +40,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
3940
def setUp(self):
4041
super().setUp()
4142
gc.collect()
42-
torch.cuda.empty_cache()
43+
backend_empty_cache(torch_device)
4344

4445
def tearDown(self):
4546
super().tearDown()
4647
gc.collect()
47-
torch.cuda.empty_cache()
48+
backend_empty_cache(torch_device)
4849

4950
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
5051
generator = torch.Generator(device=generator_device).manual_seed(seed)

tests/single_file/test_stable_diffusion_single_file.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
88
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
99
from diffusers.utils.testing_utils import (
10+
backend_empty_cache,
1011
enable_full_determinism,
11-
require_torch_gpu,
12+
require_torch_accelerator,
1213
slow,
14+
torch_device,
1315
)
1416

1517
from .single_file_testing_utils import (
@@ -23,7 +25,7 @@
2325

2426

2527
@slow
26-
@require_torch_gpu
28+
@require_torch_accelerator
2729
class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
2830
pipeline_class = StableDiffusionPipeline
2931
ckpt_path = "https://huggingface.co/Jiali/stable-diffusion-1.5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -35,12 +37,12 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
3537
def setUp(self):
3638
super().setUp()
3739
gc.collect()
38-
torch.cuda.empty_cache()
40+
backend_empty_cache(torch_device)
3941

4042
def tearDown(self):
4143
super().tearDown()
4244
gc.collect()
43-
torch.cuda.empty_cache()
45+
backend_empty_cache(torch_device)
4446

4547
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
4648
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -93,12 +95,12 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
9395
def setUp(self):
9496
super().setUp()
9597
gc.collect()
96-
torch.cuda.empty_cache()
98+
backend_empty_cache(torch_device)
9799

98100
def tearDown(self):
99101
super().tearDown()
100102
gc.collect()
101-
torch.cuda.empty_cache()
103+
backend_empty_cache(torch_device)
102104

103105
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
104106
generator = torch.Generator(device=generator_device).manual_seed(seed)

0 commit comments

Comments
 (0)