4646 get_python_version ,
4747 is_torch_compile ,
4848 require_torch_2 ,
49- require_torch_accelerator ,
5049 require_torch_accelerator_with_training ,
50+ require_torch_gpu ,
5151 require_torch_multi_gpu ,
5252 run_test_in_subprocess ,
5353 torch_device ,
@@ -405,7 +405,7 @@ def test_set_xformers_attn_processor_for_determinism(self):
405405 assert torch .allclose (output , output_3 , atol = self .base_precision )
406406 assert torch .allclose (output_2 , output_3 , atol = self .base_precision )
407407
408- @require_torch_accelerator
408+ @require_torch_gpu
409409 def test_set_attn_processor_for_determinism (self ):
410410 if self .uses_custom_attn_processor :
411411 return
@@ -752,7 +752,7 @@ def test_deprecated_kwargs(self):
752752 " from `_deprecated_kwargs = [<deprecated_argument>]`"
753753 )
754754
755- @require_torch_accelerator
755+ @require_torch_gpu
756756 def test_cpu_offload (self ):
757757 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
758758 model = self .model_class (** config ).eval ()
@@ -782,7 +782,7 @@ def test_cpu_offload(self):
782782
783783 self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
784784
785- @require_torch_accelerator
785+ @require_torch_gpu
786786 def test_disk_offload_without_safetensors (self ):
787787 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
788788 model = self .model_class (** config ).eval ()
@@ -816,7 +816,7 @@ def test_disk_offload_without_safetensors(self):
816816
817817 self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
818818
819- @require_torch_accelerator
819+ @require_torch_gpu
820820 def test_disk_offload_with_safetensors (self ):
821821 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
822822 model = self .model_class (** config ).eval ()
@@ -875,7 +875,7 @@ def test_model_parallelism(self):
875875
876876 self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
877877
878- @require_torch_accelerator
878+ @require_torch_gpu
879879 def test_sharded_checkpoints (self ):
880880 torch .manual_seed (0 )
881881 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
@@ -907,7 +907,7 @@ def test_sharded_checkpoints(self):
907907
908908 self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
909909
910- @require_torch_accelerator
910+ @require_torch_gpu
911911 def test_sharded_checkpoints_with_variant (self ):
912912 torch .manual_seed (0 )
913913 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
@@ -944,7 +944,7 @@ def test_sharded_checkpoints_with_variant(self):
944944
945945 self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
946946
947- @require_torch_accelerator
947+ @require_torch_gpu
948948 def test_sharded_checkpoints_device_map (self ):
949949 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
950950 model = self .model_class (** config ).eval ()
@@ -1066,4 +1066,4 @@ def test_push_to_hub_library_name(self):
10661066 assert model_card .library_name == "diffusers"
10671067
10681068 # Reset repo
1069- delete_repo (self .repo_id , token = TOKEN )
1069+ delete_repo (self .repo_id , token = TOKEN )
0 commit comments