2727
2828# Standard
2929import logging
30+ import os
31+ import sys
3032import time
33+ import traceback
3134
3235# Third Party
3336from datasets import load_from_disk
37+ from huggingface_hub .errors import HFValidationError
38+ from torch .cuda import OutOfMemoryError
3439from transformers import AutoTokenizer
3540import transformers
3641
4449 ModelArguments ,
4550 OptArguments ,
4651)
52+ from fms_mo .utils .config_utils import get_json_config
53+ from fms_mo .utils .error_logging import (
54+ INTERNAL_ERROR_EXIT_CODE ,
55+ USER_ERROR_EXIT_CODE ,
56+ write_termination_log ,
57+ )
4758from fms_mo .utils .import_utils import available_packages
48-
49- logger = logging .Logger ("fms_mo.main" )
59+ from fms_mo .utils .logging_utils import set_log_level
5060
5161
5262def quantize (
@@ -71,6 +81,8 @@ def quantize(
7181 output_dir (str) Output directory to write to
7282 """
7383
84+ logger = set_log_level (opt_args .log_level , "fms_mo.quantize" )
85+
7486 logger .info (f"{ fms_mo_args } \n { opt_args .quant_method } \n " )
7587
7688 if opt_args .quant_method == "gptq" :
@@ -120,6 +132,8 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
120132 # Local
121133 from fms_mo .utils .custom_gptq_models import custom_gptq_classes
122134
135+ logger = set_log_level (opt_args .log_level , "fms_mo.run_gptq" )
136+
123137 quantize_config = BaseQuantizeConfig (
124138 bits = gptq_args .bits ,
125139 group_size = gptq_args .group_size ,
@@ -179,6 +193,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
179193 from llmcompressor .modifiers .quantization import QuantizationModifier
180194 from llmcompressor .transformers import SparseAutoModelForCausalLM , oneshot
181195
196+ logger = set_log_level (opt_args .log_level , "fms_mo.run_fp8" )
197+
182198 model = SparseAutoModelForCausalLM .from_pretrained (
183199 model_args .model_name_or_path , torch_dtype = model_args .torch_dtype
184200 )
@@ -205,9 +221,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
205221 tokenizer .save_pretrained (opt_args .output_dir )
206222
207223
208- def main ():
209- """Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""
210-
224+ def get_parser ():
225+ """Get the command-line argument parser."""
211226 parser = transformers .HfArgumentParser (
212227 dataclass_types = (
213228 ModelArguments ,
@@ -218,20 +233,53 @@ def main():
218233 FP8Arguments ,
219234 )
220235 )
236+ return parser
221237
222- (
223- model_args ,
224- data_args ,
225- opt_args ,
226- fms_mo_args ,
227- gptq_args ,
228- fp8_args ,
229- _ ,
230- ) = parser .parse_args_into_dataclasses (return_remaining_strings = True )
231238
232- logger .debug (
233- "Input args parsed: \n model_args %s, data_args %s, opt_args %s, fms_mo_args %s, "
234- "gptq_args %s, fp8_args %s" ,
239+ def parse_arguments (parser , json_config = None ):
240+ """Parses arguments provided either via command-line or JSON config.
241+
242+ Args:
243+ parser: argparse.ArgumentParser
244+ Command-line argument parser.
245+ json_config: dict[str, Any]
246+ Dict of arguments to use with tuning.
247+
248+ Returns:
249+ ModelArguments
250+ Arguments pertaining to which model we are going to quantize.
251+ DataArguments
252+ Arguments pertaining to what data we are going to use for optimization and evaluation.
253+ OptArguments
254+ Arguments generic to optimization.
255+ FMSMOArguments
256+ Configuration for PTQ quantization.
257+ GPTQArguments
258+ Configuration for GPTQ quantization.
259+ FP8Arguments
260+ Configuration for FP8 quantization.
261+ """
262+ if json_config :
263+ (
264+ model_args ,
265+ data_args ,
266+ opt_args ,
267+ fms_mo_args ,
268+ gptq_args ,
269+ fp8_args ,
270+ ) = parser .parse_dict (json_config , allow_extra_keys = True )
271+ else :
272+ (
273+ model_args ,
274+ data_args ,
275+ opt_args ,
276+ fms_mo_args ,
277+ gptq_args ,
278+ fp8_args ,
279+ _ ,
280+ ) = parser .parse_args_into_dataclasses (return_remaining_strings = True )
281+
282+ return (
235283 model_args ,
236284 data_args ,
237285 opt_args ,
@@ -240,14 +288,78 @@ def main():
240288 fp8_args ,
241289 )
242290
243- quantize (
244- model_args = model_args ,
245- data_args = data_args ,
246- opt_args = opt_args ,
247- fms_mo_args = fms_mo_args ,
248- gptq_args = gptq_args ,
249- fp8_args = fp8_args ,
250- )
291+
292+ def main ():
293+ """Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""
294+
295+ parser = get_parser ()
296+ logger = logging .getLogger ()
297+ job_config = get_json_config ()
298+ # accept arguments via command-line or JSON
299+ try :
300+ (
301+ model_args ,
302+ data_args ,
303+ opt_args ,
304+ fms_mo_args ,
305+ gptq_args ,
306+ fp8_args ,
307+ ) = parse_arguments (parser , job_config )
308+
309+ logger = set_log_level (opt_args .log_level , __name__ )
310+
311+ logger .debug (
312+ "Input args parsed: \n model_args %s, data_args %s, opt_args %s, fms_mo_args %s, gptq_args %s, fp8_args %s" ,
313+ model_args ,
314+ data_args ,
315+ opt_args ,
316+ fms_mo_args ,
317+ gptq_args ,
318+ fp8_args ,
319+ )
320+ except Exception as e : # pylint: disable=broad-except
321+ logger .error (traceback .format_exc ())
322+ write_termination_log (
323+ f"Exception raised during optimization. This may be a problem with your input: { e } "
324+ )
325+ sys .exit (USER_ERROR_EXIT_CODE )
326+
327+ if opt_args .output_dir :
328+ os .makedirs (opt_args .output_dir , exist_ok = True )
329+ logger .info ("Using the output directory at %s" , opt_args .output_dir )
330+ try :
331+ quantize (
332+ model_args = model_args ,
333+ data_args = data_args ,
334+ opt_args = opt_args ,
335+ fms_mo_args = fms_mo_args ,
336+ gptq_args = gptq_args ,
337+ fp8_args = fp8_args ,
338+ )
339+ except (MemoryError , OutOfMemoryError ) as e :
340+ logger .error (traceback .format_exc ())
341+ write_termination_log (f"OOM error during optimization. { e } " )
342+ sys .exit (INTERNAL_ERROR_EXIT_CODE )
343+ except FileNotFoundError as e :
344+ logger .error (traceback .format_exc ())
345+ write_termination_log ("Unable to load file: {}" .format (e ))
346+ sys .exit (USER_ERROR_EXIT_CODE )
347+ except HFValidationError as e :
348+ logger .error (traceback .format_exc ())
349+ write_termination_log (
350+ f"There may be a problem with loading the model. Exception: { e } "
351+ )
352+ sys .exit (USER_ERROR_EXIT_CODE )
353+ except (TypeError , ValueError , EnvironmentError ) as e :
354+ logger .error (traceback .format_exc ())
355+ write_termination_log (
356+ f"Exception raised during optimization. This may be a problem with your input: { e } "
357+ )
358+ sys .exit (USER_ERROR_EXIT_CODE )
359+ except Exception as e : # pylint: disable=broad-except
360+ logger .error (traceback .format_exc ())
361+ write_termination_log (f"Unhandled exception during optimization: { e } " )
362+ sys .exit (INTERNAL_ERROR_EXIT_CODE )
251363
252364
253365if __name__ == "__main__" :
0 commit comments