Skip to content

Commit f7e85ae

Browse files
committed
conditionally check if compute capability is met.
1 parent c4d4ac2 commit f7e85ae

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/diffusers/utils/torch_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,13 @@ def apply_freeu(
149149
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
150150

151151
return hidden_states, res_hidden_states
152+
153+
154+
def get_torch_cuda_device_capability():
155+
if torch.cuda.is_available():
156+
device = torch.device("cuda")
157+
compute_capability = torch.cuda.get_device_capability(device)
158+
compute_capability = f"{compute_capability[0]}.{compute_capability[1]}"
159+
return float(compute_capability)
160+
else:
161+
return None

tests/models/test_modeling_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
torch_all_close,
6969
torch_device,
7070
)
71+
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
7172

7273
from ..others.test_utils import TOKEN, USER, is_staging_test
7374

@@ -1412,10 +1413,12 @@ def get_memory_usage(storage_dtype, compute_dtype):
14121413
torch.float8_e4m3fn, torch.bfloat16
14131414
)
14141415

1416+
compute_capability = get_torch_cuda_device_capability()
14151417
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
14161418
# NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
14171419
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes.
1418-
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
1420+
if compute_capability < 8.9:
1421+
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
14191422
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
14201423
# bytes. This only happens for some models, so we allow a small tolerance.
14211424
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.

0 commit comments

Comments
 (0)