4242 FP8Arguments ,
4343 GPTQArguments ,
4444 ModelArguments ,
45+ OptArguments ,
4546)
4647from fms_mo .utils .import_utils import available_packages
4748
5152def quantize (
5253 model_args : ModelArguments ,
5354 data_args : DataArguments ,
54- fms_mo_args : FMSMOArguments ,
55- gptq_args : GPTQArguments ,
56- fp8_args : FP8Arguments ,
57- quant_method : str ,
58- output_dir : str ,
55+ opt_args : OptArguments ,
56+ fms_mo_args : FMSMOArguments = None ,
57+ gptq_args : GPTQArguments = None ,
58+ fp8_args : FP8Arguments = None ,
5959):
6060 """Main entry point to quantize a given model with a set of specified hyperparameters
6161
@@ -71,16 +71,17 @@ def quantize(
7171 output_dir (str) Output directory to write to
7272 """
7373
74- logging .info (f"{ fms_mo_args } \n { quant_method } \n " )
75- if quant_method == "gptq" :
74+ logger .info (f"{ fms_mo_args } \n { opt_args .quant_method } \n " )
75+
76+ if opt_args .quant_method == "gptq" :
7677 if not available_packages ["auto_gptq" ]:
7778 raise ImportError (
7879 "Quantization method has been selected as gptq but unable to use external library, "
7980 "auto_gptq module not found. For more instructions on installing the appropriate "
8081 "package, see https://github.com/AutoGPTQ/AutoGPTQ?tab=readme-ov-file#installation"
8182 )
82- run_gptq (model_args , data_args , gptq_args , output_dir )
83- elif quant_method == "fp8" :
83+ run_gptq (model_args , data_args , opt_args , gptq_args )
84+ elif opt_args . quant_method == "fp8" :
8485 if not available_packages ["llmcompressor" ]:
8586 raise ImportError (
8687 "Quantization method has been selected as fp8 but unable to use external library, "
@@ -89,16 +90,18 @@ def quantize(
8990 "https://github.com/vllm-project/llm-compressor/tree/"
9091 "main?tab=readme-ov-file#installation"
9192 )
92- run_fp8 (model_args , data_args , fp8_args , output_dir )
93- elif quant_method == "dq" :
94- run_dq (model_args , data_args , fms_mo_args , output_dir )
93+ run_fp8 (model_args , data_args , opt_args , fp8_args )
94+ elif opt_args . quant_method == "dq" :
95+ run_dq (model_args , data_args , opt_args , fms_mo_args )
9596 else :
9697 raise ValueError (
97- "Not a valid quantization technique option. Please choose from: gptq, fp8, dq"
98+ "{} is not a valid quantization technique option. Please choose from: gptq, fp8, dq" .format (
99+ opt_args .quant_method
100+ )
98101 )
99102
100103
101- def run_gptq (model_args , data_args , gptq_args , output_dir ):
104+ def run_gptq (model_args , data_args , opt_args , gptq_args ):
102105 """GPTQ quantizes a given model with a set of specified hyperparameters
103106
104107 Args:
@@ -152,14 +155,16 @@ def run_gptq(model_args, data_args, gptq_args, output_dir):
152155 cache_examples_on_gpu = gptq_args .cache_examples_on_gpu ,
153156 )
154157
155- logger .info (f"Time to quantize model at { output_dir } : { time .time () - start_time } " )
158+ logger .info (
159+ f"Time to quantize model at { opt_args .output_dir } : { time .time () - start_time } "
160+ )
156161
157- logger .info (f"Saving quantized model and tokenizer to { output_dir } " )
158- model .save_quantized (output_dir , use_safetensors = True )
159- tokenizer .save_pretrained (output_dir )
162+ logger .info (f"Saving quantized model and tokenizer to { opt_args . output_dir } " )
163+ model .save_quantized (opt_args . output_dir , use_safetensors = True )
164+ tokenizer .save_pretrained (opt_args . output_dir )
160165
161166
162- def run_fp8 (model_args , data_args , fp8_args , output_dir ):
167+ def run_fp8 (model_args , data_args , opt_args , fp8_args ):
163168 """FP8 quantizes a given model with a set of specified hyperparameters
164169
165170 Args:
@@ -192,11 +197,13 @@ def run_fp8(model_args, data_args, fp8_args, output_dir):
192197 max_seq_length = data_args .max_seq_length ,
193198 num_calibration_samples = data_args .num_calibration_samples ,
194199 )
195- logger .info (f"Time to quantize model at { output_dir } : { time .time () - start_time } " )
200+ logger .info (
201+ f"Time to quantize model at { opt_args .output_dir } : { time .time () - start_time } "
202+ )
196203
197- logger .info (f"Saving quantized model and tokenizer to { output_dir } " )
198- model .save_pretrained (output_dir )
199- tokenizer .save_pretrained (output_dir )
204+ logger .info (f"Saving quantized model and tokenizer to { opt_args . output_dir } " )
205+ model .save_pretrained (opt_args . output_dir )
206+ tokenizer .save_pretrained (opt_args . output_dir )
200207
201208
202209def main ():
@@ -206,53 +213,41 @@ def main():
206213 dataclass_types = (
207214 ModelArguments ,
208215 DataArguments ,
216+ OptArguments ,
209217 FMSMOArguments ,
210218 GPTQArguments ,
211219 FP8Arguments ,
212220 )
213221 )
214222
215- parser .add_argument (
216- "--quant_method" ,
217- type = str .lower ,
218- choices = ["gptq" , "fp8" , None , "none" , "dq" ],
219- default = "none" ,
220- )
221-
222- parser .add_argument ("--output_dir" , type = str )
223-
224223 (
225224 model_args ,
226225 data_args ,
226+ opt_args ,
227227 fms_mo_args ,
228228 gptq_args ,
229229 fp8_args ,
230- additional ,
231230 _ ,
232231 ) = parser .parse_args_into_dataclasses (return_remaining_strings = True )
233- quant_method = additional .quant_method
234- output_dir = additional .output_dir
235232
236233 logger .debug (
237- "Input args parsed: \n model_args %s, data_args %s, fms_mo_args %s, "
238- "gptq_args %s, fp8_args %s, quant_method %s, output_dir %s " ,
234+ "Input args parsed: \n model_args %s, data_args %s, opt_args %s, fms_mo_args %s, "
235+ "gptq_args %s, fp8_args %s" ,
239236 model_args ,
240237 data_args ,
238+ opt_args ,
241239 fms_mo_args ,
242240 gptq_args ,
243241 fp8_args ,
244- quant_method ,
245- output_dir ,
246242 )
247243
248244 quantize (
249245 model_args = model_args ,
250246 data_args = data_args ,
247+ opt_args = opt_args ,
251248 fms_mo_args = fms_mo_args ,
252249 gptq_args = gptq_args ,
253250 fp8_args = fp8_args ,
254- quant_method = quant_method ,
255- output_dir = output_dir ,
256251 )
257252
258253
0 commit comments