|
66 | 66 | ) |
67 | 67 | from diffusers.utils.testing_utils import ( |
68 | 68 | CaptureLogger, |
| 69 | + backend_empty_cache, |
69 | 70 | enable_full_determinism, |
70 | 71 | floats_tensor, |
71 | 72 | get_python_version, |
@@ -1820,19 +1821,13 @@ def setUp(self): |
1820 | 1821 | # clean up the VRAM before each test |
1821 | 1822 | super().setUp() |
1822 | 1823 | gc.collect() |
1823 | | - if torch_device == "cuda": |
1824 | | - torch.cuda.empty_cache() |
1825 | | - elif torch_device == "xpu": |
1826 | | - torch.xpu.empty_cache() |
| 1824 | + backend_empty_cache(torch_device) |
1827 | 1825 |
|
1828 | 1826 | def tearDown(self): |
1829 | 1827 | # clean up the VRAM after each test |
1830 | 1828 | super().tearDown() |
1831 | 1829 | gc.collect() |
1832 | | - if torch_device == "cuda": |
1833 | | - torch.cuda.empty_cache() |
1834 | | - elif torch_device == "xpu": |
1835 | | - torch.xpu.empty_cache() |
| 1830 | + backend_empty_cache(torch_device) |
1836 | 1831 |
|
1837 | 1832 | def test_smart_download(self): |
1838 | 1833 | model_id = "hf-internal-testing/unet-pipeline-dummy" |
@@ -2057,16 +2052,13 @@ def setUp(self): |
2057 | 2052 | # clean up the VRAM before each test |
2058 | 2053 | super().setUp() |
2059 | 2054 | gc.collect() |
2060 | | - if torch_device == "cuda": |
2061 | | - torch.cuda.empty_cache() |
2062 | | - elif torch_device == "xpu": |
2063 | | - torch.xpu.empty_cache() |
| 2055 | + backend_empty_cache(torch_device) |
2064 | 2056 |
|
2065 | 2057 | def tearDown(self): |
2066 | 2058 | # clean up the VRAM after each test |
2067 | 2059 | super().tearDown() |
2068 | 2060 | gc.collect() |
2069 | | - torch.cuda.empty_cache() |
| 2061 | + backend_empty_cache(torch_device) |
2070 | 2062 |
|
2071 | 2063 | def test_ddpm_ddim_equality_batched(self): |
2072 | 2064 | seed = 0 |
|
0 commit comments