diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7eda13716025..2a3feae967d7 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -101,6 +101,8 @@ mps_backend_registered = hasattr(torch.backends, "mps") torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device + from .torch_utils import get_torch_cuda_device_capability + def torch_all_close(a, b, *args, **kwargs): if not is_torch_available(): @@ -282,6 +284,20 @@ def require_torch_gpu(test_case): ) +def require_torch_cuda_compatibility(expected_compute_capability): + def decorator(test_case): + if not torch.cuda.is_available(): + return unittest.skip(test_case) + else: + current_compute_capability = get_torch_cuda_device_capability() + return unittest.skipUnless( + float(current_compute_capability) == float(expected_compute_capability), + "Test not supported for this compute capability.", + ) + + return decorator + + # These decorators are for accelerator-specific behaviours that are not GPU-specific def require_torch_accelerator(test_case): """Decorator marking a test that requires an accelerator backend and PyTorch.""" diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py index 51ca0bfdc0ab..9eb6958d2183 100644 --- a/tests/quantization/quanto/test_quanto.py +++ b/tests/quantization/quanto/test_quanto.py @@ -10,6 +10,7 @@ numpy_cosine_similarity_distance, require_accelerate, require_big_gpu_with_torch_cuda, + require_torch_cuda_compatibility, torch_device, ) @@ -311,6 +312,7 @@ def get_dummy_init_kwargs(self): return {"weights_dtype": "int8"} +@require_torch_cuda_compatibility(8.0) class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): expected_memory_reduction = 0.55 @@ -318,6 +320,7 @@ def get_dummy_init_kwargs(self): return {"weights_dtype": "int4"} +@require_torch_cuda_compatibility(8.0) class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): expected_memory_reduction = 0.65