3636from diffusers .utils .import_utils import is_xformers_available
3737from diffusers .utils .testing_utils import (
3838 backend_empty_cache ,
39+ backend_reset_max_memory_allocated ,
40+ backend_reset_peak_memory_stats ,
41+ backend_max_memory_allocated ,
3942 enable_full_determinism ,
4043 floats_tensor ,
4144 is_peft_available ,
@@ -1014,7 +1017,7 @@ def test_load_sharded_checkpoint_from_hub_local(self):
10141017 assert loaded_model
10151018 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10161019
1017- @require_torch_gpu
1020+ @require_torch_accelerator
10181021 def test_load_sharded_checkpoint_from_hub_local_subfolder (self ):
10191022 _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
10201023 ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy-subfolder" )
@@ -1025,7 +1028,7 @@ def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
10251028 assert loaded_model
10261029 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10271030
1028- @require_torch_gpu
1031+ @require_torch_accelerator
10291032 @parameterized .expand (
10301033 [
10311034 ("hf-internal-testing/unet2d-sharded-dummy" , None ),
@@ -1040,7 +1043,7 @@ def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
10401043 assert loaded_model
10411044 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10421045
1043- @require_torch_gpu
1046+ @require_torch_accelerator
10441047 @parameterized .expand (
10451048 [
10461049 ("hf-internal-testing/unet2d-sharded-dummy-subfolder" , None ),
@@ -1055,7 +1058,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, va
10551058 assert loaded_model
10561059 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10571060
1058- @require_torch_gpu
1061+ @require_torch_accelerator
10591062 def test_load_sharded_checkpoint_device_map_from_hub_local (self ):
10601063 _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
10611064 ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy" )
@@ -1065,7 +1068,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10651068 assert loaded_model
10661069 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10671070
1068- @require_torch_gpu
1071+ @require_torch_accelerator
10691072 def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder (self ):
10701073 _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
10711074 ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy-subfolder" )
@@ -1165,11 +1168,11 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
11651168
11661169 return model
11671170
1168- @require_torch_gpu
1171+ @require_torch_accelerator
11691172 def test_set_attention_slice_auto (self ):
1170- torch . cuda . empty_cache ( )
1171- torch . cuda . reset_max_memory_allocated ( )
1172- torch . cuda . reset_peak_memory_stats ( )
1173+ backend_empty_cache ( torch_device )
1174+ backend_reset_max_memory_allocated ( torch_device )
1175+ backend_reset_peak_memory_stats ( torch_device )
11731176
11741177 unet = self .get_unet_model ()
11751178 unet .set_attention_slice ("auto" )
@@ -1181,15 +1184,15 @@ def test_set_attention_slice_auto(self):
11811184 with torch .no_grad ():
11821185 _ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
11831186
1184- mem_bytes = torch . cuda . max_memory_allocated ( )
1187+ mem_bytes = backend_max_memory_allocated ( torch_device )
11851188
11861189 assert mem_bytes < 5 * 10 ** 9
11871190
1188- @require_torch_gpu
1191+ @require_torch_accelerator
11891192 def test_set_attention_slice_max (self ):
1190- torch . cuda . empty_cache ( )
1191- torch . cuda . reset_max_memory_allocated ( )
1192- torch . cuda . reset_peak_memory_stats ( )
1193+ backend_empty_cache ( torch_device )
1194+ backend_reset_max_memory_allocated ( torch_device )
1195+ backend_reset_peak_memory_stats ( torch_device )
11931196
11941197 unet = self .get_unet_model ()
11951198 unet .set_attention_slice ("max" )
@@ -1201,15 +1204,15 @@ def test_set_attention_slice_max(self):
12011204 with torch .no_grad ():
12021205 _ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
12031206
1204- mem_bytes = torch . cuda . max_memory_allocated ( )
1205-
1207+ mem_bytes = backend_max_memory_allocated ( torch_device )
1208+
12061209 assert mem_bytes < 5 * 10 ** 9
12071210
1208- @require_torch_gpu
1211+ @require_torch_accelerator
12091212 def test_set_attention_slice_int (self ):
1210- torch . cuda . empty_cache ( )
1211- torch . cuda . reset_max_memory_allocated ( )
1212- torch . cuda . reset_peak_memory_stats ( )
1213+ backend_empty_cache ( torch_device )
1214+ backend_reset_max_memory_allocated ( torch_device )
1215+ backend_reset_peak_memory_stats ( torch_device )
12131216
12141217 unet = self .get_unet_model ()
12151218 unet .set_attention_slice (2 )
@@ -1221,15 +1224,15 @@ def test_set_attention_slice_int(self):
12211224 with torch .no_grad ():
12221225 _ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
12231226
1224- mem_bytes = torch . cuda . max_memory_allocated ( )
1227+ mem_bytes = backend_max_memory_allocated ( torch_device )
12251228
12261229 assert mem_bytes < 5 * 10 ** 9
12271230
1228- @require_torch_gpu
1231+ @require_torch_accelerator
12291232 def test_set_attention_slice_list (self ):
1230- torch . cuda . empty_cache ( )
1231- torch . cuda . reset_max_memory_allocated ( )
1232- torch . cuda . reset_peak_memory_stats ( )
1233+ backend_empty_cache ( torch_device )
1234+ backend_reset_max_memory_allocated ( torch_device )
1235+ backend_reset_peak_memory_stats ( torch_device )
12331236
12341237 # there are 32 sliceable layers
12351238 slice_list = 16 * [2 , 3 ]
@@ -1243,7 +1246,7 @@ def test_set_attention_slice_list(self):
12431246 with torch .no_grad ():
12441247 _ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
12451248
1246- mem_bytes = torch . cuda . max_memory_allocated ( )
1249+ mem_bytes = backend_max_memory_allocated ( torch_device )
12471250
12481251 assert mem_bytes < 5 * 10 ** 9
12491252
0 commit comments