1818
1919from sentencepiece import SentencePieceProcessor
2020
21+
2122try :
2223 from fairseq2 .nn .embedding import (
2324 Embedding as fsEmbedding ,
3637def quantize ( # noqa C901
3738 model : torch .nn .Module ,
3839 qmode : str ,
39- activation_dtype : Optional [DType ],
40+ computation_dtype : Optional [DType ] = None ,
41+ checkpoint_dtype : Optional [DType ] = None ,
4042 checkpoint_path : Optional [Path ] = None ,
4143 # following arguments only available when setting int4 or gptq quantization.
4244 group_size : Optional [int ] = 128 ,
@@ -52,20 +54,33 @@ def quantize( # noqa C901
5254) -> torch .nn .Module :
5355 """
5456 Quantizes a model by converting all weights to int8.
57+
5558 Args:
56- model: A model to quantize.
57- qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq
59+ model: The model to quantize.
60+ qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq.
61+ computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
62+ Also the dtype of the rest of the non-quantized compoents of the model.
63+ checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
64+ quantize the weight in its original dtype.
65+
5866 Returns:
5967 A quantized model.
6068 """
61- if activation_dtype is not None :
62- torch_dtype = activation_dtype .to_torch_dtype ()
69+ if computation_dtype :
70+ computation_torch_dtype = computation_dtype .to_torch_dtype ()
6371 else :
64- torch_dtype = torch .float16
72+ computation_torch_dtype = torch .float32
73+
74+ if not checkpoint_dtype :
75+ checkpoint_torch_dtype = computation_torch_dtype
76+ else :
77+ checkpoint_torch_dtype = checkpoint_dtype .to_torch_dtype ()
6578
6679 if qmode == "int8" :
6780 # Add quantization mode options here: group size, bit width, etc.
68- return WeightOnlyInt8QuantHandler (model ).quantized_model ()
81+ return WeightOnlyInt8QuantHandler (
82+ model , precision = checkpoint_torch_dtype
83+ ).quantized_model ()
6984 elif qmode .startswith ("torchao:fpa" ):
7085 pattern = r"torchao:fpa(\d+)w"
7186 matches = re .findall (pattern , qmode )
@@ -75,10 +90,12 @@ def quantize( # noqa C901
7590 from torchao .experimental .quant_api import UIntxWeightOnlyLinearQuantizer
7691
7792 with torch .no_grad ():
93+ # This quantize() is currently doing a model.to(self.precision) so cannot
94+ # decouple computation and checkpoint dtypes.
7895 model = (
7996 UIntxWeightOnlyLinearQuantizer (
8097 device = "mps" ,
81- precision = torch . float32 ,
98+ precision = computation_torch_dtype ,
8299 groupsize = group_size ,
83100 bitwidth = bitwidth ,
84101 )
@@ -101,6 +118,8 @@ def quantize( # noqa C901
101118 from torchao .utils import unwrap_tensor_subclass
102119
103120 with torch .no_grad ():
121+ # Computation dtype is fixed to fp32 in the implementation of quantize_, so
122+ # no way to decouple checkpoint and computation dtype.
104123 quantize_ (
105124 model ,
106125 Int8DynamicActivationIntxWeightConfig (
@@ -121,9 +140,12 @@ def quantize( # noqa C901
121140 raise Exception ("For 8da4w quantization, group size must be specified." )
122141 from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
123142
143+ # 1. Quantize in checkpoint dtype.
124144 model = Int8DynActInt4WeightQuantizer (
125- precision = torch_dtype , groupsize = group_size
145+ precision = checkpoint_torch_dtype , groupsize = group_size
126146 ).quantize (model )
147+ # 2. Set the computation dtype (what weights/acts dequantize to).
148+ model = set_8da4w_computation_dtype (model , computation_torch_dtype )
127149
128150 if verbose :
129151 print ("quantized model:" , model )
@@ -177,7 +199,7 @@ def quantize( # noqa C901
177199 blocksize ,
178200 percdamp ,
179201 group_size ,
180- )
202+ ) # TODO: separate computation and checkpoint dtype for GPTQ.
181203 model = gptq_quantizer .quantize (model , inputs )
182204 return model
183205 elif qmode == "vulkan_4w" :
@@ -190,9 +212,12 @@ def quantize( # noqa C901
190212 # at the moment
191213 from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
192214
215+ # 1. Quantize in checkpoint dtype.
193216 model = Int8DynActInt4WeightQuantizer (
194- precision = torch_dtype , groupsize = q_group_size
217+ precision = checkpoint_torch_dtype , groupsize = q_group_size
195218 ).quantize (model )
219+ # 2. Set the computation dtype (what weights/acts dequantize to).
220+ model = set_8da4w_computation_dtype (model , computation_torch_dtype )
196221
197222 return model
198223 else :
@@ -348,6 +373,7 @@ def __init__(
348373 node_type : str = "*" ,
349374 bitwidth : Optional [int ] = None ,
350375 group_size : Optional [int ] = None ,
376+ precision : torch .dtype = torch .float32 ,
351377 ):
352378 self .mod = mod
353379 self .group_size = group_size
@@ -356,6 +382,7 @@ def __init__(
356382 self .bitwidth = 8
357383 else :
358384 self .bitwidth = bitwidth
385+ self .precision = precision
359386
360387 @torch .no_grad ()
361388 def create_quantized_state_dict (self ) -> Dict :
@@ -391,7 +418,7 @@ def create_quantized_state_dict(self) -> Dict:
391418
392419 # print(f"expanded weight shape {input_weight.shape}")
393420 weight , scales , _ = dynamically_quantize_per_channel (
394- input_weight ,
421+ input_weight . to ( dtype = self . precision ) ,
395422 range_min ,
396423 range_max ,
397424 torch .int8 ,
@@ -576,6 +603,7 @@ def __init__(
576603 bitwidth : int = 8 ,
577604 group_size : Optional [int ] = None ,
578605 packed = False ,
606+ precision : Optional [torch .dtype ] = None ,
579607 ):
580608 if isinstance (packed , str ):
581609 packed = packed == "True"
@@ -584,6 +612,8 @@ def __init__(
584612 self .group_size = group_size
585613 self .bitwidth = bitwidth
586614 self .packed = packed
615+ # Dtype of the weights right before quantization.
616+ self .precision = precision
587617 if (bitwidth not in [2 , 4 ]) and packed :
588618 raise RuntimeError ("pack only works with bitsize 2, 4" )
589619
@@ -614,7 +644,11 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
614644 f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
615645 )
616646 weight , scales , _ = dynamically_quantize_per_channel (
617- mod .weight .float (),
647+ (
648+ mod .weight .to (dtype = self .precision )
649+ if self .precision
650+ else mod .weight
651+ ),
618652 range_min ,
619653 range_max ,
620654 torch .int8 ,
@@ -750,7 +784,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
750784############################ Source Transform Start #######################
751785
752786
753- def get_quant_embedding_transform (args ):
787+ def get_quant_embedding_transform (args , dtype_override : Optional [ DType ] = None ):
754788 if args .embedding_quantize .startswith ("torchao:" ):
755789 bitwidth , group_size = args .embedding_quantize .split (":" )[1 ].split ("," )
756790 group_size = int (group_size )
@@ -775,16 +809,22 @@ def _torchao_embedding_quantizer(model):
775809 else :
776810 group_size = int (group_size )
777811 bitwidth = int (bitwidth )
812+ torch_dtype = dtype_override .to_torch_dtype () if dtype_override else None
778813 return lambda model : EmbeddingQuantHandler (
779814 model ,
780815 bitwidth = bitwidth ,
781816 group_size = group_size ,
782817 packed = (bitwidth in [2 , 4 ]),
818+ precision = torch_dtype ,
783819 ).quantized_model ()
784820
785821
786- def get_quant_weight_transform (args , dtype_override , verbose ):
787- # If these optional args are None, don't provide them to quantize()
822+ def get_quant_weight_transform (
823+ args ,
824+ computation_dtype : Optional [DType ] = None ,
825+ checkpoint_dtype : Optional [DType ] = None ,
826+ ):
827+ # If these optional args are None, don't provide them to quantize().
788828 quant_args_str = [
789829 "group_size" ,
790830 "calibration_tasks" ,
@@ -802,7 +842,8 @@ def get_quant_weight_transform(args, dtype_override, verbose):
802842 quantize ,
803843 ** quant_args ,
804844 qmode = args .quantization_mode ,
805- activation_dtype = dtype_override ,
845+ computation_dtype = computation_dtype ,
846+ checkpoint_dtype = checkpoint_dtype ,
806847 checkpoint_path = (Path (path ) if (path := args .checkpoint ) is not None else None ),
807848 tokenizer_path = (
808849 Path (path ) if (path := args .tokenizer_path ) is not None else None
@@ -829,4 +870,28 @@ def _load_torchao_aten_lib(libname):
829870 torch .ops .load_library (libs [0 ])
830871
831872
873+ # We want to do compute the actual ops in the computation dtype, since the precision of the
874+ # quantized linear will initially be the dtype of the checkpoint.
875+ def set_8da4w_computation_dtype (
876+ module : nn .Module , computation_dtype : torch .dtype
877+ ) -> nn .Module :
878+
879+ from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
880+
881+ def _set_8da4w_computation_dtype (module : nn .Module , dtype : torch .dtype ) -> None :
882+ """
883+ Recursively iterate through the module and set the precision attributes
884+ of all Int8DynActInt4WeightLinears.
885+ """
886+ for _name , child in module .named_children ():
887+ if isinstance (child , Int8DynActInt4WeightLinear ):
888+ child .precision = dtype
889+ else :
890+ # Recursively apply to child modules
891+ _set_8da4w_computation_dtype (child , dtype )
892+
893+ _set_8da4w_computation_dtype (module , computation_dtype )
894+ return module
895+
896+
832897############################ Source Transform End #######################
0 commit comments