@@ -102,6 +102,7 @@ def set_startup_timestamps(program_start=None, main_entry=None):
102102from megatron .training .checkpointing import load_checkpoint
103103from megatron .training .checkpointing import save_checkpoint
104104from megatron .training .checkpointing import checkpoint_exists
105+ from megatron .training .checkpointing import get_loaded_iteration
105106from megatron .core .full_cuda_graph import FullCudaGraphWrapper
106107from megatron .core .transformer .cuda_graphs import TECudaGraphHelper
107108from megatron .core .transformer .enums import CudaGraphScope
@@ -1899,6 +1900,8 @@ def training_log(
18991900 writer .add_scalar ('max_attention_logit' , max_attention_logit , iteration )
19001901 if wandb_writer :
19011902 wandb_writer .log ({'max_attention_logit' : max_attention_logit }, iteration )
1903+
1904+ # Log MoE metrics.
19021905 if args .num_experts is not None :
19031906 moe_loss_scale = 1 / get_num_microbatches ()
19041907 track_names = []
@@ -1930,12 +1933,15 @@ def training_log(
19301933 mtp_num_layers = args .mtp_num_layers ,
19311934 pg_collection = pg_collection ,
19321935 )
1936+
1937+ # Log MTP metrics.
19331938 if args .mtp_num_layers is not None :
19341939 mtp_loss_scale = 1 / get_num_microbatches ()
19351940 MTPLossLoggingHelper .track_mtp_metrics (
19361941 mtp_loss_scale , iteration , writer , wandb_writer , total_loss_dict
19371942 )
1938- # Track sparse attention indexer loss
1943+
1944+ # Track sparse attention indexer loss.
19391945 if args .dsa_indexer_loss_coeff is not None and args .dsa_indexer_loss_coeff > 0 :
19401946 indexer_loss_scale = 1 / get_num_microbatches ()
19411947 DSAIndexerLossLoggingHelper .track_indexer_metrics (
@@ -1945,6 +1951,8 @@ def training_log(
19451951 wandb_writer = wandb_writer ,
19461952 total_loss_dict = total_loss_dict ,
19471953 )
1954+
1955+ # Dump memory snapshot and print metrics to stdout.
19481956 if iteration % args .log_interval == 0 or is_first_iteration :
19491957 if args .record_memory_history and (is_last_rank () or torch .distributed .get_backend () == 'fake' ):
19501958 snapshot = torch .cuda .memory ._snapshot ()
@@ -2026,16 +2034,22 @@ def training_log(
20262034 total_loss_dict [skipped_iters_key ] = 0
20272035 total_loss_dict [nan_iters_key ] = 0
20282036 print_rank_last (log_string )
2037+ reported_memory_in_this_iteration = False
20292038 if report_memory_flag :
20302039 # Report memory after optimizer state has been initialized.
20312040 if torch .distributed .get_rank () == 0 :
20322041 num_microbatches = get_num_microbatches ()
20332042 report_theoretical_memory (args , num_microbatches = num_microbatches , verbose = True )
20342043 report_memory (f'(after { iteration } iterations)' )
2035- if iteration > 1 :
2044+ reported_memory_in_this_iteration = True
2045+ loaded_iteration = max (get_loaded_iteration () or 0 , 0 )
2046+ if iteration > (loaded_iteration + 1 ):
20362047 # Make sure the memory after the second iteration is reported to include optimizer state memory.
20372048 report_memory_flag = False
2038- # Write timers to wandb, don't reset the counts
2049+ if args .log_memory_interval is not None and iteration % args .log_memory_interval == 0 and \
2050+ not reported_memory_in_this_iteration :
2051+ report_memory (f'(after { iteration } iterations)' )
2052+ # Write timers to wandb, don't reset the counts.
20392053 if args .log_timers_to_tensorboard :
20402054 timers .write (timers_to_log , writer , iteration , normalizer = args .log_interval , reset = False )
20412055 timers .write (timers_to_log , wandb_writer , iteration , normalizer = args .log_interval , reset = False )
@@ -2095,6 +2109,9 @@ def force_param_sync(model_chunks: list[DDP]) -> None:
20952109 assert isinstance (model_chunk , DDP )
20962110 model_chunk .start_param_sync (force_sync = True )
20972111
2112+ # Only report memory for first 3 checkpoint saves.
2113+ num_checkpoints_memory_reported = 0
2114+ MAX_NUM_CHECKPOINTS_MEMORY_REPORTED = 3
20982115
20992116def save_checkpoint_and_time (
21002117 iteration ,
@@ -2122,6 +2139,14 @@ def save_checkpoint_and_time(
21222139 one_logger_utils .track_e2e_metrics ()
21232140 if should_disable_forward_pre_hook (args ):
21242141 force_param_sync (model )
2142+
2143+ global num_checkpoints_memory_reported , MAX_NUM_CHECKPOINTS_MEMORY_REPORTED
2144+ should_report_memory = num_checkpoints_memory_reported < MAX_NUM_CHECKPOINTS_MEMORY_REPORTED
2145+
2146+ if should_report_memory :
2147+ # Track memory before checkpoint save.
2148+ report_memory (f"(before save_checkpoint for iteration { iteration } )" )
2149+ # Save checkpoint.
21252150 save_checkpoint (
21262151 iteration ,
21272152 model ,
@@ -2133,6 +2158,11 @@ def save_checkpoint_and_time(
21332158 train_data_iterator = train_data_iterator ,
21342159 preprocess_common_state_dict_fn = preprocess_common_state_dict ,
21352160 )
2161+ if should_report_memory :
2162+ # Track memory after checkpoint save.
2163+ report_memory (f"(after save_checkpoint for iteration { iteration } )" )
2164+ num_checkpoints_memory_reported += 1
2165+
21362166 if args .fp8 :
21372167 # Run garbage collection after checkpoint saving to free memory from
21382168 # dequantized bf16 tensors that were temporarily created during fp8
0 commit comments