@@ -136,14 +136,14 @@ def quantize( # noqa C901
136136 # Check for required args
137137 if group_size is None :
138138 raise Exception ("For 8da4w quantization, group size must be specified." )
139- from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
140139
141- # 1. Quantize in checkpoint dtype.
142- model = Int8DynActInt4WeightQuantizer (
143- precision = checkpoint_torch_dtype , groupsize = group_size
144- ).quantize (model )
145- # 2. Set the computation dtype (what weights/acts dequantize to).
146- model = set_8da4w_computation_dtype (model , computation_torch_dtype )
140+ from torchao .quantization import int8_dynamic_activation_int4_weight , quantize_
141+ from torchao .utils import unwrap_tensor_subclass
142+
143+ quantize_ (model , int8_dynamic_activation_int4_weight (group_size = group_size ))
144+ model = unwrap_tensor_subclass (model )
145+
146+ # TODO: deal with checkpoint / computation dtype decoupling.
147147
148148 if verbose :
149149 print ("quantized model:" , model )
@@ -698,7 +698,7 @@ def convert_for_runtime(self) -> nn.Module:
698698 def quantized_model (self ) -> nn .Module :
699699 model_updated_state_dict = self .create_quantized_state_dict (self .packed )
700700 self .convert_for_runtime ()
701- self .mod .load_state_dict (model_updated_state_dict )
701+ self .mod .load_state_dict (model_updated_state_dict , assign = True )
702702 return self .mod
703703
704704
0 commit comments