Skip to content

Commit 3052847

Browse files
committed
revert
1 parent 28a73ac commit 3052847

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

tests/models/test_modeling_common.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
get_python_version,
4747
is_torch_compile,
4848
require_torch_2,
49-
require_torch_accelerator,
5049
require_torch_accelerator_with_training,
50+
require_torch_gpu,
5151
require_torch_multi_gpu,
5252
run_test_in_subprocess,
5353
torch_device,
@@ -405,7 +405,7 @@ def test_set_xformers_attn_processor_for_determinism(self):
405405
assert torch.allclose(output, output_3, atol=self.base_precision)
406406
assert torch.allclose(output_2, output_3, atol=self.base_precision)
407407

408-
@require_torch_accelerator
408+
@require_torch_gpu
409409
def test_set_attn_processor_for_determinism(self):
410410
if self.uses_custom_attn_processor:
411411
return
@@ -752,7 +752,7 @@ def test_deprecated_kwargs(self):
752752
" from `_deprecated_kwargs = [<deprecated_argument>]`"
753753
)
754754

755-
@require_torch_accelerator
755+
@require_torch_gpu
756756
def test_cpu_offload(self):
757757
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
758758
model = self.model_class(**config).eval()
@@ -782,7 +782,7 @@ def test_cpu_offload(self):
782782

783783
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
784784

785-
@require_torch_accelerator
785+
@require_torch_gpu
786786
def test_disk_offload_without_safetensors(self):
787787
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
788788
model = self.model_class(**config).eval()
@@ -816,7 +816,7 @@ def test_disk_offload_without_safetensors(self):
816816

817817
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
818818

819-
@require_torch_accelerator
819+
@require_torch_gpu
820820
def test_disk_offload_with_safetensors(self):
821821
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
822822
model = self.model_class(**config).eval()
@@ -875,7 +875,7 @@ def test_model_parallelism(self):
875875

876876
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
877877

878-
@require_torch_accelerator
878+
@require_torch_gpu
879879
def test_sharded_checkpoints(self):
880880
torch.manual_seed(0)
881881
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -907,7 +907,7 @@ def test_sharded_checkpoints(self):
907907

908908
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
909909

910-
@require_torch_accelerator
910+
@require_torch_gpu
911911
def test_sharded_checkpoints_with_variant(self):
912912
torch.manual_seed(0)
913913
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -944,7 +944,7 @@ def test_sharded_checkpoints_with_variant(self):
944944

945945
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
946946

947-
@require_torch_accelerator
947+
@require_torch_gpu
948948
def test_sharded_checkpoints_device_map(self):
949949
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
950950
model = self.model_class(**config).eval()
@@ -1066,4 +1066,4 @@ def test_push_to_hub_library_name(self):
10661066
assert model_card.library_name == "diffusers"
10671067

10681068
# Reset repo
1069-
delete_repo(self.repo_id, token=TOKEN)
1069+
delete_repo(self.repo_id, token=TOKEN)

0 commit comments

Comments
 (0)