3636from diffusers .utils .import_utils import is_xformers_available
3737from diffusers .utils .testing_utils import (
3838 backend_empty_cache ,
39+ backend_max_memory_allocated ,
40+ backend_reset_max_memory_allocated ,
41+ backend_reset_peak_memory_stats ,
3942 enable_full_determinism ,
4043 floats_tensor ,
4144 is_peft_available ,
@@ -1002,7 +1005,7 @@ def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
10021005 assert loaded_model
10031006 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10041007
1005- @require_torch_gpu
1008+ @require_torch_accelerator
10061009 def test_load_sharded_checkpoint_from_hub_local (self ):
10071010 _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
10081011 ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy" )
@@ -1013,7 +1016,7 @@ def test_load_sharded_checkpoint_from_hub_local(self):
10131016 assert loaded_model
10141017 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10151018
1016- @require_torch_gpu
1019+ @require_torch_accelerator
10171020 def test_load_sharded_checkpoint_from_hub_local_subfolder (self ):
10181021 _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
10191022 ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy-subfolder" )
@@ -1024,7 +1027,7 @@ def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
10241027 assert loaded_model
10251028 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10261029
1027- @require_torch_gpu
1030+ @require_torch_accelerator
10281031 @parameterized .expand (
10291032 [
10301033 ("hf-internal-testing/unet2d-sharded-dummy" , None ),
@@ -1039,7 +1042,7 @@ def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
10391042 assert loaded_model
10401043 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10411044
1042- @require_torch_gpu
1045+ @require_torch_accelerator
10431046 @parameterized .expand (
10441047 [
10451048 ("hf-internal-testing/unet2d-sharded-dummy-subfolder" , None ),
@@ -1054,7 +1057,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, va
10541057 assert loaded_model
10551058 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10561059
1057- @require_torch_gpu
1060+ @require_torch_accelerator
10581061 def test_load_sharded_checkpoint_device_map_from_hub_local (self ):
10591062 _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
10601063 ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy" )
@@ -1064,7 +1067,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
10641067 assert loaded_model
10651068 assert new_output .sample .shape == (4 , 4 , 16 , 16 )
10661069
1067- @require_torch_gpu
1070+ @require_torch_accelerator
10681071 def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder (self ):
10691072 _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
10701073 ckpt_path = snapshot_download ("hf-internal-testing/unet2d-sharded-dummy-subfolder" )
@@ -1164,11 +1167,11 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
11641167
11651168 return model
11661169
1167- @require_torch_gpu
1170+ @require_torch_accelerator
11681171 def test_set_attention_slice_auto (self ):
1169- torch . cuda . empty_cache ( )
1170- torch . cuda . reset_max_memory_allocated ( )
1171- torch . cuda . reset_peak_memory_stats ( )
1172+ backend_empty_cache ( torch_device )
1173+ backend_reset_max_memory_allocated ( torch_device )
1174+ backend_reset_peak_memory_stats ( torch_device )
11721175
11731176 unet = self .get_unet_model ()
11741177 unet .set_attention_slice ("auto" )
@@ -1180,15 +1183,15 @@ def test_set_attention_slice_auto(self):
11801183 with torch .no_grad ():
11811184 _ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
11821185
1183- mem_bytes = torch . cuda . max_memory_allocated ( )
1186+ mem_bytes = backend_max_memory_allocated ( torch_device )
11841187
11851188 assert mem_bytes < 5 * 10 ** 9
11861189
1187- @require_torch_gpu
1190+ @require_torch_accelerator
11881191 def test_set_attention_slice_max (self ):
1189- torch . cuda . empty_cache ( )
1190- torch . cuda . reset_max_memory_allocated ( )
1191- torch . cuda . reset_peak_memory_stats ( )
1192+ backend_empty_cache ( torch_device )
1193+ backend_reset_max_memory_allocated ( torch_device )
1194+ backend_reset_peak_memory_stats ( torch_device )
11921195
11931196 unet = self .get_unet_model ()
11941197 unet .set_attention_slice ("max" )
@@ -1200,15 +1203,15 @@ def test_set_attention_slice_max(self):
12001203 with torch .no_grad ():
12011204 _ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
12021205
1203- mem_bytes = torch . cuda . max_memory_allocated ( )
1206+ mem_bytes = backend_max_memory_allocated ( torch_device )
12041207
12051208 assert mem_bytes < 5 * 10 ** 9
12061209
1207- @require_torch_gpu
1210+ @require_torch_accelerator
12081211 def test_set_attention_slice_int (self ):
1209- torch . cuda . empty_cache ( )
1210- torch . cuda . reset_max_memory_allocated ( )
1211- torch . cuda . reset_peak_memory_stats ( )
1212+ backend_empty_cache ( torch_device )
1213+ backend_reset_max_memory_allocated ( torch_device )
1214+ backend_reset_peak_memory_stats ( torch_device )
12121215
12131216 unet = self .get_unet_model ()
12141217 unet .set_attention_slice (2 )
@@ -1220,15 +1223,15 @@ def test_set_attention_slice_int(self):
12201223 with torch .no_grad ():
12211224 _ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
12221225
1223- mem_bytes = torch . cuda . max_memory_allocated ( )
1226+ mem_bytes = backend_max_memory_allocated ( torch_device )
12241227
12251228 assert mem_bytes < 5 * 10 ** 9
12261229
1227- @require_torch_gpu
1230+ @require_torch_accelerator
12281231 def test_set_attention_slice_list (self ):
1229- torch . cuda . empty_cache ( )
1230- torch . cuda . reset_max_memory_allocated ( )
1231- torch . cuda . reset_peak_memory_stats ( )
1232+ backend_empty_cache ( torch_device )
1233+ backend_reset_max_memory_allocated ( torch_device )
1234+ backend_reset_peak_memory_stats ( torch_device )
12321235
12331236 # there are 32 sliceable layers
12341237 slice_list = 16 * [2 , 3 ]
@@ -1242,7 +1245,7 @@ def test_set_attention_slice_list(self):
12421245 with torch .no_grad ():
12431246 _ = unet (latents , timestep = timestep , encoder_hidden_states = encoder_hidden_states ).sample
12441247
1245- mem_bytes = torch . cuda . max_memory_allocated ( )
1248+ mem_bytes = backend_max_memory_allocated ( torch_device )
12461249
12471250 assert mem_bytes < 5 * 10 ** 9
12481251
0 commit comments