Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
1efb851
Add --moe-use-device-initiated-grouped-gemm to allow token_per_expert…
QiZhangNV Nov 3, 2025
64041fc
Initial change for packed offloading
vasunvidia Nov 17, 2025
bbfcef2
Bug fix
Nov 17, 2025
169f9a5
Mem Opt
vasunvidia Nov 17, 2025
ac9dd93
Handle MXFP8Tensor offload
Nov 20, 2025
d6dbc99
Enable Packed offloading to CPU pinned memory with PACKED_OFFLOAD_CPU=1
Nov 20, 2025
35d3c06
Enable activation truncation for first step
Nov 21, 2025
8e5857c
Overflow check and assert
Nov 22, 2025
b2e77eb
Check in temporary solution for detecing overflow in receiving buffer
nanz-nv Nov 22, 2025
2947a02
Reconstruct the stash buffer into a 2D structure
nanz-nv Nov 23, 2025
a45b7fe
Refactor the code to check overflow in HybridEP receiving buffer
nanz-nv Nov 24, 2025
fe504bd
Use CPU offloading context manager as a WAR for now to WAR the proble…
nanz-nv Nov 24, 2025
93fb183
Add support for paged stashing
nanz-nv Nov 25, 2025
ed07de6
Add the feature of speculative CE stashing
nanz-nv Nov 26, 2025
7ab5f9b
Fix PP schedule
Nov 26, 2025
95bc80f
Use common buffer across VP for paged stashing
vasunvidia Nov 26, 2025
8a593b0
Disable Packed Offloading for validation
Nov 27, 2025
db2b1ac
Fixe perf issue in packed stash/pop kernels
nanz-nv Nov 27, 2025
22beaa5
Minor fix for tensor allocation and padding requirement on budget
nanz-nv Dec 7, 2025
3fba433
Packed/paged offloading is current not stream-safe. Need to put stash…
nanz-nv Dec 7, 2025
63a7dca
add new hybrid ep
Autumn1998 Dec 9, 2025
f9f2c7b
Remove the overflow check in framework because it is now done by hybr…
nanz-nv Dec 10, 2025
34438d9
Fix one merge conflict
nanz-nv Dec 10, 2025
0d55288
Code cleanup
vasunvidia Dec 11, 2025
da20523
Add second autograd to avoid triple buffering
vasunvidia Dec 12, 2025
a219d7d
Avoid unnecessary wait_stream for reload in case of 1f1b
vasunvidia Dec 12, 2025
0bede5b
Check in dynamic-shape-aware SwiGLU triton kernel
nanz-nv Dec 18, 2025
837503d
Major cleanup and refactor
nanz-nv Dec 18, 2025
922689a
Check in paged_stash.py that was omited in the previous commit
nanz-nv Dec 18, 2025
25e1f82
Remove d2d page feature for now
nanz-nv Dec 18, 2025
de34d7b
Update added arguments and add compatibility check
nanz-nv Dec 18, 2025
1bf3e43
refine overflow check
nanz-nv Dec 18, 2025
db0b5c9
Fixing lint issues
nanz-nv Dec 19, 2025
3186b20
Minor refactor
vasunvidia Jan 8, 2026
1ab150b
Add unit test for Paged Stashing
vasunvidia Jan 9, 2026
4ecacac
Initial check in of a) force load imbalance b) log overload factors
nanz-nv Jan 12, 2026
e650d60
make overload factor logging work for cuda graph
nanz-nv Jan 19, 2026
a792a43
1. allocate stashing buffer based on avg token count if STASH_BUFFER_…
nanz-nv Jan 22, 2026
57b9714
Reenable overlapping of stashing kernels
nanz-nv Jan 23, 2026
639509d
Remove a buggy/redundant reset
nanz-nv Feb 3, 2026
5a0267f
Cleanup moe-expert-rank-capacity-factor argument.
vasunvidia Feb 9, 2026
b8ee0e7
Update moe_use_device_initiated_grouped_gemm check for paged stashing…
vasunvidia Feb 21, 2026
2c143d4
support use-dynamic-comp-stream
Wohox Mar 12, 2026
0c2da52
Revert "remove encoder_and_decoder from enums (#3406)"
nanz-nv Mar 17, 2026
5e83aa0
Remove the WAR of running warmup on a side stream
nanz-nv Mar 17, 2026
d053991
Reapply "remove encoder_and_decoder from enums (#3406)"
nanz-nv Mar 17, 2026
c62a865
Fix for data_iterator type check in Paged Stashing fallback
vasunvidia Mar 18, 2026
5ee817c
Change to support eager-mode fallback for validation
vasunvidia Mar 18, 2026
485dd7e
Revert "Check in dynamic-shape-aware SwiGLU triton kernel"
nanz-nv Mar 18, 2026
4ed4853
Fixed some minor issues
nanz-nv Mar 18, 2026
a6875f1
Fix the unit test
nanz-nv Mar 18, 2026
acda6d1
Initial commit for spill to cpu feature
nanz-nv Mar 14, 2026
6ddc49b
Move paged stashing knobs from env vars to transformer_config knobs
nanz-nv Mar 18, 2026
7c7ab96
Refactor the knobs a bit so it is more intuitive
nanz-nv Mar 18, 2026
2dc0c53
Use get_attr_wrapped_model util to access moe and mtp layers
vasunvidia Mar 18, 2026
bfb9dd4
Refactor the unit test for paged stashing
nanz-nv Mar 20, 2026
cd89d4f
Clean up after rebase
nanz-nv Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions megatron/core/full_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import gc
import torch

from megatron.core.tensor_parallel.random import get_all_rng_states
Expand Down Expand Up @@ -98,10 +99,18 @@ class FullCudaGraphWrapper:
cuda_graph = {'training': None, 'validation': None}
result = {'training': None, 'validation': None}

def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1):
def __init__(
self,
forward_backward_func,
cuda_graph_warmup_steps=1,
moe_paged_stash=False,
moe_expert_rank_capacity_factor=None,
):
self.forward_backward_func = forward_backward_func
self.static_loader = StaticBufferLoader()
self.cuda_graph_warmup_steps = cuda_graph_warmup_steps
self.moe_paged_stash = moe_paged_stash
self.moe_expert_rank_capacity_factor = moe_expert_rank_capacity_factor

def data_read(self, data_iterator, model, training, num_microbatches):
"""Read all microbatch inputs from Dataloader and copy to static buffers."""
Expand Down Expand Up @@ -180,19 +189,48 @@ def __call__(self, *args, **kwargs):
torch.cuda.synchronize()
torch.distributed.barrier()
logger.info(f'CUDA graph capture done for {training_str}!!!')

if FullCudaGraphWrapper.cuda_graph[training_str] is None:
FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs)
else:
FullCudaGraphWrapper.cuda_graph[training_str].replay()

self.next_iter(training_str)
return FullCudaGraphWrapper.result[training_str]

def speculative_cuda_graph_check(self, model):
'''check speculative execution modules'''
if self.moe_expert_rank_capacity_factor is not None:
# Check if there is any overflow in the receiving buffer
over_budget = torch.zeros(1, dtype=torch.bool, device='cuda')
for model_chunk in model:
for layer in model_chunk.module.module.decoder.layers:
mlp = layer.mlp
if hasattr(mlp, 'token_dispatcher') and hasattr(
mlp.token_dispatcher, 'check_over_budget'
):
over_budget |= mlp.token_dispatcher.check_over_budget()
if over_budget.item():
raise Exception(f"Rank {torch.distributed.get_rank()} overbudget")

def curr_iter(self, stage):
"""Return current training/validation iteration."""
return FullCudaGraphWrapper.curr_iteration[stage]

def next_iter(self, stage):
"""Increment current training/validation iteration."""
FullCudaGraphWrapper.curr_iteration[stage] += 1

def reset_cuda_graph(self, stage=None):
"""Reset CUDA graph."""
if stage is None or stage == 'training':
if FullCudaGraphWrapper.cuda_graph['training'] is not None:
del FullCudaGraphWrapper.cuda_graph['training']
FullCudaGraphWrapper.cuda_graph['training'] = None
FullCudaGraphWrapper.result['training'] = None
FullCudaGraphWrapper.curr_iteration['training'] = 0
if stage is None or stage == 'validation':
if FullCudaGraphWrapper.cuda_graph['validation'] is not None:
del FullCudaGraphWrapper.cuda_graph['validation']
FullCudaGraphWrapper.cuda_graph['validation'] = None
FullCudaGraphWrapper.result['validation'] = None
FullCudaGraphWrapper.curr_iteration['validation'] = 0
gc.collect()
1 change: 1 addition & 0 deletions megatron/core/fusions/fused_bias_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,4 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False):

# bias_swiglu_impl = BiasSwiGLUFunction.apply
# swiglu_impl = SwiGLUFunction.apply

8 changes: 8 additions & 0 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,14 @@ class ModelParallelConfig:
in 1f1b phase of pipelining or non-pipelining schedule.
"""

use_dynamic_comp_stream: bool = False
"""Use dynamic computation stream selection instead of binding to the default stream.
When enabled, get_comp_stream() returns torch.cuda.current_stream() at call time,
allowing CUDA graph capture and replay on non-default streams. This is required for
full-iteration CUDA graph with 1f1b EP overlap where the capture stream differs
from the default stream.
"""

delay_wgrad_compute: bool = False
"""Delay the weight gradient computation to improve batch-level communication overlapping"""

Expand Down
24 changes: 15 additions & 9 deletions megatron/core/models/common/model_chunk_schedule_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_comp_stream,
)
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.paged_stash import paged_stash_set_last_layer


class ModelChunkState:
Expand Down Expand Up @@ -63,8 +64,8 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
event (torch.cuda.Event):
record CUDA event across multiple nodes on different streams for synchronization.
chunk_state (ModelChunkState): model state shared in the model chunk.
comp_stream (torch.cuda.Stream): CUDA stream for computation.
comm_stream (torch.cuda.Stream): CUDA stream for communication.
comp_stream (Callable): Func that returns CUDA stream for computation.
comm_stream (Callable): Func that returns CUDA stream for communication.
extra_args (dict): extra arguments for the layer.

The event and chunk_state are binded to the TransformerModelChunkSchedulePlan
Expand Down Expand Up @@ -317,9 +318,6 @@ def __init__(
self.post_process = None
self.vp_stage = model.vp_stage

comp_stream = get_comp_stream()
comm_stream = get_comm_stream()

# save the inputs of model.forward() to ModelChunkState
self._model_chunk_state.input_ids = input_ids
self._model_chunk_state.position_ids = position_ids
Expand All @@ -338,18 +336,22 @@ def __init__(
self._model_chunk_state.attention_bias = None

# build preprocess
self.pre_process = PreProcessNode(model, self._model_chunk_state, self._event, comp_stream)
self.pre_process = PreProcessNode(
model, self._model_chunk_state, self._event, get_comp_stream
)

# build layer schedule plan for each layer.
# The methods to obtain layers are different for MTP so we need the other build plan for
# MTP. Also, this can help annotate MTP layer so that it can know where MTP is.
self._build_layer_schedule_plan(model.decoder, comp_stream, comm_stream)
self._build_layer_schedule_plan(getattr(model, "mtp", None), comp_stream, comm_stream)
self._build_layer_schedule_plan(model.decoder, get_comp_stream, get_comm_stream)
self._build_layer_schedule_plan(
getattr(model, "mtp", None), get_comp_stream, get_comm_stream
)

# build post process
if model.post_process:
self.post_process = PostProcessNode(
model, self._model_chunk_state, self._event, comp_stream
model, self._model_chunk_state, self._event, get_comp_stream
)

def _build_layer_schedule_plan(self, module, comp_stream, comm_stream):
Expand Down Expand Up @@ -479,6 +481,8 @@ def run(
f_layer = f_schedule_plan.get_layer(i)
b_layer = b_schedule_plan.pop_layer()
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b")
if f_layer.layer.config.moe_paged_stash:
paged_stash_set_last_layer(i == f_num_layers - 1)
f_input, b_grad = TransformerLayerSchedulePlan.run(
f_layer,
b_layer,
Expand All @@ -505,6 +509,8 @@ def run(
for i in range(overlapped_layers, f_num_layers):
f_layer = f_schedule_plan.get_layer(i)
torch.cuda.nvtx.range_push(f"layer_{i}f")
if f_layer.layer.config.moe_paged_stash:
paged_stash_set_last_layer(i == f_num_layers - 1)
f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input)
torch.cuda.nvtx.range_pop()

Expand Down
4 changes: 3 additions & 1 deletion megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import weakref
from contextlib import nullcontext
from functools import partial
from typing import Optional
from typing import Callable, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -330,6 +330,8 @@ def backward_dw(self):
"""Computes the weight gradients for the transformer layer node."""
if not self.delay_wgrad_compute:
return
if isinstance(self.stream, Callable):
self.stream = self.stream()
with torch.cuda.stream(self.stream):
torch.cuda.nvtx.range_push(f"{self.name} wgrad")
for module in self.bwd_dw_callables:
Expand Down
12 changes: 12 additions & 0 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.enums import CudaGraphScope, ModelType
from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule
from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler
from megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
mtp_on_this_rank,
Expand Down Expand Up @@ -473,6 +474,12 @@ def preprocess_for_fine_grained_offloading(self):
off_interface.mark_not_offloadable(param)
self.disable_param_offloading = False

def preprocess_for_paged_stash(self):
"""Preprocess for paged stash."""
return paged_stash_init_chunk_handler(
vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage
)

def forward(
self,
input_ids: Tensor,
Expand Down Expand Up @@ -505,6 +512,9 @@ def forward(
if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()

if self.config.moe_paged_stash:
self.preprocess_for_paged_stash()

inference_context = deprecate_inference_params(inference_context, inference_params)

preproc_output = self._preprocess(
Expand Down Expand Up @@ -745,6 +755,8 @@ def build_schedule_plan(

if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()
if self.config.moe_paged_stash:
self.preprocess_for_paged_stash()

from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan

Expand Down
13 changes: 9 additions & 4 deletions megatron/core/pipeline_parallel/combined_1f1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
from megatron.core.pipeline_parallel.utils import AbstractSchedulePlan, ScheduleNode, set_streams
from megatron.core.pipeline_parallel.utils import (
AbstractSchedulePlan,
ScheduleNode,
get_comp_stream,
set_streams,
)
from megatron.core.utils import get_attr_wrapped_model

# Types
Expand Down Expand Up @@ -47,7 +52,7 @@ def combined_1f1b_schedule_for_no_pipelining(
Phases 4: 4th microbatch backward
"""

set_streams()
set_streams(use_dynamic_comp_stream=config.use_dynamic_comp_stream)
# The forward step for the first microbatch is executed alone, no a2a overlapping
output_tensor, num_tokens, _ = combined_forward_backward_step(
forward_step_func,
Expand Down Expand Up @@ -173,7 +178,7 @@ def combined_1f1b_schedule_for_interleaved_pipelining():
# backward_step_helper_postprocess()
"""

set_streams()
set_streams(use_dynamic_comp_stream=config.use_dynamic_comp_stream)
# forward prepare
f_model_chunk_id = None
f_microbatch_id = None
Expand Down Expand Up @@ -405,7 +410,7 @@ def forward_backward_step():
from megatron.core.pipeline_parallel.schedules import forward_step_calc_loss

loss_node = ScheduleNode(
loss_func, torch.cuda.current_stream(), f_schedule_plan.event, name="loss_func"
loss_func, get_comp_stream, f_schedule_plan.event, name="loss_func"
)
loss_func = loss_node.forward
output_tensor, num_tokens = forward_step_calc_loss(
Expand Down
7 changes: 7 additions & 0 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.paged_stash import paged_stash_reset
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
drain_embedding_wgrad_compute,
Expand Down Expand Up @@ -590,6 +591,8 @@ def forward_backward_no_pipelining(
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
Expand Down Expand Up @@ -1049,6 +1052,8 @@ def forward_backward_pipelining_with_interleaving(
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"

paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")

Expand Down Expand Up @@ -2232,6 +2237,8 @@ def forward_backward_pipelining_without_interleaving(
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None:
Expand Down
Loading