Skip to content

Commit d186184

Browse files
committed
Merge branch 'main' into helenn-dev-rl-training-graphs-test
2 parents b4f23e2 + e4b18f7 commit d186184

File tree

8 files changed

+172
-46
lines changed

8 files changed

+172
-46
lines changed

.github/actions/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ runs:
126126
IS_CI_WORKLOAD: ${{ inputs.is_ci_workload }}
127127
run: |
128128
PR_NUMBER=${{ fromJSON(steps.get-pr-info.outputs.pr-info || '{}').number }}
129-
HAS_RUN_FUNCTIONAL_TESTS_LABEL=$(gh pr view $PR_NUMBER --json labels | jq '[.labels[].name] | any(. == "Run functional tests")')
129+
HAS_RUN_FUNCTIONAL_TESTS_LABEL=$(gh pr view $PR_NUMBER --json labels | jq '[.labels[].name] | any(. == "Run functional tests")') || echo "$IS_CI_WORKLOAD"
130130
HAS_RUN_FUNCTIONAL_TESTS_LABEL=${HAS_RUN_FUNCTIONAL_TESTS_LABEL:-$IS_CI_WORKLOAD}
131131
echo "main=$HAS_RUN_FUNCTIONAL_TESTS_LABEL" | tee -a $GITHUB_OUTPUT
132132

.github/oncall_schedule.json

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
[
2-
{
3-
"user": "Phlip79",
4-
"date": "2026-01-07"
5-
},
62
{
73
"user": "BoxiangW",
84
"date": "2026-01-14"
@@ -26,5 +22,29 @@
2622
{
2723
"user": "janEbert",
2824
"date": "2026-02-18"
25+
},
26+
{
27+
"user": "maanug-nv",
28+
"date": "2026-02-25"
29+
},
30+
{
31+
"user": "BoxiangW",
32+
"date": "2026-03-04"
33+
},
34+
{
35+
"user": "Phlip79",
36+
"date": "2026-03-11"
37+
},
38+
{
39+
"user": "asolergi-nv",
40+
"date": "2026-03-18"
41+
},
42+
{
43+
"user": "dimapihtar",
44+
"date": "2026-03-25"
45+
},
46+
{
47+
"user": "gautham-kollu",
48+
"date": "2026-04-01"
2949
}
3050
]

.github/workflows/cicd-main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ jobs:
406406
is_unit_test: "true"
407407
PAT: ${{ secrets.PAT }}
408408
container-image: ${{ env.container-registry }}/megatron-lm:${{ github.sha }}
409+
is_ci_workload: ${{ needs.pre-flight.outputs.is_ci_workload }}
409410

410411
cicd-parse-integration-tests:
411412
runs-on: ubuntu-latest

megatron/core/transformer/attention.py

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
import copy
3+
import inspect
34
from abc import ABC, abstractmethod
45
from dataclasses import dataclass
56
from typing import NoReturn, Optional, Tuple, Union
@@ -578,6 +579,74 @@ def flash_decode(
578579
)
579580
return out
580581

582+
def _flash_attention_3_forward_wrapper(
583+
self,
584+
q: Tensor,
585+
k: Tensor,
586+
v: Tensor,
587+
max_seqlen_q,
588+
max_seqlen_k,
589+
cu_seqlens_q,
590+
seqlens_k,
591+
block_table,
592+
softmax_scale,
593+
):
594+
"""
595+
Wrapper for calling the FA3 _flash_attn_forward function.
596+
Handles argument conversion for different versions of the _flash_attn_forward API.
597+
"""
598+
candidate_kwargs = {
599+
"q": q,
600+
"k": k,
601+
"v": v,
602+
"k_new": None,
603+
"v_new": None,
604+
"qv": None,
605+
"out": None,
606+
"out_": None,
607+
"cu_seqlens_q": cu_seqlens_q,
608+
"cu_seqlens_k": None,
609+
"cu_seqlens_k_new": None,
610+
"seqused_q": None,
611+
"seqused_k": seqlens_k,
612+
"max_seqlen_q": max_seqlen_q,
613+
"max_seqlen_k": max_seqlen_k,
614+
"page_table": block_table,
615+
"kv_batch_idx": None,
616+
"leftpad_k": None,
617+
"rotary_cos": None,
618+
"rotary_sin": None,
619+
"seqlens_rotary": None,
620+
"q_descale": None,
621+
"k_descale": None,
622+
"v_descale": None,
623+
"softmax_scale": softmax_scale,
624+
"causal": True,
625+
"attention_chunk": 0,
626+
"softcap": 0.0,
627+
"window_size": (-1, -1),
628+
"window_size_left": -1,
629+
"window_size_right": -1,
630+
"rotary_interleaved": True,
631+
"scheduler_metadata": None,
632+
"num_splits": 0 if not self.batch_invariant_mode else 1,
633+
"pack_gqa": None,
634+
"sm_margin": 0,
635+
}
636+
637+
# Parse the expect argument names from the function signature
638+
if inspect.isfunction(_flash_attn_forward):
639+
sig = inspect.signature(_flash_attn_forward)
640+
else:
641+
assert isinstance(_flash_attn_forward, torch._library.custom_ops.CustomOpDef)
642+
sig = inspect.signature(_flash_attn_forward._init_fn)
643+
valid_kwargs = set(sig.parameters.keys())
644+
final_kwargs = {k: candidate_kwargs[k] for k in valid_kwargs if k in candidate_kwargs}
645+
646+
output_total, *unused = _flash_attn_forward(**final_kwargs)
647+
648+
return output_total
649+
581650
def flash_decode_and_prefill(
582651
self,
583652
q: Tensor,
@@ -619,40 +688,16 @@ def flash_decode_and_prefill(
619688
if HAVE_FA3:
620689
# TODO(ksanthanam): Replace with call to flash_attn_varlen_func once
621690
# it accepts block_table
622-
output_total, *unused = _flash_attn_forward(
623-
q=q,
624-
k=k,
625-
v=v,
626-
k_new=None,
627-
v_new=None,
628-
qv=None,
629-
out=None,
630-
cu_seqlens_q=cu_seqlens_q,
631-
cu_seqlens_k=None,
632-
cu_seqlens_k_new=None,
633-
seqused_q=None,
634-
seqused_k=seqlens_k,
635-
max_seqlen_q=max_seqlen_q,
636-
max_seqlen_k=max_seqlen_k,
637-
page_table=block_table,
638-
kv_batch_idx=None,
639-
leftpad_k=None,
640-
rotary_cos=None,
641-
rotary_sin=None,
642-
seqlens_rotary=None,
643-
q_descale=None,
644-
k_descale=None,
645-
v_descale=None,
646-
softmax_scale=softmax_scale,
647-
causal=True,
648-
window_size=(-1, -1),
649-
attention_chunk=0,
650-
softcap=0.0,
651-
rotary_interleaved=True,
652-
scheduler_metadata=None,
653-
num_splits=0 if not self.batch_invariant_mode else 1,
654-
pack_gqa=None,
655-
sm_margin=0,
691+
output_total = self._flash_attention_3_forward_wrapper(
692+
q,
693+
k,
694+
v,
695+
max_seqlen_q,
696+
max_seqlen_k,
697+
cu_seqlens_q,
698+
seqlens_k,
699+
block_table,
700+
softmax_scale,
656701
)
657702
else:
658703
assert (

megatron/training/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,8 @@ def validate_args(args, defaults={}):
914914
if args.save_retain_interval is not None:
915915
assert args.save_retain_interval > 0
916916
assert args.save_retain_interval % args.save_interval == 0
917+
if args.log_memory_interval is not None:
918+
assert args.log_memory_interval % args.log_interval == 0
917919
# Mixed precision checks.
918920
if args.fp16_lm_cross_entropy:
919921
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
@@ -2369,6 +2371,10 @@ def _add_training_args(parser):
23692371
'with larger models, sequences, and batch sizes.')
23702372
group.add_argument('--log-interval', type=int, default=100,
23712373
help='Report loss and timing interval.')
2374+
group.add_argument('--log-memory-interval', type=int, default=None,
2375+
help='Report memory interval.')
2376+
group.add_argument('--log-device-memory-used', action='store_true',
2377+
help='Log device memory used (as reported by nvidia-smi).')
23722378
group.add_argument('--tensorboard-dir', type=str, default=None,
23732379
help='Write TensorBoard logs to this directory.')
23742380
group.add_argument('--no-masked-softmax-fusion',

megatron/training/checkpointing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
has_nvidia_modelopt = False
6565

6666
_CHECKPOINT_VERSION = None
67+
_LOADED_ITERATION = None
6768

6869
logger = getLogger(__name__)
6970
_NON_PERSISTENT_CKPT_SUBDIR = 'non_persistent'
@@ -81,6 +82,22 @@ def get_checkpoint_version():
8182
return _CHECKPOINT_VERSION
8283

8384

85+
def set_loaded_iteration(value):
86+
"""Set the iteration that was loaded from checkpoint.
87+
88+
This is stored separately from args to avoid polluting the checkpoint
89+
with runtime state (args is saved in checkpoints).
90+
"""
91+
global _LOADED_ITERATION
92+
_LOADED_ITERATION = value
93+
94+
95+
def get_loaded_iteration():
96+
"""Get the iteration that was loaded from checkpoint, or None if no checkpoint was loaded."""
97+
global _LOADED_ITERATION
98+
return _LOADED_ITERATION
99+
100+
84101
def check_checkpoint_args(checkpoint_args):
85102
"""Ensure fixed arguments for a model are the same for the input
86103
arguments and the one retrieved from checkpoint."""
@@ -1132,6 +1149,10 @@ def _load_base_checkpoint(
11321149
if getattr(args, "ckpt_step", None):
11331150
iteration = args.ckpt_step
11341151

1152+
# Record the iteration loaded (stored separately from args to avoid
1153+
# polluting checkpoints, since args is saved in checkpoints).
1154+
set_loaded_iteration(iteration)
1155+
11351156
if non_persistent_iteration != -1: # there is a non-persistent checkpoint
11361157
if non_persistent_iteration >= iteration:
11371158
return _load_non_persistent_base_checkpoint(

megatron/training/training.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def set_startup_timestamps(program_start=None, main_entry=None):
102102
from megatron.training.checkpointing import load_checkpoint
103103
from megatron.training.checkpointing import save_checkpoint
104104
from megatron.training.checkpointing import checkpoint_exists
105+
from megatron.training.checkpointing import get_loaded_iteration
105106
from megatron.core.full_cuda_graph import FullCudaGraphWrapper
106107
from megatron.core.transformer.cuda_graphs import TECudaGraphHelper
107108
from megatron.core.transformer.enums import CudaGraphScope
@@ -1899,6 +1900,8 @@ def training_log(
18991900
writer.add_scalar('max_attention_logit', max_attention_logit, iteration)
19001901
if wandb_writer:
19011902
wandb_writer.log({'max_attention_logit': max_attention_logit}, iteration)
1903+
1904+
# Log MoE metrics.
19021905
if args.num_experts is not None:
19031906
moe_loss_scale = 1 / get_num_microbatches()
19041907
track_names = []
@@ -1930,12 +1933,15 @@ def training_log(
19301933
mtp_num_layers=args.mtp_num_layers,
19311934
pg_collection=pg_collection,
19321935
)
1936+
1937+
# Log MTP metrics.
19331938
if args.mtp_num_layers is not None:
19341939
mtp_loss_scale = 1 / get_num_microbatches()
19351940
MTPLossLoggingHelper.track_mtp_metrics(
19361941
mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
19371942
)
1938-
# Track sparse attention indexer loss
1943+
1944+
# Track sparse attention indexer loss.
19391945
if args.dsa_indexer_loss_coeff is not None and args.dsa_indexer_loss_coeff > 0:
19401946
indexer_loss_scale = 1 / get_num_microbatches()
19411947
DSAIndexerLossLoggingHelper.track_indexer_metrics(
@@ -1945,6 +1951,8 @@ def training_log(
19451951
wandb_writer=wandb_writer,
19461952
total_loss_dict=total_loss_dict,
19471953
)
1954+
1955+
# Dump memory snapshot and print metrics to stdout.
19481956
if iteration % args.log_interval == 0 or is_first_iteration:
19491957
if args.record_memory_history and (is_last_rank() or torch.distributed.get_backend() == 'fake'):
19501958
snapshot = torch.cuda.memory._snapshot()
@@ -2026,16 +2034,22 @@ def training_log(
20262034
total_loss_dict[skipped_iters_key] = 0
20272035
total_loss_dict[nan_iters_key] = 0
20282036
print_rank_last(log_string)
2037+
reported_memory_in_this_iteration = False
20292038
if report_memory_flag:
20302039
# Report memory after optimizer state has been initialized.
20312040
if torch.distributed.get_rank() == 0:
20322041
num_microbatches = get_num_microbatches()
20332042
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
20342043
report_memory(f'(after {iteration} iterations)')
2035-
if iteration > 1:
2044+
reported_memory_in_this_iteration = True
2045+
loaded_iteration = max(get_loaded_iteration() or 0, 0)
2046+
if iteration > (loaded_iteration + 1):
20362047
# Make sure the memory after the second iteration is reported to include optimizer state memory.
20372048
report_memory_flag = False
2038-
# Write timers to wandb, don't reset the counts
2049+
if args.log_memory_interval is not None and iteration % args.log_memory_interval == 0 and \
2050+
not reported_memory_in_this_iteration:
2051+
report_memory(f'(after {iteration} iterations)')
2052+
# Write timers to wandb, don't reset the counts.
20392053
if args.log_timers_to_tensorboard:
20402054
timers.write(timers_to_log, writer, iteration, normalizer=args.log_interval, reset=False)
20412055
timers.write(timers_to_log, wandb_writer, iteration, normalizer=args.log_interval, reset=False)
@@ -2095,6 +2109,9 @@ def force_param_sync(model_chunks: list[DDP]) -> None:
20952109
assert isinstance(model_chunk, DDP)
20962110
model_chunk.start_param_sync(force_sync=True)
20972111

2112+
# Only report memory for first 3 checkpoint saves.
2113+
num_checkpoints_memory_reported = 0
2114+
MAX_NUM_CHECKPOINTS_MEMORY_REPORTED = 3
20982115

20992116
def save_checkpoint_and_time(
21002117
iteration,
@@ -2122,6 +2139,14 @@ def save_checkpoint_and_time(
21222139
one_logger_utils.track_e2e_metrics()
21232140
if should_disable_forward_pre_hook(args):
21242141
force_param_sync(model)
2142+
2143+
global num_checkpoints_memory_reported, MAX_NUM_CHECKPOINTS_MEMORY_REPORTED
2144+
should_report_memory = num_checkpoints_memory_reported < MAX_NUM_CHECKPOINTS_MEMORY_REPORTED
2145+
2146+
if should_report_memory:
2147+
# Track memory before checkpoint save.
2148+
report_memory(f"(before save_checkpoint for iteration {iteration})")
2149+
# Save checkpoint.
21252150
save_checkpoint(
21262151
iteration,
21272152
model,
@@ -2133,6 +2158,11 @@ def save_checkpoint_and_time(
21332158
train_data_iterator=train_data_iterator,
21342159
preprocess_common_state_dict_fn=preprocess_common_state_dict,
21352160
)
2161+
if should_report_memory:
2162+
# Track memory after checkpoint save.
2163+
report_memory(f"(after save_checkpoint for iteration {iteration})")
2164+
num_checkpoints_memory_reported += 1
2165+
21362166
if args.fp8:
21372167
# Run garbage collection after checkpoint saving to free memory from
21382168
# dequantized bf16 tensors that were temporarily created during fp8

megatron/training/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,15 @@ def logical_and_across_model_parallel_group(input: bool) -> bool:
276276

277277
def report_memory(name):
278278
"""Simple GPU memory report."""
279+
args = get_args()
279280
mega_bytes = 1024.0 * 1024.0
280281
string = name + ' memory (MB)'
281-
string += ' | allocated: {}'.format(torch.cuda.memory_allocated() / mega_bytes)
282-
string += ' | max allocated: {}'.format(torch.cuda.max_memory_allocated() / mega_bytes)
283-
string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
284-
string += ' | max reserved: {}'.format(torch.cuda.max_memory_reserved() / mega_bytes)
282+
string += f" | allocated: {torch.cuda.memory_allocated() / mega_bytes:.2f}"
283+
string += f" | max allocated: {torch.cuda.max_memory_allocated() / mega_bytes:.2f}"
284+
string += f" | reserved: {torch.cuda.memory_reserved() / mega_bytes:.2f}"
285+
string += f" | max reserved: {torch.cuda.max_memory_reserved() / mega_bytes:.2f}"
286+
if args.log_device_memory_used:
287+
string += f" | total device memory used: {torch.cuda.device_memory_used() / mega_bytes:.2f}"
285288
if mpu.get_data_parallel_rank() == 0:
286289
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
287290

0 commit comments

Comments
 (0)