Skip to content

Commit 88919c0

Browse files
committed
fix empty cache
1 parent 8d0f387 commit 88919c0

File tree

10 files changed

+29
-73
lines changed

10 files changed

+29
-73
lines changed

tests/pipelines/animatediff/test_animatediff.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from diffusers.models.attention import FreeNoiseTransformerBlock
2121
from diffusers.utils import is_xformers_available, logging
2222
from diffusers.utils.testing_utils import (
23+
backend_empty_cache,
2324
numpy_cosine_similarity_distance,
2425
require_accelerator,
2526
require_torch_accelerator,
@@ -553,19 +554,13 @@ def setUp(self):
553554
# clean up the VRAM before each test
554555
super().setUp()
555556
gc.collect()
556-
if torch_device == "cuda":
557-
torch.cuda.empty_cache()
558-
elif torch_device == "xpu":
559-
torch.xpu.empty_cache()
557+
backend_empty_cache(torch_device)
560558

561559
def tearDown(self):
562560
# clean up the VRAM after each test
563561
super().tearDown()
564562
gc.collect()
565-
if torch_device == "cuda":
566-
torch.cuda.empty_cache()
567-
elif torch_device == "xpu":
568-
torch.xpu.empty_cache()
563+
backend_empty_cache(torch_device)
569564

570565
def test_animatediff(self):
571566
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")

tests/pipelines/cogvideo/test_cogvideox_image2video.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
2525
from diffusers.utils import load_image
2626
from diffusers.utils.testing_utils import (
27+
backend_empty_cache,
2728
enable_full_determinism,
2829
numpy_cosine_similarity_distance,
2930
require_torch_accelerator,
@@ -351,18 +352,12 @@ class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase):
351352
def setUp(self):
352353
super().setUp()
353354
gc.collect()
354-
if torch_device == "cuda":
355-
torch.cuda.empty_cache()
356-
elif torch_device == "xpu":
357-
torch.xpu.empty_cache()
355+
backend_empty_cache(torch_device)
358356

359357
def tearDown(self):
360358
super().tearDown()
361359
gc.collect()
362-
if torch_device == "cuda":
363-
torch.cuda.empty_cache()
364-
elif torch_device == "xpu":
365-
torch.xpu.empty_cache()
360+
backend_empty_cache(torch_device)
366361

367362
def test_cogvideox(self):
368363
generator = torch.Generator("cpu").manual_seed(0)

tests/pipelines/controlnet/test_controlnet.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
3535
from diffusers.utils.import_utils import is_xformers_available
3636
from diffusers.utils.testing_utils import (
37+
backend_empty_cache,
3738
enable_full_determinism,
3839
get_python_version,
3940
is_torch_compile,
@@ -705,18 +706,12 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
705706
def setUp(self):
706707
super().setUp()
707708
gc.collect()
708-
if torch_device == "cuda":
709-
torch.cuda.empty_cache()
710-
elif torch_device == "xpu":
711-
torch.xpu.empty_cache()
709+
backend_empty_cache(torch_device)
712710

713711
def tearDown(self):
714712
super().tearDown()
715713
gc.collect()
716-
if torch_device == "cuda":
717-
torch.cuda.empty_cache()
718-
elif torch_device == "xpu":
719-
torch.xpu.empty_cache()
714+
backend_empty_cache(torch_device)
720715

721716
def test_canny(self):
722717
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

tests/pipelines/controlnet/test_controlnet_sdxl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
3636
from diffusers.utils.import_utils import is_xformers_available
3737
from diffusers.utils.testing_utils import (
38+
backend_empty_cache,
3839
enable_full_determinism,
3940
load_image,
4041
require_torch_accelerator,
@@ -894,18 +895,12 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
894895
def setUp(self):
895896
super().setUp()
896897
gc.collect()
897-
if torch_device == "cuda":
898-
torch.cuda.empty_cache()
899-
elif torch_device == "xpu":
900-
torch.xpu.empty_cache()
898+
backend_empty_cache(torch_device)
901899

902900
def tearDown(self):
903901
super().tearDown()
904902
gc.collect()
905-
if torch_device == "cuda":
906-
torch.cuda.empty_cache()
907-
elif torch_device == "xpu":
908-
torch.xpu.empty_cache()
903+
backend_empty_cache(torch_device)
909904

910905
def test_canny(self):
911906
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")

tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from diffusers.models import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
3030
from diffusers.utils import load_image
3131
from diffusers.utils.testing_utils import (
32+
backend_empty_cache,
3233
enable_full_determinism,
3334
require_torch_accelerator,
3435
slow,
@@ -185,18 +186,12 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
185186
def setUp(self):
186187
super().setUp()
187188
gc.collect()
188-
if torch_device == "cuda":
189-
torch.cuda.empty_cache()
190-
elif torch_device == "xpu":
191-
torch.xpu.empty_cache()
189+
backend_empty_cache(torch_device)
192190

193191
def tearDown(self):
194192
super().tearDown()
195193
gc.collect()
196-
if torch_device == "cuda":
197-
torch.cuda.empty_cache()
198-
elif torch_device == "xpu":
199-
torch.xpu.empty_cache()
194+
backend_empty_cache(torch_device)
200195

201196
def test_canny(self):
202197
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(

tests/pipelines/controlnet_xs/test_controlnetxs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from diffusers.utils.import_utils import is_xformers_available
3636
from diffusers.utils.testing_utils import (
37+
backend_empty_cache,
3738
enable_full_determinism,
3839
is_torch_compile,
3940
load_image,
@@ -339,10 +340,7 @@ class ControlNetXSPipelineSlowTests(unittest.TestCase):
339340
def tearDown(self):
340341
super().tearDown()
341342
gc.collect()
342-
if torch_device == "cuda":
343-
torch.cuda.empty_cache()
344-
elif torch_device == "xpu":
345-
torch.xpu.empty_cache()
343+
backend_empty_cache(torch_device)
346344

347345
def test_canny(self):
348346
controlnet = ControlNetXSAdapter.from_pretrained(

tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from diffusers.utils.import_utils import is_xformers_available
3434
from diffusers.utils.testing_utils import (
35+
backend_empty_cache,
3536
enable_full_determinism,
3637
load_image,
3738
require_torch_accelerator,
@@ -380,10 +381,7 @@ class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
380381
def tearDown(self):
381382
super().tearDown()
382383
gc.collect()
383-
if torch_device == "cuda":
384-
torch.cuda.empty_cache()
385-
elif torch_device == "xpu":
386-
torch.xpu.empty_cache()
384+
backend_empty_cache(torch_device)
387385

388386
def test_canny(self):
389387
controlnet = ControlNetXSAdapter.from_pretrained(

tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from diffusers.models.attention_processor import AttnAddedKVProcessor
2424
from diffusers.utils.import_utils import is_xformers_available
2525
from diffusers.utils.testing_utils import (
26+
backend_empty_cache,
2627
floats_tensor,
2728
load_numpy,
2829
require_accelerator,
@@ -105,10 +106,7 @@ def setUp(self):
105106
# clean up the VRAM before each test
106107
super().setUp()
107108
gc.collect()
108-
if torch_device == "cuda":
109-
torch.cuda.empty_cache()
110-
elif torch_device == "xpu":
111-
torch.xpu.empty_cache()
109+
backend_empty_cache(torch_device)
112110

113111
def tearDown(self):
114112
# clean up the VRAM after each test

tests/pipelines/i2vgen_xl/test_i2vgenxl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from diffusers.models.unets import I2VGenXLUNet
3737
from diffusers.utils import is_xformers_available, load_image
3838
from diffusers.utils.testing_utils import (
39+
backend_empty_cache,
3940
enable_full_determinism,
4041
floats_tensor,
4142
numpy_cosine_similarity_distance,
@@ -232,19 +233,13 @@ def setUp(self):
232233
# clean up the VRAM before each test
233234
super().setUp()
234235
gc.collect()
235-
if torch_device == "cuda":
236-
torch.cuda.empty_cache()
237-
elif torch_device == "xpu":
238-
torch.xpu.empty_cache()
236+
backend_empty_cache(torch_device)
239237

240238
def tearDown(self):
241239
# clean up the VRAM after each test
242240
super().tearDown()
243241
gc.collect()
244-
if torch_device == "cuda":
245-
torch.cuda.empty_cache()
246-
elif torch_device == "xpu":
247-
torch.xpu.empty_cache()
242+
backend_empty_cache(torch_device)
248243

249244
def test_i2vgen_xl(self):
250245
pipe = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")

tests/pipelines/test_pipelines.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
)
6767
from diffusers.utils.testing_utils import (
6868
CaptureLogger,
69+
backend_empty_cache,
6970
enable_full_determinism,
7071
floats_tensor,
7172
get_python_version,
@@ -1820,19 +1821,13 @@ def setUp(self):
18201821
# clean up the VRAM before each test
18211822
super().setUp()
18221823
gc.collect()
1823-
if torch_device == "cuda":
1824-
torch.cuda.empty_cache()
1825-
elif torch_device == "xpu":
1826-
torch.xpu.empty_cache()
1824+
backend_empty_cache(torch_device)
18271825

18281826
def tearDown(self):
18291827
# clean up the VRAM after each test
18301828
super().tearDown()
18311829
gc.collect()
1832-
if torch_device == "cuda":
1833-
torch.cuda.empty_cache()
1834-
elif torch_device == "xpu":
1835-
torch.xpu.empty_cache()
1830+
backend_empty_cache(torch_device)
18361831

18371832
def test_smart_download(self):
18381833
model_id = "hf-internal-testing/unet-pipeline-dummy"
@@ -2057,16 +2052,13 @@ def setUp(self):
20572052
# clean up the VRAM before each test
20582053
super().setUp()
20592054
gc.collect()
2060-
if torch_device == "cuda":
2061-
torch.cuda.empty_cache()
2062-
elif torch_device == "xpu":
2063-
torch.xpu.empty_cache()
2055+
backend_empty_cache(torch_device)
20642056

20652057
def tearDown(self):
20662058
# clean up the VRAM after each test
20672059
super().tearDown()
20682060
gc.collect()
2069-
torch.cuda.empty_cache()
2061+
backend_empty_cache(torch_device)
20702062

20712063
def test_ddpm_ddim_equality_batched(self):
20722064
seed = 0

0 commit comments

Comments
 (0)