Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,14 @@ def apply_freeu(
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])

return hidden_states, res_hidden_states


def get_torch_cuda_device_capability():
if torch.cuda.is_available():
device = torch.device("cuda")
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = f"{compute_capability[0]}.{compute_capability[1]}"
return float(compute_capability)
else:
return None

10 changes: 7 additions & 3 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
torch_all_close,
torch_device,
)
from diffusers.utils.torch_utils import get_torch_cuda_device_capability

from ..others.test_utils import TOKEN, USER, is_staging_test

Expand Down Expand Up @@ -1384,6 +1385,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
@require_torch_gpu
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0

def reset_memory_stats():
gc.collect()
Expand Down Expand Up @@ -1412,10 +1414,12 @@ def get_memory_usage(storage_dtype, compute_dtype):
torch.float8_e4m3fn, torch.bfloat16
)

compute_capability = get_torch_cuda_device_capability()
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
# NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes.
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
Expand Down
Loading