2727
2828# Standard
2929import logging
30+ import os
31+ import sys
3032import time
33+ import traceback
3134
3235# Third Party
3336from datasets import load_from_disk
37+ from huggingface_hub .errors import HFValidationError
38+ from torch .cuda import OutOfMemoryError
3439from transformers import AutoTokenizer
3540import transformers
3641
4449 ModelArguments ,
4550 OptArguments ,
4651)
52+ from fms_mo .utils .config_utils import get_json_config
53+ from fms_mo .utils .error_logging import (
54+ INTERNAL_ERROR_EXIT_CODE ,
55+ USER_ERROR_EXIT_CODE ,
56+ write_termination_log ,
57+ )
4758from fms_mo .utils .import_utils import available_packages
48-
49- logger = logging .Logger ("fms_mo.main" )
59+ from fms_mo .utils .logging_utils import set_log_level
5060
5161
5262def quantize (
@@ -70,6 +80,8 @@ def quantize(
7080 fp8_args (fms_mo.training_args.FP8Arguments): Parameters to use for FP8 quantization
7181 """
7282
83+ logger = set_log_level (opt_args .log_level , "fms_mo.quantize" )
84+
7385 logger .info (f"{ fms_mo_args } \n { opt_args .quant_method } \n " )
7486
7587 if opt_args .quant_method == "gptq" :
@@ -119,6 +131,8 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
119131 # Local
120132 from fms_mo .utils .custom_gptq_models import custom_gptq_classes
121133
134+ logger = set_log_level (opt_args .log_level , "fms_mo.run_gptq" )
135+
122136 quantize_config = BaseQuantizeConfig (
123137 bits = gptq_args .bits ,
124138 group_size = gptq_args .group_size ,
@@ -178,6 +192,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
178192 from llmcompressor .modifiers .quantization import QuantizationModifier
179193 from llmcompressor .transformers import SparseAutoModelForCausalLM , oneshot
180194
195+ logger = set_log_level (opt_args .log_level , "fms_mo.run_fp8" )
196+
181197 model = SparseAutoModelForCausalLM .from_pretrained (
182198 model_args .model_name_or_path , torch_dtype = model_args .torch_dtype
183199 )
@@ -204,9 +220,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
204220 tokenizer .save_pretrained (opt_args .output_dir )
205221
206222
207- def main ():
208- """Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""
209-
223+ def get_parser ():
224+ """Get the command-line argument parser."""
210225 parser = transformers .HfArgumentParser (
211226 dataclass_types = (
212227 ModelArguments ,
@@ -217,20 +232,53 @@ def main():
217232 FP8Arguments ,
218233 )
219234 )
235+ return parser
220236
221- (
222- model_args ,
223- data_args ,
224- opt_args ,
225- fms_mo_args ,
226- gptq_args ,
227- fp8_args ,
228- _ ,
229- ) = parser .parse_args_into_dataclasses (return_remaining_strings = True )
230237
231- logger .debug (
232- "Input args parsed: \n model_args %s, data_args %s, opt_args %s, fms_mo_args %s, "
233- "gptq_args %s, fp8_args %s" ,
238+ def parse_arguments (parser , json_config = None ):
239+ """Parses arguments provided either via command-line or JSON config.
240+
241+ Args:
242+ parser: argparse.ArgumentParser
243+ Command-line argument parser.
244+ json_config: dict[str, Any]
245+ Dict of arguments to use with tuning.
246+
247+ Returns:
248+ ModelArguments
249+ Arguments pertaining to which model we are going to quantize.
250+ DataArguments
251+ Arguments pertaining to what data we are going to use for optimization and evaluation.
252+ OptArguments
253+ Arguments generic to optimization.
254+ FMSMOArguments
255+ Configuration for PTQ quantization.
256+ GPTQArguments
257+ Configuration for GPTQ quantization.
258+ FP8Arguments
259+ Configuration for FP8 quantization.
260+ """
261+ if json_config :
262+ (
263+ model_args ,
264+ data_args ,
265+ opt_args ,
266+ fms_mo_args ,
267+ gptq_args ,
268+ fp8_args ,
269+ ) = parser .parse_dict (json_config , allow_extra_keys = True )
270+ else :
271+ (
272+ model_args ,
273+ data_args ,
274+ opt_args ,
275+ fms_mo_args ,
276+ gptq_args ,
277+ fp8_args ,
278+ _ ,
279+ ) = parser .parse_args_into_dataclasses (return_remaining_strings = True )
280+
281+ return (
234282 model_args ,
235283 data_args ,
236284 opt_args ,
@@ -239,14 +287,72 @@ def main():
239287 fp8_args ,
240288 )
241289
242- quantize (
243- model_args = model_args ,
244- data_args = data_args ,
245- opt_args = opt_args ,
246- fms_mo_args = fms_mo_args ,
247- gptq_args = gptq_args ,
248- fp8_args = fp8_args ,
249- )
290+
291+ def main ():
292+ """Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""
293+
294+ parser = get_parser ()
295+ logger = logging .getLogger ()
296+ job_config = get_json_config ()
297+ # accept arguments via command-line or JSON
298+ try :
299+ (
300+ model_args ,
301+ data_args ,
302+ opt_args ,
303+ fms_mo_args ,
304+ gptq_args ,
305+ fp8_args ,
306+ ) = parse_arguments (parser , job_config )
307+
308+ logger = set_log_level (opt_args .log_level , __name__ )
309+
310+ logger .debug (f"Input args parsed: \n model_args { model_args } , data_args { data_args } , \
311+ opt_args { opt_args } , fms_mo_args { fms_mo_args } , gptq_args { gptq_args } , \
312+ fp8_args { fp8_args } " )
313+ except Exception as e : # pylint: disable=broad-except
314+ logger .error (traceback .format_exc ())
315+ write_termination_log (
316+ f"Exception raised during optimization. This may be a problem with your input: { e } "
317+ )
318+ sys .exit (USER_ERROR_EXIT_CODE )
319+
320+ if opt_args .output_dir :
321+ os .makedirs (opt_args .output_dir , exist_ok = True )
322+ logger .info ("Using the output directory at %s" , opt_args .output_dir )
323+ try :
324+ quantize (
325+ model_args = model_args ,
326+ data_args = data_args ,
327+ opt_args = opt_args ,
328+ fms_mo_args = fms_mo_args ,
329+ gptq_args = gptq_args ,
330+ fp8_args = fp8_args ,
331+ )
332+ except (MemoryError , OutOfMemoryError ) as e :
333+ logger .error (traceback .format_exc ())
334+ write_termination_log (f"OOM error during optimization. { e } " )
335+ sys .exit (INTERNAL_ERROR_EXIT_CODE )
336+ except FileNotFoundError as e :
337+ logger .error (traceback .format_exc ())
338+ write_termination_log (f"Unable to load file: { e } " )
339+ sys .exit (USER_ERROR_EXIT_CODE )
340+ except HFValidationError as e :
341+ logger .error (traceback .format_exc ())
342+ write_termination_log (
343+ f"There may be a problem with loading the model. Exception: { e } "
344+ )
345+ sys .exit (USER_ERROR_EXIT_CODE )
346+ except (TypeError , ValueError , EnvironmentError ) as e :
347+ logger .error (traceback .format_exc ())
348+ write_termination_log (
349+ f"Exception raised during optimization. This may be a problem with your input: { e } "
350+ )
351+ sys .exit (USER_ERROR_EXIT_CODE )
352+ except Exception as e : # pylint: disable=broad-except
353+ logger .error (traceback .format_exc ())
354+ write_termination_log (f"Unhandled exception during optimization: { e } " )
355+ sys .exit (INTERNAL_ERROR_EXIT_CODE )
250356
251357
252358if __name__ == "__main__" :
0 commit comments