|
48 | 48 | require_torch_2, |
49 | 49 | require_torch_accelerator_with_training, |
50 | 50 | require_torch_gpu, |
| 51 | + require_torch_accelerator, |
51 | 52 | require_torch_multi_gpu, |
52 | 53 | run_test_in_subprocess, |
53 | 54 | torch_device, |
@@ -405,7 +406,7 @@ def test_set_xformers_attn_processor_for_determinism(self): |
405 | 406 | assert torch.allclose(output, output_3, atol=self.base_precision) |
406 | 407 | assert torch.allclose(output_2, output_3, atol=self.base_precision) |
407 | 408 |
|
408 | | - @require_torch_gpu |
| 409 | + @require_torch_accelerator |
409 | 410 | def test_set_attn_processor_for_determinism(self): |
410 | 411 | if self.uses_custom_attn_processor: |
411 | 412 | return |
@@ -752,7 +753,7 @@ def test_deprecated_kwargs(self): |
752 | 753 | " from `_deprecated_kwargs = [<deprecated_argument>]`" |
753 | 754 | ) |
754 | 755 |
|
755 | | - @require_torch_gpu |
| 756 | + @require_torch_accelerator |
756 | 757 | def test_cpu_offload(self): |
757 | 758 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
758 | 759 | model = self.model_class(**config).eval() |
@@ -782,7 +783,7 @@ def test_cpu_offload(self): |
782 | 783 |
|
783 | 784 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
784 | 785 |
|
785 | | - @require_torch_gpu |
| 786 | + @require_torch_accelerator |
786 | 787 | def test_disk_offload_without_safetensors(self): |
787 | 788 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
788 | 789 | model = self.model_class(**config).eval() |
@@ -816,7 +817,7 @@ def test_disk_offload_without_safetensors(self): |
816 | 817 |
|
817 | 818 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
818 | 819 |
|
819 | | - @require_torch_gpu |
| 820 | + @require_torch_accelerator |
820 | 821 | def test_disk_offload_with_safetensors(self): |
821 | 822 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
822 | 823 | model = self.model_class(**config).eval() |
@@ -875,7 +876,7 @@ def test_model_parallelism(self): |
875 | 876 |
|
876 | 877 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
877 | 878 |
|
878 | | - @require_torch_gpu |
| 879 | + @require_torch_accelerator |
879 | 880 | def test_sharded_checkpoints(self): |
880 | 881 | torch.manual_seed(0) |
881 | 882 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
@@ -907,7 +908,7 @@ def test_sharded_checkpoints(self): |
907 | 908 |
|
908 | 909 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
909 | 910 |
|
910 | | - @require_torch_gpu |
| 911 | + @require_torch_accelerator |
911 | 912 | def test_sharded_checkpoints_with_variant(self): |
912 | 913 | torch.manual_seed(0) |
913 | 914 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
@@ -944,7 +945,7 @@ def test_sharded_checkpoints_with_variant(self): |
944 | 945 |
|
945 | 946 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
946 | 947 |
|
947 | | - @require_torch_gpu |
| 948 | + @require_torch_accelerator |
948 | 949 | def test_sharded_checkpoints_device_map(self): |
949 | 950 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
950 | 951 | model = self.model_class(**config).eval() |
|
0 commit comments