Skip to content

Commit 68a6f80

Browse files
authored
[megatron] add logging jsonl (#4908)
1 parent 7772c6b commit 68a6f80

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

swift/megatron/init.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from packaging import version
1212

1313
from swift.llm import git_clone_github
14-
from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run
14+
from swift.utils import JsonlWriter, get_logger, is_master, is_megatron_available, safe_ddp_context, subprocess_run
1515

1616
logger = get_logger()
1717

@@ -60,12 +60,18 @@ def _patch_training_log():
6060
from megatron.training.training import num_floating_point_operations
6161
from megatron.core.num_microbatches_calculator import get_num_microbatches
6262
from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory
63+
jsonl_writer = None
6364

6465
# Code borrowed from NVIDIA/Megatron-LM
6566
def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale,
6667
report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad):
6768
"""Log training information such as losses, timing, ...."""
69+
nonlocal jsonl_writer
6870
args = get_args()
71+
if is_master() and jsonl_writer is None:
72+
logging_path = os.path.join(args.save, 'logging.jsonl')
73+
logger.info(f'logging_path: {logging_path}')
74+
jsonl_writer = JsonlWriter(logging_path, enable_async=True)
6975
timers = get_timers()
7076
writer = get_tensorboard_writer()
7177
wandb_writer = get_wandb_writer()
@@ -209,6 +215,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
209215
mtp_loss_scale = 1 / get_num_microbatches()
210216
MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict)
211217
if iteration % args.log_interval == 0 or iteration == 1:
218+
origin_total_loss_dict = total_loss_dict.copy()
219+
212220
if args.record_memory_history and is_last_rank():
213221
snapshot = torch.cuda.memory._snapshot()
214222
from pickle import dump
@@ -277,6 +285,26 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
277285
report_memory_flag = False
278286
timers.log(timers_to_log, normalizer=args.log_interval)
279287

288+
if is_master():
289+
logs = {}
290+
for key in origin_total_loss_dict:
291+
if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]:
292+
avg = origin_total_loss_dict[key].item() / float(
293+
max(1, origin_total_loss_dict[advanced_iters_key]))
294+
logs[key] = round(avg, 8)
295+
if grad_norm is not None:
296+
logs['grad_norm'] = round(grad_norm, 8)
297+
if params_norm is not None:
298+
logs['params_norm'] = round(params_norm, 8)
299+
logs['learning_rate'] = round(learning_rate, 8)
300+
logs['elapsed_time_per_iteration'] = round(elapsed_time_per_iteration, 8)
301+
if args.log_throughput:
302+
logs['throughput'] = round(throughput, 8)
303+
logs['loss_scale'] = round(loss_scale, 8)
304+
logs['consumed_samples'] = args.consumed_train_samples
305+
logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}'
306+
jsonl_writer.append(logs)
307+
280308
return report_memory_flag
281309

282310
training.training_log = training_log

swift/megatron/trainers/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22

3+
import os
34
import time
45
from abc import ABC, abstractmethod
56
from contextlib import contextmanager
@@ -18,7 +19,7 @@
1819
from megatron.training.checkpointing import load_checkpoint
1920
from packaging import version
2021

21-
from swift.utils import get_logger
22+
from swift.utils import JsonlWriter, get_logger, is_master
2223
from ..utils import adapter_state_dict_context, prepare_mcore_model
2324
from .utils import get_swift_datasets_provider
2425

@@ -30,6 +31,9 @@ class BaseMegatronTrainer(ABC):
3031
def __init__(self, args):
3132
self.args = args
3233
self.stimer = StragglerDetector()
34+
logging_path = os.path.join(args.save, 'logging.jsonl')
35+
logger.info(f'logging_path: {logging_path}')
36+
self.jsonl_writer = JsonlWriter(logging_path, enable_async=True)
3337
self._patch_megatron()
3438

3539
@contextmanager
@@ -305,9 +309,11 @@ def evaluate(self,
305309
timers.log(['evaluate'])
306310

307311
rerun_state_machine.set_mode(rerun_mode)
308-
309-
rerun_state_machine.set_mode(rerun_mode)
310-
312+
if is_master():
313+
logs = {}
314+
for key, val in total_loss_dict.items():
315+
logs[f'eval_{key}'] = round(val.item(), 8)
316+
self.jsonl_writer.append(logs)
311317
return total_loss_dict, collected_non_loss_data, False
312318

313319
def save_checkpoint(self, *args, **kwargs):

0 commit comments

Comments
 (0)