|
59 | 59 | require_torch_2, |
60 | 60 | require_torch_accelerator_with_training, |
61 | 61 | require_torch_gpu, |
| 62 | + require_torch_accelerator, |
62 | 63 | require_torch_multi_gpu, |
63 | 64 | run_test_in_subprocess, |
64 | 65 | torch_all_close, |
@@ -543,7 +544,7 @@ def test_set_xformers_attn_processor_for_determinism(self): |
543 | 544 | assert torch.allclose(output, output_3, atol=self.base_precision) |
544 | 545 | assert torch.allclose(output_2, output_3, atol=self.base_precision) |
545 | 546 |
|
546 | | - @require_torch_gpu |
| 547 | + @require_torch_accelerator |
547 | 548 | def test_set_attn_processor_for_determinism(self): |
548 | 549 | if self.uses_custom_attn_processor: |
549 | 550 | return |
@@ -1068,7 +1069,7 @@ def test_wrong_adapter_name_raises_error(self): |
1068 | 1069 |
|
1069 | 1070 | self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) |
1070 | 1071 |
|
1071 | | - @require_torch_gpu |
| 1072 | + @require_torch_accelerator |
1072 | 1073 | def test_cpu_offload(self): |
1073 | 1074 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1074 | 1075 | model = self.model_class(**config).eval() |
@@ -1098,7 +1099,7 @@ def test_cpu_offload(self): |
1098 | 1099 |
|
1099 | 1100 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1100 | 1101 |
|
1101 | | - @require_torch_gpu |
| 1102 | + @require_torch_accelerator |
1102 | 1103 | def test_disk_offload_without_safetensors(self): |
1103 | 1104 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1104 | 1105 | model = self.model_class(**config).eval() |
@@ -1132,7 +1133,7 @@ def test_disk_offload_without_safetensors(self): |
1132 | 1133 |
|
1133 | 1134 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1134 | 1135 |
|
1135 | | - @require_torch_gpu |
| 1136 | + @require_torch_accelerator |
1136 | 1137 | def test_disk_offload_with_safetensors(self): |
1137 | 1138 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1138 | 1139 | model = self.model_class(**config).eval() |
@@ -1191,7 +1192,7 @@ def test_model_parallelism(self): |
1191 | 1192 |
|
1192 | 1193 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1193 | 1194 |
|
1194 | | - @require_torch_gpu |
| 1195 | + @require_torch_accelerator |
1195 | 1196 | def test_sharded_checkpoints(self): |
1196 | 1197 | torch.manual_seed(0) |
1197 | 1198 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
@@ -1223,7 +1224,7 @@ def test_sharded_checkpoints(self): |
1223 | 1224 |
|
1224 | 1225 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1225 | 1226 |
|
1226 | | - @require_torch_gpu |
| 1227 | + @require_torch_accelerator |
1227 | 1228 | def test_sharded_checkpoints_with_variant(self): |
1228 | 1229 | torch.manual_seed(0) |
1229 | 1230 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
@@ -1261,7 +1262,7 @@ def test_sharded_checkpoints_with_variant(self): |
1261 | 1262 |
|
1262 | 1263 | self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) |
1263 | 1264 |
|
1264 | | - @require_torch_gpu |
| 1265 | + @require_torch_accelerator |
1265 | 1266 | def test_sharded_checkpoints_device_map(self): |
1266 | 1267 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1267 | 1268 | model = self.model_class(**config).eval() |
|
0 commit comments