@@ -57,7 +57,7 @@ def quantize( # noqa C901
5757
5858 Args:
5959 model: The model to quantize.
60- qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq .
60+ qmode: The quantization mode, e.g. int8, 8da4w.
6161 computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
6262 Also the dtype of the rest of the non-quantized compoents of the model.
6363 checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
@@ -161,58 +161,6 @@ def quantize( # noqa C901
161161 if verbose :
162162 print ("quantized model:" , model )
163163 return model
164- elif qmode == "8da4w-gptq" :
165- # Check for required args
166- required_args : Optional [Any ] = [
167- group_size ,
168- calibration_limit ,
169- calibration_seq_length ,
170- ]
171- if any (arg is None for arg in required_args ):
172- raise Exception (
173- "For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
174- )
175- if calibration_tasks is None :
176- calibration_tasks = ["wikitext" ]
177-
178- try :
179- # torchao 0.3+
180- from torchao ._models ._eval import InputRecorder
181- except ImportError :
182- from torchao .quantization .GPTQ import InputRecorder # pyre-ignore
183-
184- from torchao .quantization .quant_api import Int8DynActInt4WeightGPTQQuantizer
185-
186- if tokenizer_path is None :
187- assert checkpoint_path is not None , "checkpoint_path must be specified"
188- tokenizer_path = checkpoint_path .parent / "tokenizer.model"
189- assert tokenizer_path .is_file (), tokenizer_path
190- tokenizer = SentencePieceProcessor ( # pyre-ignore[28]
191- model_file = str (tokenizer_path )
192- )
193-
194- inputs = (
195- InputRecorder ( # pyre-fixme[16]
196- tokenizer ,
197- calibration_seq_length ,
198- None , # input_prep_func
199- pad_calibration_inputs ,
200- model .vocab_size ,
201- )
202- .record_inputs (
203- calibration_tasks ,
204- calibration_limit ,
205- )
206- .get_inputs ()
207- )
208-
209- gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer (
210- blocksize ,
211- percdamp ,
212- group_size ,
213- ) # TODO: separate computation and checkpoint dtype for GPTQ.
214- model = gptq_quantizer .quantize (model , inputs )
215- return model
216164 elif qmode == "vulkan_4w" :
217165 from executorch .backends .vulkan ._passes import VkInt4WeightOnlyQuantizer
218166
0 commit comments