@@ -98,21 +98,38 @@ def quantize( # noqa C901
9898 matches = re .findall (pattern , qmode )
9999 assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
100100 bitwidth = int (matches [0 ][0 ])
101- _load_torchao_aten_lib (libname = "libtorchao_ops_aten" )
102- from torchao .experimental .quant_api import Int8DynActIntxWeightLinearQuantizer
101+ # _load_torchao_aten_lib(libname="libtorchao_ops_aten")
102+ # from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
103+ from torchao .experimental .quant_api import int8_dynamic_activation_intx_weight , Int8DynActIntxWeightLinearQuantizer
104+ from torchao .quantization .quant_api import quantize_
105+ from torchao .utils import unwrap_tensor_subclass
106+ from torchao .quantization .granularity import PerRow , PerGroup
103107
104108 with torch .no_grad ():
105- model = Int8DynActIntxWeightLinearQuantizer (
106- device = "cpu" ,
107- precision = torch .float32 ,
108- groupsize = group_size ,
109- bitwidth = bitwidth ,
110- has_weight_zeros = False ,
111- ).quantize (model )
112-
109+ # model = Int8DynActIntxWeightLinearQuantizer(
110+ # device="cpu",
111+ # precision=torch.float32,
112+ # groupsize=group_size,
113+ # bitwidth=bitwidth,
114+ # has_weight_zeros=False,
115+ # ).quantize(model)
116+
117+ quantize_ (model ,
118+ int8_dynamic_activation_intx_weight (
119+ # group_size=group_size,
120+ # nbit=bitwidth,
121+ # has_weight_zeros=False,
122+ weight_dtype = getattr (torch , f"int{ bitwidth } " ),
123+ granularity = PerRow () if group_size == 0 else PerGroup (group_size ),
124+ has_weight_zeros = False ,
125+ ),
126+ )
127+ model = unwrap_tensor_subclass (model )
113128 if verbose :
114129 print ("quantized model:" , model )
115130 return model
131+
132+ return model
116133 elif qmode == "8da4w" :
117134 # Check for required args
118135 if group_size is None :
@@ -752,7 +769,7 @@ def get_quant_embedding_transform(args):
752769 bitwidth , group_size = args .embedding_quantize .split (":" )[1 ].split ("," )
753770 group_size = int (group_size )
754771 bitwidth = int (bitwidth )
755- _load_torchao_aten_lib (libname = "libtorchao_ops_aten" )
772+ # _load_torchao_aten_lib(libname="libtorchao_ops_aten")
756773 from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
757774
758775 def _torchao_embedding_quantizer (model ):
0 commit comments