|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -import gc |
17 | 16 | import tempfile |
18 | 17 | import traceback |
19 | 18 | import unittest |
|
34 | 33 | from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel |
35 | 34 | from diffusers.utils.import_utils import is_xformers_available |
36 | 35 | from diffusers.utils.testing_utils import ( |
37 | | - backend_empty_cache, |
38 | 36 | enable_full_determinism, |
| 37 | + flush_memory, |
39 | 38 | get_python_version, |
40 | 39 | is_torch_compile, |
41 | 40 | load_image, |
@@ -704,13 +703,11 @@ def test_save_pretrained_raise_not_implemented_exception(self): |
704 | 703 | class ControlNetPipelineSlowTests(unittest.TestCase): |
705 | 704 | def setUp(self): |
706 | 705 | super().setUp() |
707 | | - gc.collect() |
708 | | - backend_empty_cache(torch_device) |
| 706 | + flush_memory(torch_device, gc_collect=True) |
709 | 707 |
|
710 | 708 | def tearDown(self): |
711 | 709 | super().tearDown() |
712 | | - gc.collect() |
713 | | - backend_empty_cache(torch_device) |
| 710 | + flush_memory(torch_device, gc_collect=True) |
714 | 711 |
|
715 | 712 | def test_canny(self): |
716 | 713 | controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") |
@@ -929,14 +926,7 @@ def test_seg(self): |
929 | 926 | assert np.abs(expected_image - image).max() < 8e-2 |
930 | 927 |
|
931 | 928 | def test_sequential_cpu_offloading(self): |
932 | | - if torch_device == "cuda": |
933 | | - torch.cuda.empty_cache() |
934 | | - torch.cuda.reset_max_memory_allocated() |
935 | | - torch.cuda.reset_peak_memory_stats() |
936 | | - elif torch_device == "xpu": |
937 | | - torch.xpu.empty_cache() |
938 | | - torch.xpu.reset_max_memory_allocated() |
939 | | - torch.xpu.reset_peak_memory_stats() |
| 929 | + flush_memory(torch_device, gc_collect=True, reset_mem_stats=True) |
940 | 930 |
|
941 | 931 | controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg") |
942 | 932 |
|
@@ -1077,13 +1067,11 @@ def test_v11_shuffle_global_pool_conditions(self): |
1077 | 1067 | class StableDiffusionMultiControlNetPipelineSlowTests(unittest.TestCase): |
1078 | 1068 | def setUp(self): |
1079 | 1069 | super().setUp() |
1080 | | - gc.collect() |
1081 | | - backend_empty_cache(torch_device) |
| 1070 | + flush_memory(torch_device, gc_collect=True) |
1082 | 1071 |
|
1083 | 1072 | def tearDown(self): |
1084 | 1073 | super().tearDown() |
1085 | | - gc.collect() |
1086 | | - backend_empty_cache(torch_device) |
| 1074 | + flush_memory(torch_device, gc_collect=True) |
1087 | 1075 |
|
1088 | 1076 | def test_pose_and_canny(self): |
1089 | 1077 | controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") |
|
0 commit comments