Skip to content

Commit 9fdeee8

Browse files
Fix lint
Signed-off-by: Thara Palanivel <[email protected]>
1 parent dfa1829 commit 9fdeee8

File tree

4 files changed

+16
-18
lines changed

4 files changed

+16
-18
lines changed

fms_mo/run_quant.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,15 +308,9 @@ def main():
308308

309309
logger = set_log_level(opt_args.log_level, __name__)
310310

311-
logger.debug(
312-
"Input args parsed: \nmodel_args %s, data_args %s, opt_args %s, fms_mo_args %s, gptq_args %s, fp8_args %s",
313-
model_args,
314-
data_args,
315-
opt_args,
316-
fms_mo_args,
317-
gptq_args,
318-
fp8_args,
319-
)
311+
logger.debug(f"Input args parsed: \nmodel_args {model_args}, data_args {data_args}, \
312+
opt_args {opt_args}, fms_mo_args {fms_mo_args}, gptq_args {gptq_args}, \
313+
fp8_args {fp8_args}")
320314
except Exception as e: # pylint: disable=broad-except
321315
logger.error(traceback.format_exc())
322316
write_termination_log(
@@ -342,7 +336,7 @@ def main():
342336
sys.exit(INTERNAL_ERROR_EXIT_CODE)
343337
except FileNotFoundError as e:
344338
logger.error(traceback.format_exc())
345-
write_termination_log("Unable to load file: {}".format(e))
339+
write_termination_log(f"Unable to load file: {e}")
346340
sys.exit(USER_ERROR_EXIT_CODE)
347341
except HFValidationError as e:
348342
logger.error(traceback.format_exc())

fms_mo/utils/config_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def update_config(config, **kwargs):
23+
"""Updates config from key-value pairs provided through kwargs"""
2324
if isinstance(config, (tuple, list)):
2425
for c in config:
2526
update_config(c, **kwargs)

fms_mo/utils/error_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ def write_termination_log(text, log_file="error.log"):
3838
with open(log_file, "a", encoding="utf-8") as handle:
3939
handle.write(text)
4040
except Exception as e: # pylint: disable=broad-except
41-
logging.warning("Unable to write termination log due to error {}".format(e))
41+
logging.warning(f"Unable to write termination log due to error {e}")

tests/test_run_quant.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ def test_run_train_fails_training_data_path_not_exist():
8585
)
8686

8787

88-
# Note: job_config dict gets modified during process training args
8988
@pytest.fixture(name="job_config", scope="session")
9089
def fixture_job_config():
90+
"""Fixture to get happy path dummy config as a dict, note that job_config dict gets
91+
modified during process training args"""
9192
with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f:
9293
dummy_job_config_dict = json.load(f)
9394
return dummy_job_config_dict
@@ -97,15 +98,16 @@ def fixture_job_config():
9798

9899

99100
def test_parse_arguments(job_config):
101+
"""Test that arg parser can parse json job config correctly"""
100102
parser = get_parser()
101103
job_config_copy = copy.deepcopy(job_config)
102104
(
103105
model_args,
104106
data_args,
105107
opt_args,
106-
fms_mo_args,
107-
gptq_args,
108-
fp8_args,
108+
_,
109+
_,
110+
_,
109111
) = parse_arguments(parser, job_config_copy)
110112
assert str(model_args.torch_dtype) == "torch.bfloat16"
111113
assert data_args.training_data_path == "data_train"
@@ -114,6 +116,7 @@ def test_parse_arguments(job_config):
114116

115117

116118
def test_parse_arguments_defaults(job_config):
119+
"""Test that defaults set in fms_mo/training_args.py are retained"""
117120
parser = get_parser()
118121
job_config_defaults = copy.deepcopy(job_config)
119122
assert "torch_dtype" not in job_config_defaults
@@ -123,10 +126,10 @@ def test_parse_arguments_defaults(job_config):
123126
(
124127
model_args,
125128
data_args,
126-
opt_args,
129+
_,
127130
fms_mo_args,
128-
gptq_args,
129-
fp8_args,
131+
_,
132+
_,
130133
) = parse_arguments(parser, job_config_defaults)
131134
assert str(model_args.torch_dtype) == "torch.bfloat16"
132135
assert model_args.model_revision == "main"

0 commit comments

Comments
 (0)