Skip to content

Commit e67c25a

Browse files
committed
update
1 parent 391b5a9 commit e67c25a

File tree

4 files changed

+25
-15
lines changed

4 files changed

+25
-15
lines changed

src/diffusers/__init__.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"loaders": ["FromOriginalModelMixin"],
3232
"models": [],
3333
"pipelines": [],
34-
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig", "GGUFQuantizationConfig"],
34+
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
3535
"schedulers": [],
3636
"utils": [
3737
"OptionalDependencyNotAvailable",
@@ -428,8 +428,7 @@
428428
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
429429
raise OptionalDependencyNotAvailable()
430430
except OptionalDependencyNotAvailable:
431-
from .utils import \
432-
dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
431+
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
433432

434433
_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
435434
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
@@ -442,8 +441,7 @@
442441
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
443442
raise OptionalDependencyNotAvailable()
444443
except OptionalDependencyNotAvailable:
445-
from .utils import \
446-
dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
444+
from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
447445

448446
_import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
449447
name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
@@ -456,8 +454,7 @@
456454
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
457455
raise OptionalDependencyNotAvailable()
458456
except OptionalDependencyNotAvailable:
459-
from .utils import \
460-
dummy_torch_and_transformers_and_onnx_objects # noqa F403
457+
from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
461458

462459
_import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
463460
name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
@@ -492,8 +489,7 @@
492489
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
493490
raise OptionalDependencyNotAvailable()
494491
except OptionalDependencyNotAvailable:
495-
from .utils import \
496-
dummy_transformers_and_torch_and_note_seq_objects # noqa F403
492+
from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
497493

498494
_import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
499495
name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")

src/diffusers/quantizers/quantization_config.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,19 +395,29 @@ def to_diff_dict(self) -> Dict[str, Any]:
395395
return serializable_config_dict
396396

397397

398+
@dataclass
398399
class GGUFQuantizationConfig(QuantizationConfigMixin):
399-
def __init__(self, compute_dtype=None, quant_storage=None, modules_to_not_convert=None):
400+
"""This is a config class for GGUF Quantization techniques.
401+
402+
Args:
403+
compute_dtype: (`torch.dtype`, defaults to `torch.float32`):
404+
This sets the computational type which might be different than the input type. For example, inputs might be
405+
fp32, but computation can be set to bf16 for speedups.
406+
407+
"""
408+
409+
def __init__(self, compute_dtype: torch.dtype = None):
400410
self.quant_method = QuantizationMethod.GGUF
401411
self.compute_dtype = compute_dtype
402-
self.quant_storage = quant_storage
403412
self.pre_quantized = True
404-
self.modules_to_not_convert = modules_to_not_convert
413+
414+
# TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints.
415+
self.modules_to_not_convert = []
405416

406417
if self.compute_dtype is None:
407418
self.compute_dtype = torch.float32
408419

409-
if self.quant_storage is None:
410-
self.quant_storage = torch.uint8
420+
411421
@dataclass
412422
class TorchAoConfig(QuantizationConfigMixin):
413423
"""This is a config class for torchao quantization/sparsity techniques.

src/diffusers/utils/import_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@ def is_imageio_available():
479479

480480
def is_gguf_available():
481481
return _is_gguf_available
482+
483+
482484
def is_torchao_available():
483485
return _is_torchao_available
484486

@@ -622,7 +624,8 @@ def is_torchao_available():
622624
"""
623625

624626
TORCHAO_IMPORT_ERROR = """
625-
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao`
627+
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
628+
torchao`
626629
"""
627630

628631
BACKENDS_MAPPING = OrderedDict(

src/diffusers/utils/testing_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def decorator(test_case):
487487
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
488488
)(test_case)
489489

490+
490491
def require_torchao_version_greater(torchao_version):
491492
def decorator(test_case):
492493
correct_torchao_version = is_torchao_available() and version.parse(

0 commit comments

Comments
 (0)