|
57 | 57 | get_python_version, |
58 | 58 | is_torch_compile, |
59 | 59 | require_torch_2, |
| 60 | + require_torch_accelerator, |
60 | 61 | require_torch_accelerator_with_training, |
61 | | - require_torch_gpu, |
62 | 62 | require_torch_multi_gpu, |
63 | 63 | run_test_in_subprocess, |
64 | 64 | torch_all_close, |
@@ -543,7 +543,7 @@ def test_set_xformers_attn_processor_for_determinism(self): |
543 | 543 | assert torch.allclose(output, output_3, atol=self.base_precision) |
544 | 544 | assert torch.allclose(output_2, output_3, atol=self.base_precision) |
545 | 545 |
|
546 | | - @require_torch_gpu |
| 546 | + @require_torch_accelerator |
547 | 547 | def test_set_attn_processor_for_determinism(self): |
548 | 548 | if self.uses_custom_attn_processor: |
549 | 549 | return |
@@ -1068,7 +1068,7 @@ def test_wrong_adapter_name_raises_error(self): |
1068 | 1068 |
|
1069 | 1069 | self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) |
1070 | 1070 |
|
1071 | | - @require_torch_gpu |
| 1071 | + @require_torch_accelerator |
1072 | 1072 | def test_cpu_offload(self): |
1073 | 1073 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1074 | 1074 | model = self.model_class(**config).eval() |
@@ -1098,7 +1098,7 @@ def test_cpu_offload(self): |
1098 | 1098 |
|
1099 | 1099 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1100 | 1100 |
|
1101 | | - @require_torch_gpu |
| 1101 | + @require_torch_accelerator |
1102 | 1102 | def test_disk_offload_without_safetensors(self): |
1103 | 1103 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1104 | 1104 | model = self.model_class(**config).eval() |
@@ -1132,7 +1132,7 @@ def test_disk_offload_without_safetensors(self): |
1132 | 1132 |
|
1133 | 1133 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1134 | 1134 |
|
1135 | | - @require_torch_gpu |
| 1135 | + @require_torch_accelerator |
1136 | 1136 | def test_disk_offload_with_safetensors(self): |
1137 | 1137 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1138 | 1138 | model = self.model_class(**config).eval() |
@@ -1191,7 +1191,7 @@ def test_model_parallelism(self): |
1191 | 1191 |
|
1192 | 1192 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1193 | 1193 |
|
1194 | | - @require_torch_gpu |
| 1194 | + @require_torch_accelerator |
1195 | 1195 | def test_sharded_checkpoints(self): |
1196 | 1196 | torch.manual_seed(0) |
1197 | 1197 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
@@ -1223,7 +1223,7 @@ def test_sharded_checkpoints(self): |
1223 | 1223 |
|
1224 | 1224 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1225 | 1225 |
|
1226 | | - @require_torch_gpu |
| 1226 | + @require_torch_accelerator |
1227 | 1227 | def test_sharded_checkpoints_with_variant(self): |
1228 | 1228 | torch.manual_seed(0) |
1229 | 1229 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
@@ -1261,7 +1261,7 @@ def test_sharded_checkpoints_with_variant(self): |
1261 | 1261 |
|
1262 | 1262 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1263 | 1263 |
|
1264 | | - @require_torch_gpu |
| 1264 | + @require_torch_accelerator |
1265 | 1265 | def test_sharded_checkpoints_device_map(self): |
1266 | 1266 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1267 | 1267 | model = self.model_class(**config).eval() |
|
0 commit comments