@@ -49,6 +49,7 @@ def quantize( # noqa C901
4949 blocksize : int = 128 ,
5050 tokenizer_path : Optional [Path ] = None ,
5151 verbose : bool = False ,
52+ quantize_with_hqq : bool = True ,
5253) -> torch .nn .Module :
5354 """
5455 Quantizes a model by converting all weights to int8.
@@ -119,7 +120,6 @@ def quantize( # noqa C901
119120 from torchao .quantization .granularity import PerAxis , PerGroup
120121 from torchao .quantization .quant_api import (
121122 Int8DynamicActivationIntxWeightConfig ,
122- MappingType ,
123123 quantize_ ,
124124 )
125125 from torchao .utils import unwrap_tensor_subclass
@@ -134,9 +134,12 @@ def quantize( # noqa C901
134134 weight_granularity = (
135135 PerAxis (0 ) if group_size == 0 else PerGroup (group_size )
136136 ),
137- weight_mapping_type = MappingType .SYMMETRIC ,
138137 # pyre-ignore[6]
139138 intx_packing_format = "opaque_torchao_auto" ,
139+ # pyre-ignore[6]
140+ intx_choose_qparams_algorithm = (
141+ "hqq_scale_only" if quantize_with_hqq else "affine"
142+ ),
140143 ),
141144 )
142145 model = unwrap_tensor_subclass (model )
@@ -170,6 +173,10 @@ def filter_fn(m, fqn):
170173 # pyre-ignore[16]
171174 weight_dtype = torch .int4 ,
172175 weight_granularity = PerGroup (group_size ),
176+ # pyre-ignore[6]
177+ intx_choose_qparams_algorithm = (
178+ "hqq_scale_only" if quantize_with_hqq else "affine"
179+ ),
173180 ),
174181 filter_fn = filter_fn ,
175182 )
@@ -191,6 +198,10 @@ def filter_fn(m, fqn):
191198 # pyre-ignore[16]
192199 weight_dtype = torch .int4 ,
193200 granularity = PerGroup (q_group_size ),
201+ # pyre-ignore[6]
202+ intx_choose_qparams_algorithm = (
203+ "hqq_scale_only" if quantize_with_hqq else "affine"
204+ ),
194205 )
195206 quantize_ (model , q_config )
196207 model = unwrap_tensor_subclass (model )
@@ -580,6 +591,7 @@ def __init__(
580591 group_size : Optional [int ] = None ,
581592 packed = False ,
582593 precision : Optional [torch .dtype ] = None ,
594+ quantize_with_hqq : bool = True ,
583595 ):
584596 if isinstance (packed , str ):
585597 packed = packed == "True"
@@ -592,15 +604,12 @@ def __init__(
592604 self .precision = precision
593605 if (bitwidth not in [2 , 4 ]) and packed :
594606 raise RuntimeError ("pack only works with bitsize 2, 4" )
607+ self .quantize_with_hqq = quantize_with_hqq
595608
596609 @torch .no_grad ()
597610 def create_quantized_state_dict (self , packed = False ) -> Dict :
598611 from torchao .quantization .granularity import PerAxis , PerGroup
599- from torchao .quantization .quant_api import (
600- IntxWeightOnlyConfig ,
601- MappingType ,
602- quantize_ ,
603- )
612+ from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
604613
605614 cur_state_dict = self .mod .state_dict ()
606615
@@ -627,7 +636,10 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
627636 if (self .group_size is None or self .group_size == 0 )
628637 else PerGroup (self .group_size )
629638 ),
630- mapping_type = MappingType .SYMMETRIC ,
639+ # pyre-ignore[6]
640+ intx_choose_qparams_algorithm = (
641+ "hqq_scale_only" if self .quantize_with_hqq else "affine"
642+ ),
631643 )
632644 quantize_ (tmp_model , config , lambda m , fqn : isinstance (m , nn .Embedding ))
633645 weight = tmp_model .weight .qdata # pyre-ignore[16]
@@ -765,6 +777,7 @@ def get_quant_embedding_transform(
765777 embedding_quantize : str ,
766778 use_shared_embedding : bool = False ,
767779 dtype_override : Optional [DType ] = None ,
780+ quantize_with_hqq : bool = True ,
768781):
769782 if embedding_quantize .startswith ("torchao:" ):
770783 from torchao .prototype .quantization .embedding .api import (
@@ -825,6 +838,7 @@ def _torchao_embedding_quantizer(model):
825838 group_size = group_size ,
826839 packed = (bitwidth in [2 , 4 ]),
827840 precision = torch_dtype ,
841+ quantize_with_hqq = quantize_with_hqq ,
828842 ).quantized_model ()
829843
830844
@@ -838,6 +852,7 @@ def get_quant_weight_transform(
838852 calibration_tasks : Optional [list ] = None ,
839853 calibration_limit : Optional [int ] = None ,
840854 calibration_seq_length : Optional [int ] = None ,
855+ quantize_with_hqq : bool = True ,
841856):
842857 return partial (
843858 quantize ,
@@ -850,6 +865,7 @@ def get_quant_weight_transform(
850865 calibration_limit = calibration_limit ,
851866 calibration_seq_length = calibration_seq_length ,
852867 tokenizer_path = (Path (path ) if (path := tokenizer_path ) is not None else None ),
868+ quantize_with_hqq = quantize_with_hqq ,
853869 )
854870
855871
@@ -877,7 +893,6 @@ def _load_torchao_aten_lib(libname):
877893def set_8da4w_computation_dtype (
878894 module : nn .Module , computation_dtype : torch .dtype
879895) -> nn .Module :
880-
881896 from torchao .quantization .linear_quant_modules import Int8DynActInt4WeightLinear
882897
883898 def _set_8da4w_computation_dtype (module : nn .Module , dtype : torch .dtype ) -> None :
0 commit comments