@@ -73,69 +73,24 @@ def quantize( # noqa C901
7373 # Add quantization mode options here: group size, bit width, etc.
7474 return WeightOnlyInt8QuantHandler (model ).quantized_model ()
7575 elif qmode .startswith ("torchao:" ):
76- import glob
77- import os
78-
79- libs = glob .glob (
80- os .path .abspath (
81- os .path .join (
82- os .environ .get ("CMAKE_INSTALL_PREFIX" , "" ),
83- "lib/libtorchao_ops_aten.*" ,
84- )
85- )
86- )
87- assert (
88- len (libs ) == 1
89- ), f"Expected 1 library but got { len (libs )} . If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
90- logging .info (f"Loading custom ops library: { libs [0 ]} " )
91- torch .ops .load_library (libs [0 ])
92-
93- logging .warning (
94- "When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."
95- )
96- embedding_pattern = r"emb.(\d+),(\d+)"
97- linear_pattern = r"lin8da.(\d+),(\d+)"
98-
99- matches = re .findall (linear_pattern , qmode )
100- if matches :
101- assert (
102- len (matches ) == 1
103- ), f"Expected 1 match for linear_pattern but got { len (matches )} "
104- bitwidth = int (matches [0 ][0 ])
105- groupsize = int (matches [0 ][1 ])
106- from torchao .experimental .quant_api import (
107- Int8DynActIntxWeightLinearQuantizer ,
108- )
109-
110- with torch .no_grad ():
111- model = Int8DynActIntxWeightLinearQuantizer (
112- device = "cpu" ,
113- precision = torch_dtype ,
114- groupsize = groupsize ,
115- bitwidth = bitwidth ,
116- has_weight_zeros = False ,
117- ).quantize (model )
118-
119- matches = re .findall (embedding_pattern , qmode )
120- if matches :
121- assert (
122- len (matches ) == 1
123- ), f"Expected 1 match for embedding_pattern but got { len (matches )} "
124- bitwidth = int (matches [0 ][0 ])
125- groupsize = int (matches [0 ][1 ])
126- from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
127-
128- with torch .no_grad ():
129- model = IntxWeightEmbeddingQuantizer (
130- device = "cpu" ,
131- precision = torch_dtype ,
132- bitwidth = bitwidth ,
133- groupsize = groupsize ,
134- ).quantize (model )
76+ pattern = r"torchao:8da(\d+)w"
77+ matches = re .findall (pattern , qmode )
78+ assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
79+ bitwidth = int (matches [0 ][0 ])
80+ _load_torchao_ops_aten ()
81+ from torchao .experimental .quant_api import Int8DynActIntxWeightLinearQuantizer
82+
83+ with torch .no_grad ():
84+ model = Int8DynActIntxWeightLinearQuantizer (
85+ device = "cpu" ,
86+ precision = torch .float32 ,
87+ groupsize = group_size ,
88+ bitwidth = bitwidth ,
89+ has_weight_zeros = False ,
90+ ).quantize (model )
13591
13692 if verbose :
13793 print ("quantized model:" , model )
138-
13994 return model
14095 elif qmode == "8da4w" :
14196 # Check for required args
@@ -760,6 +715,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
760715
761716
762717def get_quant_embedding_transform (args ):
718+ if args .embedding_quantize .startswith ("torchao:" ):
719+ bitwidth , group_size = args .embedding_quantize .split (":" )[1 ].split ("," )
720+ group_size = int (group_size )
721+ bitwidth = int (bitwidth )
722+ _load_torchao_ops_aten ()
723+ from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
724+
725+ def _torchao_embedding_quantizer (model ):
726+ with torch .no_grad ():
727+ model = IntxWeightEmbeddingQuantizer (
728+ device = "cpu" ,
729+ precision = torch .float32 ,
730+ bitwidth = bitwidth ,
731+ groupsize = group_size ,
732+ ).quantize (model )
733+ return model
734+
735+ return _torchao_embedding_quantizer
736+
763737 bitwidth , group_size = args .embedding_quantize .split ("," )
764738 if group_size == "none" or group_size == "None" or group_size == "0" :
765739 group_size = None
@@ -801,4 +775,23 @@ def get_quant_weight_transform(args, dtype_override, verbose):
801775 )
802776
803777
778+ def _load_torchao_ops_aten ():
779+ import glob
780+ import os
781+
782+ libs = glob .glob (
783+ os .path .abspath (
784+ os .path .join (
785+ os .environ .get ("CMAKE_INSTALL_PREFIX" , "" ),
786+ "lib/libtorchao_ops_aten.*" ,
787+ )
788+ )
789+ )
790+ assert (
791+ len (libs ) == 1
792+ ), f"Expected 1 library but got { len (libs )} . If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
793+ logging .info (f"Loading custom ops library: { libs [0 ]} " )
794+ torch .ops .load_library (libs [0 ])
795+
796+
804797############################ Source Transform End #######################
0 commit comments