Skip to content

Commit f0b55bf

Browse files
authored
fix: change logging level to info and print flat arguments (#582)
Signed-off-by: Dushyant Behl <[email protected]>
1 parent be16db6 commit f0b55bf

File tree

3 files changed

+91
-48
lines changed

3 files changed

+91
-48
lines changed

tests/utils/test_logging.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def test_set_log_level_for_logger_default():
3636

3737
with mock.patch.dict(os.environ, {}, clear=True):
3838
train_args = copy.deepcopy(TRAIN_ARGS)
39-
training_args, logger = set_log_level(train_args)
40-
assert logger.getEffectiveLevel() == logging.WARNING
41-
assert training_args.log_level == "passive"
39+
logger, log_level = set_log_level(level=train_args.log_level)
40+
assert logger.getEffectiveLevel() == logging.INFO
41+
assert log_level == "info"
4242

4343

4444
def test_set_log_level_for_logger_with_env_var():
@@ -48,10 +48,10 @@ def test_set_log_level_for_logger_with_env_var():
4848
"""
4949

5050
with mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True):
51-
train_args_env = copy.deepcopy(TRAIN_ARGS)
52-
training_args, logger = set_log_level(train_args_env)
51+
train_args = copy.deepcopy(TRAIN_ARGS)
52+
logger, log_level = set_log_level(level=train_args.log_level)
5353
assert logger.getEffectiveLevel() == logging.INFO
54-
assert training_args.log_level == "info"
54+
assert log_level == "info"
5555

5656

5757
def test_set_log_level_for_logger_with_set_verbosity_and_cli():
@@ -64,9 +64,9 @@ def test_set_log_level_for_logger_with_set_verbosity_and_cli():
6464
with mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True):
6565
train_args = copy.deepcopy(TRAIN_ARGS)
6666
train_args.log_level = "error"
67-
training_args, logger = set_log_level(train_args)
67+
logger, log_level = set_log_level(level=train_args.log_level)
6868
assert logger.getEffectiveLevel() == logging.ERROR
69-
assert training_args.log_level == "error"
69+
assert log_level == "error"
7070

7171

7272
def test_set_log_level_for_logger_with_env_var_and_cli():
@@ -79,6 +79,6 @@ def test_set_log_level_for_logger_with_env_var_and_cli():
7979
with mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True):
8080
train_args = copy.deepcopy(TRAIN_ARGS)
8181
train_args.log_level = "error"
82-
training_args, logger = set_log_level(train_args)
82+
logger, log_level = set_log_level(level=train_args.log_level)
8383
assert logger.getEffectiveLevel() == logging.ERROR
84-
assert training_args.log_level == "error"
84+
assert log_level == "error"

tuning/sft_trainer.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
USER_ERROR_EXIT_CODE,
6464
write_termination_log,
6565
)
66-
from tuning.utils.logging import set_log_level
66+
from tuning.utils.logging import pretty_print_args, set_log_level
6767

6868

6969
def train(
@@ -120,7 +120,9 @@ def train(
120120
Tuple: Instance of SFTTrainer , some metadata in a dict
121121
Metadata contains information on number of added tokens while tuning.
122122
"""
123-
train_args, logger = set_log_level(train_args, "sft_trainer_train")
123+
logger, train_args.log_level = set_log_level(
124+
logger_name="sft_trainer_train", level=train_args.log_level
125+
)
124126
USE_ALORA = False
125127
try:
126128
# Third Party
@@ -529,7 +531,6 @@ def get_parser():
529531
choices=["pt", "lora", "alora", None, "none"],
530532
default="none",
531533
)
532-
533534
parser.add_argument(
534535
"--exp_metadata",
535536
type=str,
@@ -603,7 +604,6 @@ def parse_arguments(parser, json_config=None):
603604
raise ValueError(
604605
"invocation_string is not passed required for aLoRA usage"
605606
)
606-
607607
else:
608608
(
609609
model_args,
@@ -687,26 +687,28 @@ def main():
687687
) = parse_arguments(parser, job_config)
688688

689689
# Function to set log level for python native logger and transformers training logger
690-
training_args, logger = set_log_level(training_args, __name__)
690+
logger, training_args.log_level = set_log_level(
691+
logger_name=__name__, level=training_args.log_level
692+
)
691693

692-
logger.info(
693-
"Flat arguments parsed: \
694-
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
695-
tune_config %s, quantized_lora_config %s, fusedops_kernels_config %s, \
696-
attention_and_distributed_packing_config, %s,\
697-
fast_moe_config %s, tracker_config %s, exp_metadata %s",
698-
model_args,
699-
data_args,
700-
training_args,
701-
trainer_controller_args,
702-
tune_config,
703-
quantized_lora_config,
704-
fusedops_kernels_config,
705-
attention_and_distributed_packing_config,
706-
fast_moe_config,
707-
tracker_configs,
708-
exp_metadata,
694+
logger.info("fms-hf-tuning execution start")
695+
args_dump = pretty_print_args(
696+
{
697+
"Model Arguments": model_args,
698+
"Data Arguments": data_args,
699+
"Training Arguments": training_args,
700+
"Tune Config": tune_config,
701+
"QLoRA Config": quantized_lora_config,
702+
"Tracker Config": tracker_configs,
703+
"AADP (fms-acceleration) Config": attention_and_distributed_packing_config,
704+
"Fused Ops Kernels Config": fusedops_kernels_config,
705+
"Fast MoE Config": fast_moe_config,
706+
"Trainer Controller Config": trainer_controller_args,
707+
"Extra Metadata": exp_metadata,
708+
}
709709
)
710+
logger.info(args_dump)
711+
710712
except Exception as e: # pylint: disable=broad-except
711713
logger.error(traceback.format_exc())
712714
write_termination_log(

tuning/utils/logging.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,33 @@
1313
# limitations under the License.
1414

1515
# Standard
16+
from typing import Dict
1617
import logging
1718
import os
1819

20+
# Third Party
21+
from accelerate.state import PartialState
22+
import datasets
23+
import transformers
1924

20-
def set_log_level(train_args, logger_name=None):
25+
DEFAULT_LOG_LEVEL_MAIN = "INFO"
26+
DEFAULT_LOG_LEVEL_WORKERS = "WARNING"
27+
28+
29+
def set_log_level(logger_name="fms-hf-tuning", level=None):
2130
"""Set log level of python native logger and TF logger via argument from CLI or env variable.
2231
2332
Args:
24-
train_args
25-
Training arguments for training model.
2633
logger_name
2734
Logger name with which the logger is instantiated.
35+
level
36+
Requested level of the logger
2837
2938
Returns:
30-
train_args
31-
Updated training arguments for training model.
3239
train_logger
3340
Logger with updated effective log level
41+
level
42+
Level of the logger initialized
3443
"""
3544

3645
# Clear any existing handlers if necessary
@@ -39,26 +48,58 @@ def set_log_level(train_args, logger_name=None):
3948

4049
# Configure Python native logger and transformers log level
4150
# If CLI arg is passed, assign same log level to python native logger
42-
log_level = "WARNING"
43-
if train_args.log_level != "passive":
44-
log_level = train_args.log_level
45-
46-
# If CLI arg not is passed and env var LOG_LEVEL is set,
47-
# assign same log level to both logger
51+
lowest_log_level = DEFAULT_LOG_LEVEL_MAIN
52+
if level != "passive":
53+
lowest_log_level = level
4854
elif os.environ.get("LOG_LEVEL"):
49-
log_level = os.environ.get("LOG_LEVEL")
50-
train_args.log_level = (
51-
log_level.lower()
55+
# If CLI arg not is passed and env var LOG_LEVEL is set,
56+
# assign same log level to both logger
57+
lowest_log_level = (
58+
os.environ.get("LOG_LEVEL").lower()
5259
if not os.environ.get("TRANSFORMERS_VERBOSITY")
5360
else os.environ.get("TRANSFORMERS_VERBOSITY")
5461
)
5562

63+
state = PartialState()
64+
rank = state.process_index
65+
66+
log_on_all = os.environ.get("LOG_ON_ALL_PROCESSES")
67+
if log_on_all:
68+
log_level = lowest_log_level or DEFAULT_LOG_LEVEL_MAIN
69+
else:
70+
if state.is_local_main_process:
71+
log_level = lowest_log_level or DEFAULT_LOG_LEVEL_MAIN
72+
datasets.utils.logging.set_verbosity_warning()
73+
transformers.utils.logging.set_verbosity_info()
74+
else:
75+
log_level = DEFAULT_LOG_LEVEL_WORKERS
76+
datasets.utils.logging.set_verbosity_error()
77+
transformers.utils.logging.set_verbosity_error()
78+
79+
log_format = f"Rank-{rank} [%(levelname)s]:%(filename)s:%(funcName)s: %(message)s"
80+
5681
logging.basicConfig(
57-
format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper()
82+
format=log_format,
83+
level=log_level.upper(),
5884
)
5985

6086
if logger_name:
6187
train_logger = logging.getLogger(logger_name)
6288
else:
6389
train_logger = logging.getLogger()
64-
return train_args, train_logger
90+
91+
return train_logger, log_level.lower()
92+
93+
94+
def pretty_print_args(args: Dict):
95+
dump = "\n========================= Flat Arguments =========================\n"
96+
for name, arg in args.items():
97+
if arg:
98+
dump += f"---------------------------- {name} -----------------------\n"
99+
if hasattr(arg, "__dict__"):
100+
arg = vars(arg)
101+
max_len = max(len(k) for k in arg.keys())
102+
for k, v in sorted(arg.items()):
103+
dump += f" {k:<{max_len}} : {v}\n"
104+
dump += "========================= Arguments Done =========================\n"
105+
return dump

0 commit comments

Comments
 (0)