Skip to content

Commit cd9c0d6

Browse files
committed
restrict memory tests for quanto for certain schemes.
1 parent 20e4b6a commit cd9c0d6

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@
101101
mps_backend_registered = hasattr(torch.backends, "mps")
102102
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
103103

104+
from .torch_utils import get_torch_cuda_device_capability
105+
104106

105107
def torch_all_close(a, b, *args, **kwargs):
106108
if not is_torch_available():
@@ -282,6 +284,18 @@ def require_torch_gpu(test_case):
282284
)
283285

284286

287+
def require_torch_cuda_compatibility(expected_compute_capability):
288+
def decorator(test_case):
289+
if not torch.cuda.is_available():
290+
return unittest.skip(test_case)
291+
else:
292+
compute_capability = get_torch_cuda_device_capability()
293+
current_compute_capability = f"{compute_capability[0]}.{compute_capability[1]}"
294+
return unittest.skipUnless(float(current_compute_capability) == float(expected_compute_capability))
295+
296+
return decorator
297+
298+
285299
# These decorators are for accelerator-specific behaviours that are not GPU-specific
286300
def require_torch_accelerator(test_case):
287301
"""Decorator marking a test that requires an accelerator backend and PyTorch."""

tests/quantization/quanto/test_quanto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
numpy_cosine_similarity_distance,
1111
require_accelerate,
1212
require_big_gpu_with_torch_cuda,
13+
require_torch_cuda_compatibility,
1314
torch_device,
1415
)
1516

@@ -311,13 +312,19 @@ def get_dummy_init_kwargs(self):
311312
return {"weights_dtype": "int8"}
312313

313314

315+
require_torch_cuda_compatibility(8.0)
316+
317+
314318
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
315319
expected_memory_reduction = 0.55
316320

317321
def get_dummy_init_kwargs(self):
318322
return {"weights_dtype": "int4"}
319323

320324

325+
require_torch_cuda_compatibility(8.0)
326+
327+
321328
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
322329
expected_memory_reduction = 0.65
323330

0 commit comments

Comments
 (0)