16
16
)
17
17
18
18
19
- if _is_quanto_greater_than_0_2_5 := is_quanto_greater ("0.2.5" , accept_dev = True ):
20
- from optimum .quanto import MaxOptimizer , qint2 , qint4 , quantize_weight
21
-
22
19
if is_hqq_available ():
23
20
from hqq .core .quantize import Quantizer as HQQQuantizer
24
21
@@ -558,7 +555,7 @@ def __init__(
558
555
q_group_size : int = 64 ,
559
556
residual_length : int = 128 ,
560
557
):
561
- super ().__init__ (self )
558
+ super ().__init__ ()
562
559
self .nbits = nbits
563
560
self .axis_key = axis_key
564
561
self .axis_value = axis_value
@@ -635,10 +632,12 @@ def __init__(
635
632
residual_length = residual_length ,
636
633
)
637
634
638
- if not _is_quanto_greater_than_0_2_5 :
635
+ # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
636
+ if is_quanto_greater ("0.2.5" , accept_dev = True ):
637
+ from optimum .quanto import MaxOptimizer , qint2 , qint4
638
+ else :
639
639
raise ImportError (
640
640
"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. "
641
- "Detected version {optimum_quanto_version}."
642
641
)
643
642
644
643
if self .nbits not in [2 , 4 ]:
@@ -656,6 +655,8 @@ def __init__(
656
655
self .optimizer = MaxOptimizer () # hardcode as it's the only one for per-channel quantization
657
656
658
657
def _quantize (self , tensor , axis ):
658
+ from optimum .quanto import quantize_weight
659
+
659
660
scale , zeropoint = self .optimizer (tensor , self .qtype , axis , self .q_group_size )
660
661
qtensor = quantize_weight (tensor , self .qtype , axis , scale , zeropoint , self .q_group_size )
661
662
return qtensor
0 commit comments