|
11 | 11 | from packaging import version
|
12 | 12 |
|
13 | 13 | from swift.llm import git_clone_github
|
14 |
| -from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run |
| 14 | +from swift.utils import JsonlWriter, get_logger, is_master, is_megatron_available, safe_ddp_context, subprocess_run |
15 | 15 |
|
16 | 16 | logger = get_logger()
|
17 | 17 |
|
@@ -60,12 +60,18 @@ def _patch_training_log():
|
60 | 60 | from megatron.training.training import num_floating_point_operations
|
61 | 61 | from megatron.core.num_microbatches_calculator import get_num_microbatches
|
62 | 62 | from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory
|
| 63 | + jsonl_writer = None |
63 | 64 |
|
64 | 65 | # Code borrowed from NVIDIA/Megatron-LM
|
65 | 66 | def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale,
|
66 | 67 | report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad):
|
67 | 68 | """Log training information such as losses, timing, ...."""
|
| 69 | + nonlocal jsonl_writer |
68 | 70 | args = get_args()
|
| 71 | + if is_master() and jsonl_writer is None: |
| 72 | + logging_path = os.path.join(args.save, 'logging.jsonl') |
| 73 | + logger.info(f'logging_path: {logging_path}') |
| 74 | + jsonl_writer = JsonlWriter(logging_path, enable_async=True) |
69 | 75 | timers = get_timers()
|
70 | 76 | writer = get_tensorboard_writer()
|
71 | 77 | wandb_writer = get_wandb_writer()
|
@@ -209,6 +215,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
|
209 | 215 | mtp_loss_scale = 1 / get_num_microbatches()
|
210 | 216 | MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict)
|
211 | 217 | if iteration % args.log_interval == 0 or iteration == 1:
|
| 218 | + origin_total_loss_dict = total_loss_dict.copy() |
| 219 | + |
212 | 220 | if args.record_memory_history and is_last_rank():
|
213 | 221 | snapshot = torch.cuda.memory._snapshot()
|
214 | 222 | from pickle import dump
|
@@ -277,6 +285,26 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
|
277 | 285 | report_memory_flag = False
|
278 | 286 | timers.log(timers_to_log, normalizer=args.log_interval)
|
279 | 287 |
|
| 288 | + if is_master(): |
| 289 | + logs = {} |
| 290 | + for key in origin_total_loss_dict: |
| 291 | + if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]: |
| 292 | + avg = origin_total_loss_dict[key].item() / float( |
| 293 | + max(1, origin_total_loss_dict[advanced_iters_key])) |
| 294 | + logs[key] = round(avg, 8) |
| 295 | + if grad_norm is not None: |
| 296 | + logs['grad_norm'] = round(grad_norm, 8) |
| 297 | + if params_norm is not None: |
| 298 | + logs['params_norm'] = round(params_norm, 8) |
| 299 | + logs['learning_rate'] = round(learning_rate, 8) |
| 300 | + logs['elapsed_time_per_iteration'] = round(elapsed_time_per_iteration, 8) |
| 301 | + if args.log_throughput: |
| 302 | + logs['throughput'] = round(throughput, 8) |
| 303 | + logs['loss_scale'] = round(loss_scale, 8) |
| 304 | + logs['consumed_samples'] = args.consumed_train_samples |
| 305 | + logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}' |
| 306 | + jsonl_writer.append(logs) |
| 307 | + |
280 | 308 | return report_memory_flag
|
281 | 309 |
|
282 | 310 | training.training_log = training_log
|
|
0 commit comments