@@ -72,7 +72,9 @@ def quantize( # noqa C901
7272
7373 if qmode == "int8" :
7474 # Add quantization mode options here: group size, bit width, etc.
75- return WeightOnlyInt8QuantHandler (model ).quantized_model ()
75+ return WeightOnlyInt8QuantHandler (
76+ model , precision = torch_dtype
77+ ).quantized_model ()
7678 elif qmode .startswith ("torchao:fpa" ):
7779 pattern = r"torchao:fpa(\d+)w"
7880 matches = re .findall (pattern , qmode )
@@ -85,7 +87,7 @@ def quantize( # noqa C901
8587 model = (
8688 UIntxWeightOnlyLinearQuantizer (
8789 device = "mps" ,
88- precision = torch . float32 ,
90+ precision = torch_dtype ,
8991 groupsize = group_size ,
9092 bitwidth = bitwidth ,
9193 )
@@ -107,7 +109,7 @@ def quantize( # noqa C901
107109 with torch .no_grad ():
108110 model = Int8DynActIntxWeightLinearQuantizer (
109111 device = "cpu" ,
110- precision = torch . float32 ,
112+ precision = torch_dtype ,
111113 groupsize = group_size ,
112114 bitwidth = bitwidth ,
113115 has_weight_zeros = False ,
@@ -346,6 +348,7 @@ def __init__(
346348 node_type : str = "*" ,
347349 bitwidth : Optional [int ] = None ,
348350 group_size : Optional [int ] = None ,
351+ precision : torch .dtype = torch .float32 ,
349352 ):
350353 self .mod = mod
351354 self .group_size = group_size
@@ -354,6 +357,7 @@ def __init__(
354357 self .bitwidth = 8
355358 else :
356359 self .bitwidth = bitwidth
360+ self .precision = precision
357361
358362 @torch .no_grad ()
359363 def create_quantized_state_dict (self ) -> Dict :
@@ -389,7 +393,7 @@ def create_quantized_state_dict(self) -> Dict:
389393
390394 # print(f"expanded weight shape {input_weight.shape}")
391395 weight , scales , _ = dynamically_quantize_per_channel (
392- input_weight ,
396+ input_weight . to ( dtype = self . precision ) ,
393397 range_min ,
394398 range_max ,
395399 torch .int8 ,
@@ -574,6 +578,7 @@ def __init__(
574578 bitwidth : int = 8 ,
575579 group_size : Optional [int ] = None ,
576580 packed = False ,
581+ precision : Optional [torch .dtype ] = None ,
577582 ):
578583 if isinstance (packed , str ):
579584 packed = packed == "True"
@@ -582,6 +587,8 @@ def __init__(
582587 self .group_size = group_size
583588 self .bitwidth = bitwidth
584589 self .packed = packed
590+ # Dtype of the weights right before quantization.
591+ self .precision = precision
585592 if (bitwidth not in [2 , 4 ]) and packed :
586593 raise RuntimeError ("pack only works with bitsize 2, 4" )
587594
@@ -612,7 +619,11 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
612619 f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
613620 )
614621 weight , scales , _ = dynamically_quantize_per_channel (
615- mod .weight ,
622+ (
623+ mod .weight .to (dtype = self .precision )
624+ if self .precision
625+ else mod .weight
626+ ),
616627 range_min ,
617628 range_max ,
618629 torch .int8 ,
@@ -748,7 +759,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
748759############################ Source Transform Start #######################
749760
750761
751- def get_quant_embedding_transform (args ):
762+ def get_quant_embedding_transform (args , dtype_override : Optional [ DType ] = None ):
752763 if args .embedding_quantize .startswith ("torchao:" ):
753764 bitwidth , group_size = args .embedding_quantize .split (":" )[1 ].split ("," )
754765 group_size = int (group_size )
@@ -774,11 +785,13 @@ def _torchao_embedding_quantizer(model):
774785 else :
775786 group_size = int (group_size )
776787 bitwidth = int (bitwidth )
788+ torch_dtype = dtype_override .to_torch_dtype () if dtype_override else None
777789 return lambda model : EmbeddingQuantHandler (
778790 model ,
779791 bitwidth = bitwidth ,
780792 group_size = group_size ,
781793 packed = (bitwidth in [2 , 4 ]),
794+ precision = torch_dtype ,
782795 ).quantized_model ()
783796
784797
@@ -831,25 +844,32 @@ def _load_torchao_aten_lib(libname):
831844# We want to do compute the actual ops in the dtype of the dtype_override,
832845# since the precision of the quantized linear will initially be the dtype of the
833846# checkpoint, not the dtype_override.
834- # TODO(#8652): this is a temporary solution for until we can support the new ao,
835- # quantize_ api, which apparently can support different dtypes at quantization and
836- # computation.
837- def _set_quantized_computation_dtype (module : nn .Module , dtype : torch .dtype ):
838- """
839- Recursively iterate through the module and set the dtype/precision attributes
840- of all Int8DynActInt4WeightLinear and QuantizedGroupEmbedding submodules to 'fp32'.
841- """
842- for name , child in module .named_children ():
843- if isinstance (child , Int8DynActInt4WeightLinear ):
844- # Change the precision attribute to 'fp32'
845- child .precision = dtype
846- print (f"Changed precision of { name } to { dtype } " )
847- elif isinstance (child , QuantizedGroupEmbedding ):
848- child .dtype = dtype
849- print (f"Changed precision of { name } to { dtype } " )
850- else :
851- # Recursively apply to child modules
852- _set_quantized_computation_dtype (child , dtype )
847+ def _set_quantized_computation_dtype (
848+ module : nn .Module , dtype : torch .dtype
849+ ) -> nn .Module :
850+ def _set_quantized_computation_dtype_rec (
851+ module : nn .Module , dtype : torch .dtype
852+ ) -> None :
853+ """
854+ Recursively iterate through the module and set the dtype/precision attributes
855+ of all Int8DynActInt4WeightLinear and QuantizedGroupEmbedding submodules to 'fp32'.
856+ """
857+ for name , child in module .named_children ():
858+ if isinstance (child , Int8DynActInt4WeightLinear ):
859+ # Change the precision attribute to 'fp32'
860+ child .precision = dtype
861+ print (f"Changed precision of { name } to { dtype } " )
862+ elif isinstance (child , QuantizedGroupEmbedding ):
863+ child .dtype = dtype
864+ print (f"Changed precision of { name } to { dtype } " )
865+ elif isinstance (child , WeightOnlyInt8Linear ):
866+ child .dtype = dtype
867+ else :
868+ # Recursively apply to child modules
869+ _set_quantized_computation_dtype_rec (child , dtype )
870+
871+ _set_quantized_computation_dtype_rec (module , dtype )
872+ return module
853873
854874
855875############################ Source Transform End #######################
0 commit comments