diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 12eef8899bbb..3c8911773e39 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -149,3 +149,13 @@ def apply_freeu( res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) return hidden_states, res_hidden_states + + +def get_torch_cuda_device_capability(): + if torch.cuda.is_available(): + device = torch.device("cuda") + compute_capability = torch.cuda.get_device_capability(device) + compute_capability = f"{compute_capability[0]}.{compute_capability[1]}" + return float(compute_capability) + else: + return None diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index b88b6f16b9fb..c3cb082b0ef1 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -68,6 +68,7 @@ torch_all_close, torch_device, ) +from diffusers.utils.torch_utils import get_torch_cuda_device_capability from ..others.test_utils import TOKEN, USER, is_staging_test @@ -1384,6 +1385,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype): @require_torch_gpu def test_layerwise_casting_memory(self): MB_TOLERANCE = 0.2 + LEAST_COMPUTE_CAPABILITY = 8.0 def reset_memory_stats(): gc.collect() @@ -1412,10 +1414,12 @@ def get_memory_usage(storage_dtype, compute_dtype): torch.float8_e4m3fn, torch.bfloat16 ) + compute_capability = get_torch_cuda_device_capability() self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) - # NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. - # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. - self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) + # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. + # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. + if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY: + self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few # bytes. This only happens for some models, so we allow a small tolerance. # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.