Skip to content

Commit 67f1700

Browse files
committed
update
1 parent 66ae46e commit 67f1700

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def validate_environment(self, *args, **kwargs):
4949
"To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
5050
)
5151

52+
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
5253
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
5354
# need more space for buffers that are created during quantization
5455
max_memory = {key: val * 0.90 for key, val in max_memory.items()}

src/diffusers/utils/testing_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
is_bitsandbytes_available,
3333
is_compel_available,
3434
is_flax_available,
35+
is_gguf_available,
3536
is_note_seq_available,
3637
is_onnx_available,
3738
is_opencv_available,
@@ -476,6 +477,18 @@ def decorator(test_case):
476477
return decorator
477478

478479

480+
def require_gguf_version_greater_or_equal(gguf_version):
481+
def decorator(test_case):
482+
correct_gguf_version = is_gguf_available() and version.parse(
483+
version.parse(importlib.metadata.version("gguf")).base_version
484+
) >= version.parse(gguf_version)
485+
return unittest.skipUnless(
486+
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
487+
)(test_case)
488+
489+
return decorator
490+
491+
479492
def deprecate_after_peft_backend(test_case):
480493
"""
481494
Decorator marking a test that will be skipped after PEFT backend

tests/quantization/gguf/test_gguf.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,24 @@
44
import torch
55

66
from diffusers import FluxTransformer2DModel, GGUFQuantizationConfig
7-
from diffusers.quantizers.gguf.utils import GGUFParameter
87
from diffusers.utils.testing_utils import (
8+
is_gguf_available,
99
nightly,
10+
require_accelerate,
1011
require_big_gpu_with_torch_cuda,
12+
require_gguf_version_greater_or_equal,
1113
torch_device,
1214
)
1315

1416

17+
if is_gguf_available():
18+
from diffusers.quantizers.gguf.utils import GGUFParameter
19+
20+
1521
@nightly
1622
@require_big_gpu_with_torch_cuda
23+
@require_accelerate
24+
@require_gguf_version_greater_or_equal("0.10.0")
1725
class GGUFSingleFileTests(unittest.TestCase):
1826
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
1927
torch_dtype = torch.bfloat16

0 commit comments

Comments
 (0)