44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import logging
8+ import re
79from functools import partial
810from pathlib import Path
911from typing import Any , Dict , Optional
@@ -70,6 +72,26 @@ def quantize( # noqa C901
7072 if qmode == "int8" :
7173 # Add quantization mode options here: group size, bit width, etc.
7274 return WeightOnlyInt8QuantHandler (model ).quantized_model ()
75+ elif qmode .startswith ("torchao:" ):
76+ pattern = r"torchao:8da(\d+)w"
77+ matches = re .findall (pattern , qmode )
78+ assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
79+ bitwidth = int (matches [0 ][0 ])
80+ _load_torchao_ops_aten ()
81+ from torchao .experimental .quant_api import Int8DynActIntxWeightLinearQuantizer
82+
83+ with torch .no_grad ():
84+ model = Int8DynActIntxWeightLinearQuantizer (
85+ device = "cpu" ,
86+ precision = torch .float32 ,
87+ groupsize = group_size ,
88+ bitwidth = bitwidth ,
89+ has_weight_zeros = False ,
90+ ).quantize (model )
91+
92+ if verbose :
93+ print ("quantized model:" , model )
94+ return model
7395 elif qmode == "8da4w" :
7496 # Check for required args
7597 if group_size is None :
@@ -79,6 +101,7 @@ def quantize( # noqa C901
79101 model = Int8DynActInt4WeightQuantizer (
80102 precision = torch_dtype , groupsize = group_size
81103 ).quantize (model )
104+
82105 if verbose :
83106 print ("quantized model:" , model )
84107 return model
@@ -692,6 +715,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
692715
693716
694717def get_quant_embedding_transform (args ):
718+ if args .embedding_quantize .startswith ("torchao:" ):
719+ bitwidth , group_size = args .embedding_quantize .split (":" )[1 ].split ("," )
720+ group_size = int (group_size )
721+ bitwidth = int (bitwidth )
722+ _load_torchao_ops_aten ()
723+ from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
724+
725+ def _torchao_embedding_quantizer (model ):
726+ with torch .no_grad ():
727+ model = IntxWeightEmbeddingQuantizer (
728+ device = "cpu" ,
729+ precision = torch .float32 ,
730+ bitwidth = bitwidth ,
731+ groupsize = group_size ,
732+ ).quantize (model )
733+ return model
734+
735+ return _torchao_embedding_quantizer
736+
695737 bitwidth , group_size = args .embedding_quantize .split ("," )
696738 if group_size == "none" or group_size == "None" or group_size == "0" :
697739 group_size = None
@@ -733,4 +775,23 @@ def get_quant_weight_transform(args, dtype_override, verbose):
733775 )
734776
735777
778+ def _load_torchao_ops_aten ():
779+ import glob
780+ import os
781+
782+ libs = glob .glob (
783+ os .path .abspath (
784+ os .path .join (
785+ os .environ .get ("CMAKE_INSTALL_PREFIX" , "" ),
786+ "lib/libtorchao_ops_aten.*" ,
787+ )
788+ )
789+ )
790+ assert (
791+ len (libs ) == 1
792+ ), f"Expected 1 library but got { len (libs )} . If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
793+ logging .info (f"Loading custom ops library: { libs [0 ]} " )
794+ torch .ops .load_library (libs [0 ])
795+
796+
736797############################ Source Transform End #######################
0 commit comments