Skip to content

Commit e965875

Browse files
Improve logging
Signed-off-by: Thara Palanivel <[email protected]>
1 parent 36b4156 commit e965875

File tree

3 files changed

+227
-25
lines changed

3 files changed

+227
-25
lines changed

fms_mo/run_quant.py

Lines changed: 137 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@
2727

2828
# Standard
2929
import logging
30+
import os
31+
import sys
3032
import time
33+
import traceback
3134

3235
# Third Party
3336
from datasets import load_from_disk
37+
from huggingface_hub.errors import HFValidationError
38+
from torch.cuda import OutOfMemoryError
3439
from transformers import AutoTokenizer
3540
import transformers
3641

@@ -44,9 +49,14 @@
4449
ModelArguments,
4550
OptArguments,
4651
)
52+
from fms_mo.utils.config_utils import get_json_config
53+
from fms_mo.utils.error_logging import (
54+
INTERNAL_ERROR_EXIT_CODE,
55+
USER_ERROR_EXIT_CODE,
56+
write_termination_log,
57+
)
4758
from fms_mo.utils.import_utils import available_packages
48-
49-
logger = logging.Logger("fms_mo.main")
59+
from fms_mo.utils.logging_utils import set_log_level
5060

5161

5262
def quantize(
@@ -71,6 +81,8 @@ def quantize(
7181
output_dir (str) Output directory to write to
7282
"""
7383

84+
logger = set_log_level(opt_args.log_level, "fms_mo.quantize")
85+
7486
logger.info(f"{fms_mo_args}\n{opt_args.quant_method}\n")
7587

7688
if opt_args.quant_method == "gptq":
@@ -120,6 +132,8 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
120132
# Local
121133
from fms_mo.utils.custom_gptq_models import custom_gptq_classes
122134

135+
logger = set_log_level(opt_args.log_level, "fms_mo.run_gptq")
136+
123137
quantize_config = BaseQuantizeConfig(
124138
bits=gptq_args.bits,
125139
group_size=gptq_args.group_size,
@@ -179,6 +193,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
179193
from llmcompressor.modifiers.quantization import QuantizationModifier
180194
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
181195

196+
logger = set_log_level(opt_args.log_level, "fms_mo.run_fp8")
197+
182198
model = SparseAutoModelForCausalLM.from_pretrained(
183199
model_args.model_name_or_path, torch_dtype=model_args.torch_dtype
184200
)
@@ -205,9 +221,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
205221
tokenizer.save_pretrained(opt_args.output_dir)
206222

207223

208-
def main():
209-
"""Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""
210-
224+
def get_parser():
225+
"""Get the command-line argument parser."""
211226
parser = transformers.HfArgumentParser(
212227
dataclass_types=(
213228
ModelArguments,
@@ -218,20 +233,53 @@ def main():
218233
FP8Arguments,
219234
)
220235
)
236+
return parser
221237

222-
(
223-
model_args,
224-
data_args,
225-
opt_args,
226-
fms_mo_args,
227-
gptq_args,
228-
fp8_args,
229-
_,
230-
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
231238

232-
logger.debug(
233-
"Input args parsed: \nmodel_args %s, data_args %s, opt_args %s, fms_mo_args %s, "
234-
"gptq_args %s, fp8_args %s",
239+
def parse_arguments(parser, json_config=None):
240+
"""Parses arguments provided either via command-line or JSON config.
241+
242+
Args:
243+
parser: argparse.ArgumentParser
244+
Command-line argument parser.
245+
json_config: dict[str, Any]
246+
Dict of arguments to use with tuning.
247+
248+
Returns:
249+
ModelArguments
250+
Arguments pertaining to which model we are going to quantize.
251+
DataArguments
252+
Arguments pertaining to what data we are going to use for optimization and evaluation.
253+
OptArguments
254+
Arguments generic to optimization.
255+
FMSMOArguments
256+
Configuration for PTQ quantization.
257+
GPTQArguments
258+
Configuration for GPTQ quantization.
259+
FP8Arguments
260+
Configuration for FP8 quantization.
261+
"""
262+
if json_config:
263+
(
264+
model_args,
265+
data_args,
266+
opt_args,
267+
fms_mo_args,
268+
gptq_args,
269+
fp8_args,
270+
) = parser.parse_dict(json_config, allow_extra_keys=True)
271+
else:
272+
(
273+
model_args,
274+
data_args,
275+
opt_args,
276+
fms_mo_args,
277+
gptq_args,
278+
fp8_args,
279+
_,
280+
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
281+
282+
return (
235283
model_args,
236284
data_args,
237285
opt_args,
@@ -240,14 +288,78 @@ def main():
240288
fp8_args,
241289
)
242290

243-
quantize(
244-
model_args=model_args,
245-
data_args=data_args,
246-
opt_args=opt_args,
247-
fms_mo_args=fms_mo_args,
248-
gptq_args=gptq_args,
249-
fp8_args=fp8_args,
250-
)
291+
292+
def main():
293+
"""Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""
294+
295+
parser = get_parser()
296+
logger = logging.getLogger()
297+
job_config = get_json_config()
298+
# accept arguments via command-line or JSON
299+
try:
300+
(
301+
model_args,
302+
data_args,
303+
opt_args,
304+
fms_mo_args,
305+
gptq_args,
306+
fp8_args,
307+
) = parse_arguments(parser, job_config)
308+
309+
logger = set_log_level(opt_args.log_level, __name__)
310+
311+
logger.debug(
312+
"Input args parsed: \nmodel_args %s, data_args %s, opt_args %s, fms_mo_args %s, gptq_args %s, fp8_args %s",
313+
model_args,
314+
data_args,
315+
opt_args,
316+
fms_mo_args,
317+
gptq_args,
318+
fp8_args,
319+
)
320+
except Exception as e: # pylint: disable=broad-except
321+
logger.error(traceback.format_exc())
322+
write_termination_log(
323+
f"Exception raised during optimization. This may be a problem with your input: {e}"
324+
)
325+
sys.exit(USER_ERROR_EXIT_CODE)
326+
327+
if opt_args.output_dir:
328+
os.makedirs(opt_args.output_dir, exist_ok=True)
329+
logger.info("Using the output directory at %s", opt_args.output_dir)
330+
try:
331+
quantize(
332+
model_args=model_args,
333+
data_args=data_args,
334+
opt_args=opt_args,
335+
fms_mo_args=fms_mo_args,
336+
gptq_args=gptq_args,
337+
fp8_args=fp8_args,
338+
)
339+
except (MemoryError, OutOfMemoryError) as e:
340+
logger.error(traceback.format_exc())
341+
write_termination_log(f"OOM error during optimization. {e}")
342+
sys.exit(INTERNAL_ERROR_EXIT_CODE)
343+
except FileNotFoundError as e:
344+
logger.error(traceback.format_exc())
345+
write_termination_log("Unable to load file: {}".format(e))
346+
sys.exit(USER_ERROR_EXIT_CODE)
347+
except HFValidationError as e:
348+
logger.error(traceback.format_exc())
349+
write_termination_log(
350+
f"There may be a problem with loading the model. Exception: {e}"
351+
)
352+
sys.exit(USER_ERROR_EXIT_CODE)
353+
except (TypeError, ValueError, EnvironmentError) as e:
354+
logger.error(traceback.format_exc())
355+
write_termination_log(
356+
f"Exception raised during optimization. This may be a problem with your input: {e}"
357+
)
358+
sys.exit(USER_ERROR_EXIT_CODE)
359+
except Exception as e: # pylint: disable=broad-except
360+
logger.error(traceback.format_exc())
361+
write_termination_log(f"Unhandled exception during optimization: {e}")
362+
sys.exit(INTERNAL_ERROR_EXIT_CODE)
251363

252364

253365
if __name__ == "__main__":

fms_mo/utils/error_logging.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
import logging
17+
import os
18+
19+
# The USER_ERROR_EXIT_CODE will be thrown when the process must exit
20+
# as result of a user input error. User-related errors should be
21+
# >= 1 and <=127 due to how some kubernetes operators interpret them.
22+
USER_ERROR_EXIT_CODE = 1
23+
# The INTERNAL_ERROR_EXIT_CODE will be thrown when training
24+
# abnormally terminates, and it is not clearly fault of the user.
25+
# System-level errors should be >= 128 and <= 254
26+
INTERNAL_ERROR_EXIT_CODE = 203
27+
28+
29+
def write_termination_log(text, log_file="error.log"):
30+
"""Writes text to termination log.
31+
32+
Args:
33+
text: str
34+
log_file: Optional[str]
35+
"""
36+
log_file = os.environ.get("TERMINATION_LOG_FILE", log_file)
37+
try:
38+
with open(log_file, "a", encoding="utf-8") as handle:
39+
handle.write(text)
40+
except Exception as e: # pylint: disable=broad-except
41+
logging.warning("Unable to write termination log due to error {}".format(e))

fms_mo/utils/logging_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
import logging
17+
import os
18+
19+
20+
def set_log_level(log_level=None, logger_name=None):
21+
"""Set log level of python native logger and TF logger via argument from CLI or env variable.
22+
23+
Args:
24+
train_args
25+
Training arguments for training model.
26+
logger_name
27+
Logger name with which the logger is instantiated.
28+
29+
Returns:
30+
train_args
31+
Updated training arguments for training model.
32+
train_logger
33+
Logger with updated effective log level
34+
"""
35+
36+
# Clear any existing handlers if necessary
37+
for handler in logging.root.handlers[:]:
38+
logging.root.removeHandler(handler)
39+
40+
# Configure Python native logger
41+
# If CLI arg is passed, assign same log level to python native logger
42+
log_level = log_level or os.environ.get("LOG_LEVEL", "WARNING")
43+
44+
logging.basicConfig(
45+
format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper()
46+
)
47+
48+
logger = logging.getLogger(logger_name) if logger_name else logging.getLogger()
49+
return logger

0 commit comments

Comments
 (0)