|
38 | 38 | from diffusers.pipelines.pipeline_utils import StableDiffusionMixin |
39 | 39 | from diffusers.schedulers import KarrasDiffusionSchedulers |
40 | 40 | from diffusers.utils import logging |
41 | | -from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available |
| 41 | +from diffusers.utils.import_utils import is_xformers_available |
42 | 42 | from diffusers.utils.testing_utils import ( |
43 | 43 | CaptureLogger, |
| 44 | + require_accelerate_version_greater, |
| 45 | + require_non_cpu, |
44 | 46 | require_torch, |
45 | 47 | skip_mps, |
46 | 48 | torch_device, |
@@ -770,10 +772,8 @@ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3): |
770 | 772 | type(proc) == AttnProcessor for proc in component.attn_processors.values() |
771 | 773 | ), "`from_pipe` changed the attention processor in original pipeline." |
772 | 774 |
|
773 | | - @unittest.skipIf( |
774 | | - torch_device="cpu" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), |
775 | | - reason="CPU offload is only available with hardware accelerator and `accelerate v0.14.0` or higher", |
776 | | - ) |
| 775 | + @require_non_cpu |
| 776 | + @require_accelerate_version_greater("0.14.0") |
777 | 777 | def test_from_pipe_consistent_forward_pass_cpu_offload(self, expected_max_diff=1e-3): |
778 | 778 | components = self.get_dummy_components() |
779 | 779 | pipe = self.pipeline_class(**components) |
@@ -1201,7 +1201,7 @@ def test_components_function(self): |
1201 | 1201 | self.assertTrue(hasattr(pipe, "components")) |
1202 | 1202 | self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) |
1203 | 1203 |
|
1204 | | - @unittest.skipIf(torch_device="cpu", reason="float16 requires a hardware accelerator") |
| 1204 | + @require_non_cpu |
1205 | 1205 | def test_float16_inference(self, expected_max_diff=5e-2): |
1206 | 1206 | components = self.get_dummy_components() |
1207 | 1207 | pipe = self.pipeline_class(**components) |
@@ -1238,7 +1238,7 @@ def test_float16_inference(self, expected_max_diff=5e-2): |
1238 | 1238 | max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() |
1239 | 1239 | self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") |
1240 | 1240 |
|
1241 | | - @unittest.skipIf(torch_device="cpu", reason="float16 requires a hardware accelerator") |
| 1241 | + @require_non_cpu |
1242 | 1242 | def test_save_load_float16(self, expected_max_diff=1e-2): |
1243 | 1243 | components = self.get_dummy_components() |
1244 | 1244 | for name, module in components.items(): |
@@ -1319,7 +1319,7 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): |
1319 | 1319 | max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() |
1320 | 1320 | self.assertLess(max_diff, expected_max_difference) |
1321 | 1321 |
|
1322 | | - @unittest.skipIf(torch_device="cpu", reason="Hardware accelerator and CPU are required to switch devices") |
| 1322 | + @require_non_cpu |
1323 | 1323 | def test_to_device(self): |
1324 | 1324 | components = self.get_dummy_components() |
1325 | 1325 | pipe = self.pipeline_class(**components) |
@@ -1393,10 +1393,8 @@ def _test_attention_slicing_forward_pass( |
1393 | 1393 | assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0])) |
1394 | 1394 | assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0])) |
1395 | 1395 |
|
1396 | | - @unittest.skipIf( |
1397 | | - torch_device="cpu" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), |
1398 | | - reason="CPU offload is only available with hardware accelerator and `accelerate v0.14.0` or higher", |
1399 | | - ) |
| 1396 | + @require_non_cpu |
| 1397 | + @require_accelerate_version_greater("0.14.0") |
1400 | 1398 | def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): |
1401 | 1399 | import accelerate |
1402 | 1400 |
|
@@ -1456,10 +1454,8 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4): |
1456 | 1454 | f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", |
1457 | 1455 | ) |
1458 | 1456 |
|
1459 | | - @unittest.skipIf( |
1460 | | - torch_device="cpu" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), |
1461 | | - reason="CPU offload is only available with hardware accelerator and `accelerate v0.17.0` or higher", |
1462 | | - ) |
| 1457 | + @require_non_cpu |
| 1458 | + @require_accelerate_version_greater("0.17.0") |
1463 | 1459 | def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): |
1464 | 1460 | import accelerate |
1465 | 1461 |
|
@@ -1513,10 +1509,8 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4): |
1513 | 1509 | f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", |
1514 | 1510 | ) |
1515 | 1511 |
|
1516 | | - @unittest.skipIf( |
1517 | | - torch_device="cpu" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"), |
1518 | | - reason="CPU offload is only available with hardware accelerator and `accelerate v0.17.0` or higher", |
1519 | | - ) |
| 1512 | + @require_non_cpu |
| 1513 | + @require_accelerate_version_greater("0.17.0") |
1520 | 1514 | def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4): |
1521 | 1515 | import accelerate |
1522 | 1516 |
|
@@ -1570,10 +1564,8 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4): |
1570 | 1564 | f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}", |
1571 | 1565 | ) |
1572 | 1566 |
|
1573 | | - @unittest.skipIf( |
1574 | | - torch_device="cpu" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), |
1575 | | - reason="CPU offload is only available with hardware accelerator and `accelerate v0.14.0` or higher", |
1576 | | - ) |
| 1567 | + @require_non_cpu |
| 1568 | + @require_accelerate_version_greater("0.14.0") |
1577 | 1569 | def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4): |
1578 | 1570 | import accelerate |
1579 | 1571 |
|
|
0 commit comments