1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import gc
1617import random
1718import unittest
1819
2223from diffusers .models .attention_processor import AttnAddedKVProcessor
2324from diffusers .utils .import_utils import is_xformers_available
2425from diffusers .utils .testing_utils import (
26+ backend_empty_cache ,
27+ backend_max_memory_allocated ,
28+ backend_reset_max_memory_allocated ,
29+ backend_reset_peak_memory_stats ,
2530 floats_tensor ,
26- flush_memory ,
2731 load_numpy ,
2832 require_accelerator ,
2933 require_torch_accelerator ,
@@ -104,12 +108,14 @@ class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
104108 def setUp (self ):
105109 # clean up the VRAM before each test
106110 super ().setUp ()
107- flush_memory (torch_device , gc_collect = True )
111+ gc .collect ()
112+ backend_empty_cache (torch_device )
108113
109114 def tearDown (self ):
110115 # clean up the VRAM after each test
111116 super ().tearDown ()
112- flush_memory (torch_device , gc_collect = True )
117+ gc .collect ()
118+ backend_empty_cache (torch_device )
113119
114120 def test_if_img2img_superresolution (self ):
115121 pipe = IFImg2ImgSuperResolutionPipeline .from_pretrained (
@@ -120,7 +126,9 @@ def test_if_img2img_superresolution(self):
120126 pipe .unet .set_attn_processor (AttnAddedKVProcessor ())
121127 pipe .enable_model_cpu_offload (device = torch_device )
122128
123- flush_memory (torch_device , reset_mem_stats = True )
129+ backend_reset_max_memory_allocated (torch_device )
130+ backend_empty_cache (torch_device )
131+ backend_reset_peak_memory_stats (torch_device )
124132
125133 generator = torch .Generator (device = "cpu" ).manual_seed (0 )
126134
@@ -140,10 +148,7 @@ def test_if_img2img_superresolution(self):
140148
141149 assert image .shape == (256 , 256 , 3 )
142150
143- if torch_device == "cuda" :
144- mem_bytes = torch .cuda .max_memory_allocated ()
145- elif torch_device == "xpu" :
146- mem_bytes = torch .xpu .max_memory_allocated ()
151+ mem_bytes = backend_max_memory_allocated (torch_device )
147152
148153 assert mem_bytes < 12 * 10 ** 9
149154
0 commit comments