@@ -41,7 +41,7 @@ def quantize( # noqa C901
4141 checkpoint_dtype : Optional [DType ] = None ,
4242 checkpoint_path : Optional [Path ] = None ,
4343 # following arguments only available when setting int4 or gptq quantization.
44- group_size : Optional [int ] = 128 ,
44+ group_size : Optional [int ] = None ,
4545 # following arguments are only used for GPTQ
4646 calibration_tasks : Optional [list ] = None ,
4747 calibration_limit : Optional [int ] = None ,
@@ -146,9 +146,9 @@ def quantize( # noqa C901
146146 print ("quantized model:" , model )
147147 return model
148148 elif qmode == "8da4w" :
149- # Check for required args
150149 if group_size is None :
151- raise Exception ("For 8da4w quantization, group size must be specified." )
150+ # TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
151+ group_size = 128
152152
153153 from torchao .quantization import int8_dynamic_activation_int4_weight , quantize_
154154 from torchao .utils import unwrap_tensor_subclass
@@ -784,16 +784,20 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
784784############################ Source Transform Start #######################
785785
786786
787- def get_quant_embedding_transform (args , dtype_override : Optional [DType ] = None ):
788- if args .embedding_quantize .startswith ("torchao:" ):
787+ def get_quant_embedding_transform (
788+ embedding_quantize : str ,
789+ use_shared_embedding : bool = False ,
790+ dtype_override : Optional [DType ] = None ,
791+ ):
792+ if embedding_quantize .startswith ("torchao:" ):
789793 from torchao .experimental .quant_api import (
790794 EmbeddingQuantizer ,
791795 SharedEmbeddingQuantizer ,
792796 )
793797 from torchao .quantization .granularity import PerAxis , PerGroup
794798 from torchao .quantization .quant_api import MappingType
795799
796- quant_args = args . embedding_quantize .split (":" )[1 ].split ("," )
800+ quant_args = embedding_quantize .split (":" )[1 ].split ("," )
797801 if len (quant_args ) == 2 :
798802 bitwidth , group_size = quant_args
799803 is_asymmetric = True
@@ -814,7 +818,7 @@ def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
814818
815819 def _torchao_embedding_quantizer (model ):
816820 with torch .no_grad ():
817- if not args . use_shared_embedding :
821+ if not use_shared_embedding :
818822 EmbeddingQuantizer (
819823 weight_dtype = weight_dtype ,
820824 granularity = granularity ,
@@ -831,7 +835,7 @@ def _torchao_embedding_quantizer(model):
831835
832836 return _torchao_embedding_quantizer
833837
834- bitwidth , group_size = args . embedding_quantize .split ("," )
838+ bitwidth , group_size = embedding_quantize .split ("," )
835839 if group_size == "none" or group_size == "None" or group_size == "0" :
836840 group_size = None
837841 else :
@@ -848,34 +852,27 @@ def _torchao_embedding_quantizer(model):
848852
849853
850854def get_quant_weight_transform (
851- args ,
855+ quantization_mode : str ,
856+ group_size : Optional [int ] = None ,
852857 computation_dtype : Optional [DType ] = None ,
853858 checkpoint_dtype : Optional [DType ] = None ,
859+ checkpoint_path : Optional [Path ] = None ,
860+ tokenizer_path : Optional [Path ] = None ,
861+ calibration_tasks : Optional [list ] = None ,
862+ calibration_limit : Optional [int ] = None ,
863+ calibration_seq_length : Optional [int ] = None ,
854864):
855- # If these optional args are None, don't provide them to quantize().
856- quant_args_str = [
857- "group_size" ,
858- "calibration_tasks" ,
859- "calibration_limit" ,
860- "calibration_seq_length" ,
861- ]
862- arg_dict = vars (args )
863- quant_args = {
864- param : val
865- for param in quant_args_str
866- if (val := arg_dict .get (param )) is not None
867- }
868-
869865 return partial (
870866 quantize ,
871- ** quant_args ,
872- qmode = args .quantization_mode ,
867+ qmode = quantization_mode ,
873868 computation_dtype = computation_dtype ,
874869 checkpoint_dtype = checkpoint_dtype ,
875- checkpoint_path = (Path (path ) if (path := args .checkpoint ) is not None else None ),
876- tokenizer_path = (
877- Path (path ) if (path := args .tokenizer_path ) is not None else None
878- ),
870+ checkpoint_path = (Path (path ) if (path := checkpoint_path ) is not None else None ),
871+ group_size = group_size ,
872+ calibration_tasks = calibration_tasks ,
873+ calibration_limit = calibration_limit ,
874+ calibration_seq_length = calibration_seq_length ,
875+ tokenizer_path = (Path (path ) if (path := tokenizer_path ) is not None else None ),
879876 )
880877
881878
0 commit comments