Skip to content

Commit 191f561

Browse files
manueldepradaburcgokden
authored andcommitted
Fix QuantoQuantizedCache import issues (huggingface#40109)
* fix quantoquantized
1 parent 0a87fce commit 191f561

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/transformers/cache_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
)
1717

1818

19-
if _is_quanto_greater_than_0_2_5 := is_quanto_greater("0.2.5", accept_dev=True):
20-
from optimum.quanto import MaxOptimizer, qint2, qint4, quantize_weight
21-
2219
if is_hqq_available():
2320
from hqq.core.quantize import Quantizer as HQQQuantizer
2421

@@ -558,7 +555,7 @@ def __init__(
558555
q_group_size: int = 64,
559556
residual_length: int = 128,
560557
):
561-
super().__init__(self)
558+
super().__init__()
562559
self.nbits = nbits
563560
self.axis_key = axis_key
564561
self.axis_value = axis_value
@@ -635,10 +632,12 @@ def __init__(
635632
residual_length=residual_length,
636633
)
637634

638-
if not _is_quanto_greater_than_0_2_5:
635+
# We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
636+
if is_quanto_greater("0.2.5", accept_dev=True):
637+
from optimum.quanto import MaxOptimizer, qint2, qint4
638+
else:
639639
raise ImportError(
640640
"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. "
641-
"Detected version {optimum_quanto_version}."
642641
)
643642

644643
if self.nbits not in [2, 4]:
@@ -656,6 +655,8 @@ def __init__(
656655
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
657656

658657
def _quantize(self, tensor, axis):
658+
from optimum.quanto import quantize_weight
659+
659660
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
660661
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
661662
return qtensor

src/transformers/utils/import_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1286,7 +1286,7 @@ def is_quanto_greater(library_version: str, accept_dev: bool = False):
12861286
given version. If `accept_dev` is True, it will also accept development versions (e.g. 2.7.0.dev20250320 matches
12871287
2.7.0).
12881288
"""
1289-
if not _is_package_available("optimum-quanto"):
1289+
if not _is_package_available("optimum.quanto"):
12901290
return False
12911291

12921292
if accept_dev:

0 commit comments

Comments
 (0)