@@ -72,12 +72,37 @@ def quantize( # noqa C901
7272 if qmode == "int8" :
7373 # Add quantization mode options here: group size, bit width, etc.
7474 return WeightOnlyInt8QuantHandler (model ).quantized_model ()
75- elif qmode .startswith ("torchao:" ):
75+ elif qmode .startswith ("torchao:fpa" ):
76+ pattern = r"torchao:fpa(\d+)w"
77+ matches = re .findall (pattern , qmode )
78+ assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
79+ bitwidth = int (matches [0 ][0 ])
80+ _load_torchao_aten_lib (
81+ libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten"
82+ )
83+ from torchao .experimental .quant_api import UIntxWeightOnlyLinearQuantizer
84+
85+ with torch .no_grad ():
86+ model = (
87+ UIntxWeightOnlyLinearQuantizer (
88+ device = "mps" ,
89+ precision = torch .float32 ,
90+ groupsize = group_size ,
91+ bitwidth = bitwidth ,
92+ )
93+ .quantize (model )
94+ .to ("cpu" )
95+ )
96+
97+ if verbose :
98+ print ("quantized model:" , model )
99+ return model
100+ elif qmode .startswith ("torchao:8da" ):
76101 pattern = r"torchao:8da(\d+)w"
77102 matches = re .findall (pattern , qmode )
78103 assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
79104 bitwidth = int (matches [0 ][0 ])
80- _load_torchao_ops_aten ( )
105+ _load_torchao_aten_lib ( libname = "libtorchao_ops_aten" )
81106 from torchao .experimental .quant_api import Int8DynActIntxWeightLinearQuantizer
82107
83108 with torch .no_grad ():
@@ -729,7 +754,7 @@ def get_quant_embedding_transform(args):
729754 bitwidth , group_size = args .embedding_quantize .split (":" )[1 ].split ("," )
730755 group_size = int (group_size )
731756 bitwidth = int (bitwidth )
732- _load_torchao_ops_aten ( )
757+ _load_torchao_aten_lib ( libname = "libtorchao_ops_aten" )
733758 from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
734759
735760 def _torchao_embedding_quantizer (model ):
@@ -785,15 +810,15 @@ def get_quant_weight_transform(args, dtype_override, verbose):
785810 )
786811
787812
788- def _load_torchao_ops_aten ( ):
813+ def _load_torchao_aten_lib ( libname ):
789814 import glob
790815 import os
791816
792817 libs = glob .glob (
793818 os .path .abspath (
794819 os .path .join (
795820 os .environ .get ("CMAKE_INSTALL_PREFIX" , "" ),
796- "lib/libtorchao_ops_aten .*" ,
821+ f "lib/{ libname } .*" ,
797822 )
798823 )
799824 )
0 commit comments