Skip to content

Commit 1e7856e

Browse files
Merge pull request #44 from tharapalanivel/logging
Add logging and tests for run_quant.py
2 parents c56a37b + d94088e commit 1e7856e

File tree

13 files changed

+546
-26
lines changed

13 files changed

+546
-26
lines changed

fms_mo/run_quant.py

Lines changed: 131 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@
2727

2828
# Standard
2929
import logging
30+
import os
31+
import sys
3032
import time
33+
import traceback
3134

3235
# Third Party
3336
from datasets import load_from_disk
37+
from huggingface_hub.errors import HFValidationError
38+
from torch.cuda import OutOfMemoryError
3439
from transformers import AutoTokenizer
3540
import transformers
3641

@@ -44,9 +49,14 @@
4449
ModelArguments,
4550
OptArguments,
4651
)
52+
from fms_mo.utils.config_utils import get_json_config
53+
from fms_mo.utils.error_logging import (
54+
INTERNAL_ERROR_EXIT_CODE,
55+
USER_ERROR_EXIT_CODE,
56+
write_termination_log,
57+
)
4758
from fms_mo.utils.import_utils import available_packages
48-
49-
logger = logging.Logger("fms_mo.main")
59+
from fms_mo.utils.logging_utils import set_log_level
5060

5161

5262
def quantize(
@@ -70,6 +80,8 @@ def quantize(
7080
fp8_args (fms_mo.training_args.FP8Arguments): Parameters to use for FP8 quantization
7181
"""
7282

83+
logger = set_log_level(opt_args.log_level, "fms_mo.quantize")
84+
7385
logger.info(f"{fms_mo_args}\n{opt_args.quant_method}\n")
7486

7587
if opt_args.quant_method == "gptq":
@@ -119,6 +131,8 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
119131
# Local
120132
from fms_mo.utils.custom_gptq_models import custom_gptq_classes
121133

134+
logger = set_log_level(opt_args.log_level, "fms_mo.run_gptq")
135+
122136
quantize_config = BaseQuantizeConfig(
123137
bits=gptq_args.bits,
124138
group_size=gptq_args.group_size,
@@ -178,6 +192,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
178192
from llmcompressor.modifiers.quantization import QuantizationModifier
179193
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
180194

195+
logger = set_log_level(opt_args.log_level, "fms_mo.run_fp8")
196+
181197
model = SparseAutoModelForCausalLM.from_pretrained(
182198
model_args.model_name_or_path, torch_dtype=model_args.torch_dtype
183199
)
@@ -204,9 +220,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
204220
tokenizer.save_pretrained(opt_args.output_dir)
205221

206222

207-
def main():
208-
"""Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""
209-
223+
def get_parser():
224+
"""Get the command-line argument parser."""
210225
parser = transformers.HfArgumentParser(
211226
dataclass_types=(
212227
ModelArguments,
@@ -217,20 +232,53 @@ def main():
217232
FP8Arguments,
218233
)
219234
)
235+
return parser
220236

221-
(
222-
model_args,
223-
data_args,
224-
opt_args,
225-
fms_mo_args,
226-
gptq_args,
227-
fp8_args,
228-
_,
229-
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
230237

231-
logger.debug(
232-
"Input args parsed: \nmodel_args %s, data_args %s, opt_args %s, fms_mo_args %s, "
233-
"gptq_args %s, fp8_args %s",
238+
def parse_arguments(parser, json_config=None):
239+
"""Parses arguments provided either via command-line or JSON config.
240+
241+
Args:
242+
parser: argparse.ArgumentParser
243+
Command-line argument parser.
244+
json_config: dict[str, Any]
245+
Dict of arguments to use with tuning.
246+
247+
Returns:
248+
ModelArguments
249+
Arguments pertaining to which model we are going to quantize.
250+
DataArguments
251+
Arguments pertaining to what data we are going to use for optimization and evaluation.
252+
OptArguments
253+
Arguments generic to optimization.
254+
FMSMOArguments
255+
Configuration for PTQ quantization.
256+
GPTQArguments
257+
Configuration for GPTQ quantization.
258+
FP8Arguments
259+
Configuration for FP8 quantization.
260+
"""
261+
if json_config:
262+
(
263+
model_args,
264+
data_args,
265+
opt_args,
266+
fms_mo_args,
267+
gptq_args,
268+
fp8_args,
269+
) = parser.parse_dict(json_config, allow_extra_keys=True)
270+
else:
271+
(
272+
model_args,
273+
data_args,
274+
opt_args,
275+
fms_mo_args,
276+
gptq_args,
277+
fp8_args,
278+
_,
279+
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
280+
281+
return (
234282
model_args,
235283
data_args,
236284
opt_args,
@@ -239,14 +287,72 @@ def main():
239287
fp8_args,
240288
)
241289

242-
quantize(
243-
model_args=model_args,
244-
data_args=data_args,
245-
opt_args=opt_args,
246-
fms_mo_args=fms_mo_args,
247-
gptq_args=gptq_args,
248-
fp8_args=fp8_args,
249-
)
290+
291+
def main():
292+
"""Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""
293+
294+
parser = get_parser()
295+
logger = logging.getLogger()
296+
job_config = get_json_config()
297+
# accept arguments via command-line or JSON
298+
try:
299+
(
300+
model_args,
301+
data_args,
302+
opt_args,
303+
fms_mo_args,
304+
gptq_args,
305+
fp8_args,
306+
) = parse_arguments(parser, job_config)
307+
308+
logger = set_log_level(opt_args.log_level, __name__)
309+
310+
logger.debug(f"Input args parsed: \nmodel_args {model_args}, data_args {data_args}, \
311+
opt_args {opt_args}, fms_mo_args {fms_mo_args}, gptq_args {gptq_args}, \
312+
fp8_args {fp8_args}")
313+
except Exception as e: # pylint: disable=broad-except
314+
logger.error(traceback.format_exc())
315+
write_termination_log(
316+
f"Exception raised during optimization. This may be a problem with your input: {e}"
317+
)
318+
sys.exit(USER_ERROR_EXIT_CODE)
319+
320+
if opt_args.output_dir:
321+
os.makedirs(opt_args.output_dir, exist_ok=True)
322+
logger.info("Using the output directory at %s", opt_args.output_dir)
323+
try:
324+
quantize(
325+
model_args=model_args,
326+
data_args=data_args,
327+
opt_args=opt_args,
328+
fms_mo_args=fms_mo_args,
329+
gptq_args=gptq_args,
330+
fp8_args=fp8_args,
331+
)
332+
except (MemoryError, OutOfMemoryError) as e:
333+
logger.error(traceback.format_exc())
334+
write_termination_log(f"OOM error during optimization. {e}")
335+
sys.exit(INTERNAL_ERROR_EXIT_CODE)
336+
except FileNotFoundError as e:
337+
logger.error(traceback.format_exc())
338+
write_termination_log(f"Unable to load file: {e}")
339+
sys.exit(USER_ERROR_EXIT_CODE)
340+
except HFValidationError as e:
341+
logger.error(traceback.format_exc())
342+
write_termination_log(
343+
f"There may be a problem with loading the model. Exception: {e}"
344+
)
345+
sys.exit(USER_ERROR_EXIT_CODE)
346+
except (TypeError, ValueError, EnvironmentError) as e:
347+
logger.error(traceback.format_exc())
348+
write_termination_log(
349+
f"Exception raised during optimization. This may be a problem with your input: {e}"
350+
)
351+
sys.exit(USER_ERROR_EXIT_CODE)
352+
except Exception as e: # pylint: disable=broad-except
353+
logger.error(traceback.format_exc())
354+
write_termination_log(f"Unhandled exception during optimization: {e}")
355+
sys.exit(INTERNAL_ERROR_EXIT_CODE)
250356

251357

252358
if __name__ == "__main__":

fms_mo/utils/config_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
import base64
17+
import json
18+
import os
19+
import pickle
20+
21+
22+
def update_config(config, **kwargs):
23+
"""Updates config from key-value pairs provided through kwargs"""
24+
if isinstance(config, (tuple, list)):
25+
for c in config:
26+
update_config(c, **kwargs)
27+
else:
28+
for k, v in kwargs.items():
29+
if hasattr(config, k):
30+
setattr(config, k, v)
31+
elif "." in k:
32+
# allow --some_config.some_param=True
33+
config_name, param_name = k.split(".")
34+
if type(config).__name__ == config_name:
35+
if hasattr(config, param_name):
36+
setattr(config, param_name, v)
37+
else:
38+
# In case of specialized config we can warm user
39+
print(f"Warning: {config_name} does not accept parameter: {k}")
40+
41+
42+
def get_json_config():
43+
"""Parses JSON configuration if provided via environment variables
44+
FMS_MO_CONFIG_JSON_ENV_VAR or FMS_MO_CONFIG_JSON_PATH.
45+
46+
FMS_MO_CONFIG_JSON_ENV_VAR is the base64 encoded JSON.
47+
FMS_MO_CONFIG_JSON_PATH is the path to the JSON config file.
48+
49+
Returns: dict or {}
50+
"""
51+
json_env_var = os.getenv("FMS_MO_CONFIG_JSON_ENV_VAR")
52+
json_path = os.getenv("FMS_MO_CONFIG_JSON_PATH")
53+
54+
# accepts either path to JSON file or encoded string config
55+
# env var takes precedent
56+
job_config_dict = {}
57+
if json_env_var:
58+
job_config_dict = txt_to_obj(json_env_var)
59+
elif json_path:
60+
with open(json_path, "r", encoding="utf-8") as f:
61+
job_config_dict = json.load(f)
62+
63+
return job_config_dict
64+
65+
66+
def txt_to_obj(txt):
67+
"""Given encoded byte string, converts to base64 decoded dict.
68+
69+
Args:
70+
txt: str
71+
Returns: dict[str, Any]
72+
"""
73+
base64_bytes = txt.encode("ascii")
74+
message_bytes = base64.b64decode(base64_bytes)
75+
try:
76+
# If the bytes represent JSON string
77+
return json.loads(message_bytes)
78+
except UnicodeDecodeError:
79+
# Otherwise the bytes are a pickled python dictionary
80+
return pickle.loads(message_bytes)

fms_mo/utils/dq_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def config_quantize_smooth_layers(qcfg):
3838
"granite-20b-code",
3939
"granite-20b-code",
4040
]
41-
if any(model in qcfg["model"] for model in llama_architecture):
41+
if any(model in qcfg["model"] for model in llama_architecture) or any(
42+
model in qcfg["model_type"] for model in llama_architecture
43+
):
4244
qcfg["qlayer_name_pattern"] = ["model.layers."]
4345
qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
4446
qcfg["qskip_layer_name"] = []

fms_mo/utils/error_logging.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
import logging
17+
import os
18+
19+
# The USER_ERROR_EXIT_CODE will be thrown when the process must exit
20+
# as result of a user input error. User-related errors should be
21+
# >= 1 and <=127 due to how some kubernetes operators interpret them.
22+
USER_ERROR_EXIT_CODE = 1
23+
# The INTERNAL_ERROR_EXIT_CODE will be thrown when training
24+
# abnormally terminates, and it is not clearly fault of the user.
25+
# System-level errors should be >= 128 and <= 254
26+
INTERNAL_ERROR_EXIT_CODE = 203
27+
28+
29+
def write_termination_log(text, log_file="error.log"):
30+
"""Writes text to termination log.
31+
32+
Args:
33+
text: str
34+
log_file: Optional[str]
35+
"""
36+
log_file = os.environ.get("TERMINATION_LOG_FILE", log_file)
37+
try:
38+
with open(log_file, "a", encoding="utf-8") as handle:
39+
handle.write(text)
40+
except Exception as e: # pylint: disable=broad-except
41+
logging.warning(f"Unable to write termination log due to error {e}")

0 commit comments

Comments
 (0)