diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 8e3aff3c8bd3..9ef198a76703 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -16,9 +16,6 @@ ) -if _is_quanto_greater_than_0_2_5 := is_quanto_greater("0.2.5", accept_dev=True): - from optimum.quanto import MaxOptimizer, qint2, qint4, quantize_weight - if is_hqq_available(): from hqq.core.quantize import Quantizer as HQQQuantizer @@ -558,7 +555,7 @@ def __init__( q_group_size: int = 64, residual_length: int = 128, ): - super().__init__(self) + super().__init__() self.nbits = nbits self.axis_key = axis_key self.axis_value = axis_value @@ -635,10 +632,12 @@ def __init__( residual_length=residual_length, ) - if not _is_quanto_greater_than_0_2_5: + # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py + if is_quanto_greater("0.2.5", accept_dev=True): + from optimum.quanto import MaxOptimizer, qint2, qint4 + else: raise ImportError( "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. " - "Detected version {optimum_quanto_version}." ) if self.nbits not in [2, 4]: @@ -656,6 +655,8 @@ def __init__( self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization def _quantize(self, tensor, axis): + from optimum.quanto import quantize_weight + scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) return qtensor diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 163ee0410944..1ae4d9d71159 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1286,7 +1286,7 @@ def is_quanto_greater(library_version: str, accept_dev: bool = False): given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches 2.7.0). """ - if not _is_package_available("optimum-quanto"): + if not _is_package_available("optimum.quanto"): return False if accept_dev: