Skip to content

Commit a8bcb03

Browse files
committed
replace with is_torchao_version
1 parent f0b1b11 commit a8bcb03

File tree

2 files changed

+6
-15
lines changed

2 files changed

+6
-15
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from packaging import version
3636

37-
from ..utils import is_torch_available, is_torchao_available, logging
37+
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
3838

3939

4040
if is_torch_available():
@@ -516,7 +516,6 @@ def __init__(
516516

517517
def post_init(self):
518518
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
519-
AO_VERSION = self._get_ao_version()
520519

521520
if isinstance(self.quant_type, str):
522521
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
@@ -546,7 +545,7 @@ def post_init(self):
546545
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
547546
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
548547
)
549-
elif AO_VERSION > version.parse("0.9.0"):
548+
elif is_torchao_version(">", "0.9.0"):
550549
from torchao.quantization.quant_api import AOBaseConfig
551550

552551
if not isinstance(self.quant_type, AOBaseConfig):
@@ -590,8 +589,8 @@ def to_dict(self):
590589
@classmethod
591590
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
592591
"""Create configuration from a dictionary."""
593-
ao_version = cls._get_ao_version()
594-
assert ao_version > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
592+
if not is_torchao_version(">", "0.9.0"):
593+
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
595594
config_dict = config_dict.copy()
596595
quant_type = config_dict.pop("quant_type")
597596

@@ -611,14 +610,6 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
611610

612611
return cls(quant_type=quant_type, **config_dict)
613612

614-
@staticmethod
615-
def _get_ao_version() -> version.Version:
616-
"""Centralized check for TorchAO availability and version requirements."""
617-
if not is_torchao_available():
618-
raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`")
619-
620-
return version.parse(importlib.metadata.version("torchao"))
621-
622613
@classmethod
623614
def _get_torchao_quant_type_to_method(cls):
624615
r"""

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
235235
elif quant_type.startswith("float") or quant_type.startswith("fp"):
236236
return torch.bfloat16
237237

238-
elif self.quantization_config._get_ao_version() > version.Version("0.9.0"):
238+
elif is_torchao_version(">", "0.9.0"):
239239
from torchao.core.config import AOBaseConfig
240240

241241
quant_type = self.quantization_config.quant_type
@@ -332,7 +332,7 @@ def get_cuda_warm_up_factor(self):
332332
# Original mapping for non-AOBaseConfig types
333333
# For the uint types, this is a best guess. Once these types become more used
334334
# we can look into their nuances.
335-
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
335+
if is_torchao_version(">", "0.9.0"):
336336
from torchao.core.config import AOBaseConfig
337337

338338
quant_type = self.quantization_config.quant_type

0 commit comments

Comments
 (0)