diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 7c11195f33b..faff0c4ab59 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -4,6 +4,7 @@ import logging +import gc import torch from megatron.core.tensor_parallel.random import get_all_rng_states @@ -98,10 +99,18 @@ class FullCudaGraphWrapper: cuda_graph = {'training': None, 'validation': None} result = {'training': None, 'validation': None} - def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1): + def __init__( + self, + forward_backward_func, + cuda_graph_warmup_steps=1, + moe_paged_stash=False, + moe_expert_rank_capacity_factor=None, + ): self.forward_backward_func = forward_backward_func self.static_loader = StaticBufferLoader() self.cuda_graph_warmup_steps = cuda_graph_warmup_steps + self.moe_paged_stash = moe_paged_stash + self.moe_expert_rank_capacity_factor = moe_expert_rank_capacity_factor def data_read(self, data_iterator, model, training, num_microbatches): """Read all microbatch inputs from Dataloader and copy to static buffers.""" @@ -180,15 +189,28 @@ def __call__(self, *args, **kwargs): torch.cuda.synchronize() torch.distributed.barrier() logger.info(f'CUDA graph capture done for {training_str}!!!') - if FullCudaGraphWrapper.cuda_graph[training_str] is None: FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: FullCudaGraphWrapper.cuda_graph[training_str].replay() - self.next_iter(training_str) return FullCudaGraphWrapper.result[training_str] + def speculative_cuda_graph_check(self, model): + '''check speculative execution modules''' + if self.moe_expert_rank_capacity_factor is not None: + # Check if there is any overflow in the receiving buffer + over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') + for model_chunk in model: + for layer in model_chunk.module.module.decoder.layers: + mlp = layer.mlp + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher, 'check_over_budget' + ): + over_budget |= mlp.token_dispatcher.check_over_budget() + if over_budget.item(): + raise Exception(f"Rank {torch.distributed.get_rank()} overbudget") + def curr_iter(self, stage): """Return current training/validation iteration.""" return FullCudaGraphWrapper.curr_iteration[stage] @@ -196,3 +218,19 @@ def curr_iter(self, stage): def next_iter(self, stage): """Increment current training/validation iteration.""" FullCudaGraphWrapper.curr_iteration[stage] += 1 + + def reset_cuda_graph(self, stage=None): + """Reset CUDA graph.""" + if stage is None or stage == 'training': + if FullCudaGraphWrapper.cuda_graph['training'] is not None: + del FullCudaGraphWrapper.cuda_graph['training'] + FullCudaGraphWrapper.cuda_graph['training'] = None + FullCudaGraphWrapper.result['training'] = None + FullCudaGraphWrapper.curr_iteration['training'] = 0 + if stage is None or stage == 'validation': + if FullCudaGraphWrapper.cuda_graph['validation'] is not None: + del FullCudaGraphWrapper.cuda_graph['validation'] + FullCudaGraphWrapper.cuda_graph['validation'] = None + FullCudaGraphWrapper.result['validation'] = None + FullCudaGraphWrapper.curr_iteration['validation'] = 0 + gc.collect() \ No newline at end of file diff --git a/megatron/core/fusions/fused_bias_swiglu.py b/megatron/core/fusions/fused_bias_swiglu.py index 632470876c9..1161c832d79 100644 --- a/megatron/core/fusions/fused_bias_swiglu.py +++ b/megatron/core/fusions/fused_bias_swiglu.py @@ -253,3 +253,4 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False): # bias_swiglu_impl = BiasSwiGLUFunction.apply # swiglu_impl = SwiGLUFunction.apply + diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 075aa75c76a..d8e66c0e52a 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -275,6 +275,14 @@ class ModelParallelConfig: in 1f1b phase of pipelining or non-pipelining schedule. """ + use_dynamic_comp_stream: bool = False + """Use dynamic computation stream selection instead of binding to the default stream. + When enabled, get_comp_stream() returns torch.cuda.current_stream() at call time, + allowing CUDA graph capture and replay on non-default streams. This is required for + full-iteration CUDA graph with 1f1b EP overlap where the capture stream differs + from the default stream. + """ + delay_wgrad_compute: bool = False """Delay the weight gradient computation to improve batch-level communication overlapping""" diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 2e26e5fd1d3..506f290e7d8 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -15,6 +15,7 @@ get_comp_stream, ) from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.moe.paged_stash import paged_stash_set_last_layer class ModelChunkState: @@ -63,8 +64,8 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar event (torch.cuda.Event): record CUDA event across multiple nodes on different streams for synchronization. chunk_state (ModelChunkState): model state shared in the model chunk. - comp_stream (torch.cuda.Stream): CUDA stream for computation. - comm_stream (torch.cuda.Stream): CUDA stream for communication. + comp_stream (Callable): Func that returns CUDA stream for computation. + comm_stream (Callable): Func that returns CUDA stream for communication. extra_args (dict): extra arguments for the layer. The event and chunk_state are binded to the TransformerModelChunkSchedulePlan @@ -317,9 +318,6 @@ def __init__( self.post_process = None self.vp_stage = model.vp_stage - comp_stream = get_comp_stream() - comm_stream = get_comm_stream() - # save the inputs of model.forward() to ModelChunkState self._model_chunk_state.input_ids = input_ids self._model_chunk_state.position_ids = position_ids @@ -338,18 +336,22 @@ def __init__( self._model_chunk_state.attention_bias = None # build preprocess - self.pre_process = PreProcessNode(model, self._model_chunk_state, self._event, comp_stream) + self.pre_process = PreProcessNode( + model, self._model_chunk_state, self._event, get_comp_stream + ) # build layer schedule plan for each layer. # The methods to obtain layers are different for MTP so we need the other build plan for # MTP. Also, this can help annotate MTP layer so that it can know where MTP is. - self._build_layer_schedule_plan(model.decoder, comp_stream, comm_stream) - self._build_layer_schedule_plan(getattr(model, "mtp", None), comp_stream, comm_stream) + self._build_layer_schedule_plan(model.decoder, get_comp_stream, get_comm_stream) + self._build_layer_schedule_plan( + getattr(model, "mtp", None), get_comp_stream, get_comm_stream + ) # build post process if model.post_process: self.post_process = PostProcessNode( - model, self._model_chunk_state, self._event, comp_stream + model, self._model_chunk_state, self._event, get_comp_stream ) def _build_layer_schedule_plan(self, module, comp_stream, comm_stream): @@ -479,6 +481,8 @@ def run( f_layer = f_schedule_plan.get_layer(i) b_layer = b_schedule_plan.pop_layer() torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b") + if f_layer.layer.config.moe_paged_stash: + paged_stash_set_last_layer(i == f_num_layers - 1) f_input, b_grad = TransformerLayerSchedulePlan.run( f_layer, b_layer, @@ -505,6 +509,8 @@ def run( for i in range(overlapped_layers, f_num_layers): f_layer = f_schedule_plan.get_layer(i) torch.cuda.nvtx.range_push(f"layer_{i}f") + if f_layer.layer.config.moe_paged_stash: + paged_stash_set_last_layer(i == f_num_layers - 1) f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input) torch.cuda.nvtx.range_pop() diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 6658b6363ea..8d1036b5bae 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -3,7 +3,7 @@ import weakref from contextlib import nullcontext from functools import partial -from typing import Optional +from typing import Callable, Optional import torch from torch import Tensor @@ -330,6 +330,8 @@ def backward_dw(self): """Computes the weight gradients for the transformer layer node.""" if not self.delay_wgrad_compute: return + if isinstance(self.stream, Callable): + self.stream = self.stream() with torch.cuda.stream(self.stream): torch.cuda.nvtx.range_push(f"{self.name} wgrad") for module in self.bwd_dw_callables: diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 27b62f91c34..df732ef8d94 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -26,6 +26,7 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.enums import CudaGraphScope, ModelType from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule +from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, mtp_on_this_rank, @@ -473,6 +474,12 @@ def preprocess_for_fine_grained_offloading(self): off_interface.mark_not_offloadable(param) self.disable_param_offloading = False + def preprocess_for_paged_stash(self): + """Preprocess for paged stash.""" + return paged_stash_init_chunk_handler( + vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage + ) + def forward( self, input_ids: Tensor, @@ -505,6 +512,9 @@ def forward( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() + if self.config.moe_paged_stash: + self.preprocess_for_paged_stash() + inference_context = deprecate_inference_params(inference_context, inference_params) preproc_output = self._preprocess( @@ -745,6 +755,8 @@ def build_schedule_plan( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() + if self.config.moe_paged_stash: + self.preprocess_for_paged_stash() from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan diff --git a/megatron/core/pipeline_parallel/combined_1f1b.py b/megatron/core/pipeline_parallel/combined_1f1b.py index 232d9c8cd70..892832059d7 100644 --- a/megatron/core/pipeline_parallel/combined_1f1b.py +++ b/megatron/core/pipeline_parallel/combined_1f1b.py @@ -8,7 +8,12 @@ from megatron.core.enums import Fp8Recipe from megatron.core.fp8_utils import get_fp8_context -from megatron.core.pipeline_parallel.utils import AbstractSchedulePlan, ScheduleNode, set_streams +from megatron.core.pipeline_parallel.utils import ( + AbstractSchedulePlan, + ScheduleNode, + get_comp_stream, + set_streams, +) from megatron.core.utils import get_attr_wrapped_model # Types @@ -47,7 +52,7 @@ def combined_1f1b_schedule_for_no_pipelining( Phases 4: 4th microbatch backward """ - set_streams() + set_streams(use_dynamic_comp_stream=config.use_dynamic_comp_stream) # The forward step for the first microbatch is executed alone, no a2a overlapping output_tensor, num_tokens, _ = combined_forward_backward_step( forward_step_func, @@ -173,7 +178,7 @@ def combined_1f1b_schedule_for_interleaved_pipelining(): # backward_step_helper_postprocess() """ - set_streams() + set_streams(use_dynamic_comp_stream=config.use_dynamic_comp_stream) # forward prepare f_model_chunk_id = None f_microbatch_id = None @@ -405,7 +410,7 @@ def forward_backward_step(): from megatron.core.pipeline_parallel.schedules import forward_step_calc_loss loss_node = ScheduleNode( - loss_func, torch.cuda.current_stream(), f_schedule_plan.event, name="loss_func" + loss_func, get_comp_stream, f_schedule_plan.event, name="loss_func" ) loss_func = loss_node.forward output_tensor, num_tokens = forward_step_calc_loss( diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index ed3794208f0..03dbcf1f79c 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -22,6 +22,7 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.moe.paged_stash import paged_stash_reset from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.utils import ( drain_embedding_wgrad_compute, @@ -590,6 +591,8 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) + no_sync_func = config.no_sync_func if no_sync_func is None: no_sync_func = contextlib.nullcontext @@ -1049,6 +1052,8 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) + if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -2232,6 +2237,8 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) + # Disable async grad reductions no_sync_func = config.no_sync_func if no_sync_func is None: diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index 8f6b25eec32..08da68971ec 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -154,7 +154,7 @@ def __init__( Args: forward_func (callable): Function to execute during the forward pass. - stream (torch.cuda.Stream): The CUDA stream for this node's computation. + stream (Callable): Func that returns CUDA stream for computation. This can be either a 'compute' stream or a 'communicate' stream. - 'compute' stream: Used for computational nodes like attention and experts. - 'communicate' stream: Used for nodes that handle token communication, @@ -198,6 +198,9 @@ def forward(self, inputs=()): return self._forward(*inputs) def _forward(self, *inputs): + # Lazy initialization of stream + if isinstance(self.stream, Callable): + self.stream = self.stream() with self.stream_acquire_context(f"{self.name} forward"): self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs] for i, input in enumerate(self.inputs): @@ -235,6 +238,9 @@ def backward(self, output_grad): return self._backward(*output_grad) def _backward(self, *output_grad): + # Lazy initialization of stream + if isinstance(self.stream, Callable): + self.stream = self.stream() with self.stream_acquire_context(f"{self.name} backward"): outputs = self.output if not isinstance(outputs, tuple): @@ -323,31 +329,56 @@ def run( ... +_USE_DYNAMIC_COMP_STREAM = None _COMP_STREAM = None _COMM_STREAM = None -def set_streams(comp_stream=None, comm_stream=None): - """Set the streams for communication and computation""" +def set_streams(comp_stream=None, comm_stream=None, use_dynamic_comp_stream=False): + """Set the streams for communication and computation. + + When use_dynamic_comp_stream is True, get_comp_stream() will return + torch.cuda.current_stream() at call time instead of a cached stream, + which is required for full-iteration CUDA graph capture/replay where + the capture stream differs from the default stream. + """ global _COMP_STREAM global _COMM_STREAM - if _COMP_STREAM is not None and _COMM_STREAM is not None: + global _USE_DYNAMIC_COMP_STREAM + + if _USE_DYNAMIC_COMP_STREAM is None: + _USE_DYNAMIC_COMP_STREAM = use_dynamic_comp_stream + + # Set communication stream + if _COMM_STREAM is None: + if comm_stream is None: + comm_stream = torch.cuda.Stream(device="cuda") + _COMM_STREAM = comm_stream + + # In dynamic mode, comp stream is resolved at call time via current_stream() + if _USE_DYNAMIC_COMP_STREAM: + _COMP_STREAM = None return + if _COMP_STREAM is None: + if comp_stream is None: + comp_stream = torch.cuda.current_stream() + _COMP_STREAM = comp_stream - if comp_stream is None: - comp_stream = torch.cuda.current_stream() - if comm_stream is None: - comm_stream = torch.cuda.Stream(device="cuda") - assert _COMP_STREAM is None - assert _COMM_STREAM is None - _COMP_STREAM = comp_stream - _COMM_STREAM = comm_stream +def reset_streams(): + """Reset all stream state. Intended for testing or reinitialisation.""" + global _COMP_STREAM, _COMM_STREAM, _USE_DYNAMIC_COMP_STREAM + _USE_DYNAMIC_COMP_STREAM = None + _COMP_STREAM = None + _COMM_STREAM = None def get_comp_stream(): """Get the stream for computation""" global _COMP_STREAM + global _USE_DYNAMIC_COMP_STREAM + if _USE_DYNAMIC_COMP_STREAM: + return torch.cuda.current_stream() return _COMP_STREAM diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 8168c8ab611..12d7b2998fc 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -6,6 +6,7 @@ from collections.abc import Callable from copy import deepcopy from dataclasses import dataclass +from contextlib import nullcontext from functools import partial from itertools import chain from math import ceil @@ -50,6 +51,11 @@ ProcessGroupCollection, get_align_size_for_quantization, ) +from megatron.core.transformer.moe.paged_stash import ( + get_paged_stash_context, + paged_stash_group_commit, + paged_stash_group_start, +) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( ensure_metadata_has_dp_cp_group, @@ -534,6 +540,36 @@ def backward_dw(self): """ pass +def _pad_unpad(inp, pad): + if pad: + if inp.ndim == 2: + result = torch.nn.functional.pad(inp, (0, 0, 0, 256)) + elif inp.ndim == 1: + result = torch.nn.functional.pad(inp, (0, 256)) + else: + raise ValueError(f"Input dimension {inp.ndim} not supported") + else: + if inp.ndim == 2: + result = inp[:-256, :] + elif inp.ndim == 1: + result = inp[:-256] + else: + raise ValueError(f"Input dimension {inp.ndim} not supported") + assert result.shape[0] == 0 + return result + + +class PadUnpadFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp, pad): + result = _pad_unpad(inp, pad) + ctx.pad = not pad + return result + + @staticmethod + def backward(ctx, grad_out): + return _pad_unpad(grad_out, ctx.pad), None + class GroupedLinearFc1Interface(Protocol): """Interface for linear_fc1 module in TEGroupedMLP.""" @@ -698,6 +734,15 @@ def __init__( and "moe_act" in self.config.offload_modules ) + stash_modules = self.config.stash_modules or [] + self.moe_paged_stash_expert_fc1 = ( + self.config.moe_paged_stash and "expert_fc1" in stash_modules + ) + self.moe_paged_stash_moe_act = self.config.moe_paged_stash and "moe_act" in stash_modules + self.moe_paged_stash_expert_fc2 = ( + self.config.moe_paged_stash and "expert_fc2" in stash_modules + ) + self.activation_recompute = ( self.config.recompute_granularity == 'selective' and "moe_act" in self.config.recompute_modules @@ -791,6 +836,8 @@ def _is_fused_impl_supported(self) -> bool: if self.activation_func != F.silu or not self.config.gated_linear_unit: return False # Expected SwiGLU activation + if not self.config.use_transformer_engine_op_fuser: + return False return True def _make_fused_ops(self) -> torch.nn.Module: @@ -904,6 +951,8 @@ def _fused_forward( ) -> torch.Tensor: """Forward pass using Transformer Engine operation fuser API.""" + if self.config.moe_expert_rank_capacity_factor is not None: + assert self.config.moe_router_padding_for_quantization, "moe_expert_rank_capacity_factor requires moe_router_padding_for_quantization" # Construct fused impl if needed # Note: We initialize during the first forward pass in case # the params are modified after the constructor. @@ -931,19 +980,45 @@ def _fused_forward( tokens_per_expert = torch.tensor( tokens_per_expert, dtype=torch.int, device=permuted_probs.device ) - - # Call fused impl - output = ops( - permuted_local_hidden_states, - tokens_per_expert, # FC1 - permuted_probs, # Scaled SwiGLU - tokens_per_expert, # FC2 - ) - + # if the number of tokens is 0, pad the hidden states to 256 + apply_pad_unpad = False + if permuted_local_hidden_states.shape[0] == 0 and not torch.cuda.is_current_stream_capturing(): + apply_pad_unpad = True + permuted_local_hidden_states = PadUnpadFunction.apply(permuted_local_hidden_states, True) + permuted_probs = PadUnpadFunction.apply(permuted_probs, True) + + if self.config.moe_paged_stash: + permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states) + max_num_tokens = permuted_local_hidden_states.shape[0] + # Average/expected tokens is a pre-padding estimate used by paged stashing heuristics. + # moe_expert_rank_capacity_factor is required when moe_paged_stash is enabled. + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None + ) + offload_context = get_paged_stash_context( + name="expert_fc1_fused", + max_num_tokens=max_num_tokens, + num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, + ) + else: + offload_context = nullcontext() + with offload_context: + # Call fused impl + output = ops( + permuted_local_hidden_states, + tokens_per_expert, # FC1 + permuted_probs, # Scaled SwiGLU + tokens_per_expert, # FC2 + ) + if apply_pad_unpad: + output = PadUnpadFunction.apply(output, False) # Remove padding if needed if unpadded_tokens_per_expert is not None: output = self.quantization_unpadding(output, unpadded_tokens_per_expert) - + if self.config.moe_paged_stash: + output = paged_stash_group_commit(output, name="expert_fc1_fused") return output def forward( @@ -1001,9 +1076,28 @@ def forward( with off_interface( self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" ) as permuted_local_hidden_states: - fc1_output, bias_parallel = apply_module(self.linear_fc1)( - permuted_local_hidden_states, tokens_per_expert - ) + if self.config.moe_paged_stash: + permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states) + if self.moe_paged_stash_expert_fc1: + max_num_tokens = permuted_local_hidden_states.shape[0] + # Average/expected tokens is a pre-padding estimate used by paged stashing heuristics. + # moe_expert_rank_capacity_factor is required when moe_paged_stash is enabled. + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None + ) + offload_context = get_paged_stash_context( + name="expert_fc1", + max_num_tokens=max_num_tokens, + num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, + ) + else: + offload_context = nullcontext() + with offload_context: + fc1_output, bias_parallel = apply_module(self.linear_fc1)( + permuted_local_hidden_states, tokens_per_expert + ) if self.offload_expert_fc1: fc1_output = off_interface.group_commit( fc1_output, @@ -1102,9 +1196,47 @@ def glu(x): ) else: with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: - bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) + if self.moe_paged_stash_moe_act: + max_num_tokens = fc1_output.shape[0] + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) + if cap_factor is not None and cap_factor > 0 + else None + ) + offload_context = get_paged_stash_context( + name="moe_act", + max_num_tokens=max_num_tokens, + num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, + ) + else: + offload_context = nullcontext() + with offload_context: + bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) + if self.offload_moe_act: + (bias_act_output,) = fine_grained_offloading_group_commit( + bias_act_output, name="moe_act", forced_released_tensors=[fc1_output] + ) - output, output_bias = apply_module(self.linear_fc2)(bias_act_output, tokens_per_expert) + if self.moe_paged_stash_expert_fc2: + max_num_tokens = bias_act_output.shape[0] + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None + ) + offload_context = get_paged_stash_context( + name="expert_fc2", + max_num_tokens=max_num_tokens, + num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, + ) + else: + offload_context = nullcontext() + with offload_context: + output, output_bias = apply_module(self.linear_fc2)(bias_act_output, tokens_per_expert) + if self.config.moe_paged_stash: + output = paged_stash_group_commit(output, name="expert_fc2") if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index dbcc25a905c..edefe67356c 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -44,6 +44,9 @@ HAVE_TE = False +# MOE logging +_MOE_OVERLOAD_FACTOR_TRACKER = {} + def switch_load_balancing_loss_func( probs: torch.Tensor, tokens_per_expert: torch.Tensor, @@ -953,6 +956,224 @@ def clear_aux_losses_tracker() -> None: get_moe_metrics_tracker().clear() +def get_overload_factor_tracker(): + """Return the overload factor tracker.""" + global _MOE_OVERLOAD_FACTOR_TRACKER + return _MOE_OVERLOAD_FACTOR_TRACKER + + +class SaveOverloadFactorFunction(torch.autograd.Function): + """Autograd function to save overload factor data for forward and backward passes.""" + + @staticmethod + def forward(ctx, tensor, routing_map, layer_number, num_local_experts): + """Forward pass: save overload factor data with 'fwd' label. + + Args: + tensor: A tensor in the autograd graph (e.g., probs) to pass through. + routing_map: The routing map tensor [num_tokens, num_experts]. + layer_number: Layer index (1-based). + num_local_experts: Number of experts per EP rank. + + Returns: + tensor unchanged (pass-through). + """ + if layer_number is None: + return tensor + + # Compute local tokens per expert + local_tokens_per_expert = routing_map.sum(dim=0).detach().float() + + # Group by EP rank: reshape [num_experts] -> [ep_size, num_local_experts] and sum + num_experts = local_tokens_per_expert.shape[0] + ep_size = num_experts // num_local_experts + local_tokens_per_ep_rank = local_tokens_per_expert.view(ep_size, num_local_experts).sum(dim=1) + + # Save to tracker + tracker = get_overload_factor_tracker() + if "fwd" not in tracker: + tracker["fwd"] = {} + if "fwd_bwd" not in tracker: + tracker["fwd_bwd"] = [] + + layer_idx = layer_number - 1 # Convert to 0-based index + if layer_idx not in tracker["fwd"]: + tracker["fwd"][layer_idx] = [] + + tracker["fwd"][layer_idx].append(local_tokens_per_ep_rank) + tracker["fwd_bwd"].append(local_tokens_per_ep_rank) + + # Save for backward + ctx.save_for_backward(local_tokens_per_ep_rank) + + return tensor + + @staticmethod + def backward(ctx, grad_output): + """Backward pass: append negated tokens to fwd_bwd tracker.""" + if ctx.saved_tensors: + (local_tokens_per_ep_rank,) = ctx.saved_tensors + tracker = get_overload_factor_tracker() + if "fwd_bwd" in tracker: + tracker["fwd_bwd"].append(-local_tokens_per_ep_rank) + return grad_output, None, None, None + + +def save_overload_factor_to_tracker( + tensor: torch.Tensor, + routing_map: torch.Tensor, + layer_number: int, + num_local_experts: int, + tp_ep_group: torch.distributed.ProcessGroup, + dp_group: torch.distributed.ProcessGroup, +): + """Save local tokens per EP rank and track token counts for activation memory analysis. + + This function wraps SaveOverloadFactorFunction to: + 1. Store data for overload factor computation (done in get_overload_factors_for_logging()) + 2. Track tokens in forward/backward for activation memory analysis + + Args: + tensor: A tensor in the autograd graph (e.g., probs) - passed through unchanged. + routing_map: The routing map tensor [num_tokens, num_experts]. + layer_number: Layer index (1-based). + num_local_experts: Number of experts per EP rank. + tp_ep_group: The TP x EP group for all-reducing. + dp_group: The DP group for max/avg reduction. + + Returns: + tensor unchanged. + """ + # Set comm groups in tracker (outside autograd function) + tracker = get_overload_factor_tracker() + if "to_clear" in tracker and tracker["to_clear"]: + if "fwd" in tracker: + tracker["fwd"].clear() + if "fwd_bwd" in tracker: + tracker["fwd_bwd"].clear() + tracker.pop("tp_ep_group", None) + tracker.pop("dp_group", None) + tracker.pop("to_clear", None) + + if "tp_ep_group" not in tracker: + tracker["tp_ep_group"] = tp_ep_group + tracker["dp_group"] = dp_group + + return SaveOverloadFactorFunction.apply(tensor, routing_map, layer_number, num_local_experts) + + +def get_overload_factors_for_logging() -> dict: + """Compute overload factors from stored data and return organized metrics for logging. + + This function performs: + 1. All-reduce over TP x EP to get tokens per EP rank within each DP rank + 2. MAX reduction across DP ranks to get max tokens per EP rank + 3. AVG reduction across DP ranks to get avg tokens per EP rank + 4. Computes overload_factor = max_tokens / avg_tokens + + Should be called outside the critical path (e.g., during logging). + + Returns: + dict: A dictionary with structure: + { + "avg_overload_factor": float, + "max_overload_factor": float, + "max_cumsum_tokens": float (peak tokens from cumsum of fwd/bwd), + } + """ + tracker = get_overload_factor_tracker() + if "fwd" not in tracker or not tracker["fwd"]: + return {} + tp_ep_group = tracker.get("tp_ep_group") + dp_group = tracker.get("dp_group") + + # Collect fwd tensors for overload factor calculation + fwd_tensors = [] + layer_indices = [] # layer_idx for each fwd tensor + + for layer_idx in sorted(tracker["fwd"].keys()): + for local_tokens_per_ep_rank in tracker["fwd"][layer_idx]: + fwd_tensors.append(local_tokens_per_ep_rank) + layer_indices.append(layer_idx) + + if not fwd_tensors: + return {} + + # Stack fwd_bwd tensors (already has fwd positive, bwd negative) + fwd_bwd_tensors = tracker.get("fwd_bwd", []) + fwd_bwd_tensors_stacked = torch.stack(fwd_bwd_tensors, dim=0) if fwd_bwd_tensors else None + # All-reduce across tp_ep_group, cumsum, and find max + max_cum_overload_factor = None + if fwd_bwd_tensors_stacked is not None: + if tp_ep_group is not None: + torch.distributed.all_reduce(fwd_bwd_tensors_stacked, group=tp_ep_group) + cumsum_tokens = fwd_bwd_tensors_stacked.cumsum(dim=0) + max_cumsum_tokens = cumsum_tokens.max().item() + + # Calculate max_cum_overload_factor = max_cumsum_tokens / cumsum_tokens.mean(dim=1).max() + mean_cumsum_max = cumsum_tokens.mean(dim=1).max() + local_max_cum_overload_factor = max_cumsum_tokens / (mean_cumsum_max.item() + 1e-8) + + # Max all-reduce to find global max across DP ranks + if dp_group is not None: + cum_overload_tensor = torch.tensor( + [local_max_cum_overload_factor], device=fwd_bwd_tensors_stacked.device + ) + torch.distributed.all_reduce(cum_overload_tensor, group=dp_group, op=torch.distributed.ReduceOp.MAX) + max_cum_overload_factor = cum_overload_tensor.item() + else: + max_cum_overload_factor = local_max_cum_overload_factor + all_tensors = fwd_tensors + # Stack all tensors and do all-reduce over TP x EP + # Shape: [num_entries, ep_size] + stacked_tensors = torch.stack(all_tensors, dim=0) + if tp_ep_group is not None: + torch.distributed.all_reduce(stacked_tensors, group=tp_ep_group) + + # Now reduce across DP ranks: get both MAX and AVG + if dp_group is not None: + # Clone for max reduction + max_tokens_per_ep_rank = stacked_tensors.clone() + torch.distributed.all_reduce( + max_tokens_per_ep_rank, group=dp_group, op=torch.distributed.ReduceOp.MAX + ) + # AVG reduction for mean tokens + avg_tokens_per_ep_rank = stacked_tensors.clone() + torch.distributed.all_reduce( + avg_tokens_per_ep_rank, group=dp_group, op=torch.distributed.ReduceOp.AVG + ) + else: + max_tokens_per_ep_rank = stacked_tensors + avg_tokens_per_ep_rank = stacked_tensors + + # Compute two overload factors for each entry: + # 1. avg_overload_factor = max(avg_tokens) / mean(avg_tokens) - based on AVG across DP + # 2. max_overload_factor = max(max_tokens) / mean(max_tokens) - based on MAX across DP + avg_max_tokens = avg_tokens_per_ep_rank.max(dim=1).values # [num_entries] + avg_mean_tokens = avg_tokens_per_ep_rank.float().mean(dim=1) # [num_entries] + avg_overload_factors = (avg_max_tokens / (avg_mean_tokens + 1e-8)) # [num_entries] + + max_max_tokens = max_tokens_per_ep_rank.max(dim=1).values # [num_entries] + max_mean_tokens = max_tokens_per_ep_rank.float().mean(dim=1) # [num_entries] + max_overload_factors = (max_max_tokens / (max_mean_tokens + 1e-8)) # [num_entries] + + # Compute overall statistics + # avg_overload_factor uses mean, max_overload_factor uses max + result = { + "avg_overload_factor": avg_overload_factors.mean().item(), + "max_overload_factor": max_overload_factors.max().item(), + "max_cum_overload_factor": max_cum_overload_factor, + } + + return result + + +def clear_overload_factor_tracker(): + """Clear the overload factor tracker.""" + tracker = get_overload_factor_tracker() + tracker["to_clear"] = True + + @deprecated( version="0.16", removal_version="0.18", alternative="get_moe_metrics_tracker()._sync_metrics()" ) @@ -995,6 +1216,7 @@ def track_moe_metrics( writer: Optional["SummaryWriter"] = None, wandb_writer: Optional["wandb.Run"] = None, total_loss_dict: Optional[dict[str, torch.Tensor]] = None, + overload_dict=None, per_layer_logging: bool = False, force_initialize: bool = False, track_names: Optional[List[str]] = None, @@ -1007,6 +1229,27 @@ def track_moe_metrics( Deprecated: Use get_moe_metrics_tracker().report() directly. """ + # Log overload factor metrics + overload_metrics = get_overload_factors_for_logging() + if overload_metrics: + if overload_dict is not None: + overload_dict.update(overload_metrics) + if writer is not None: + if "avg_overload_factor" in overload_metrics: + writer.add_scalar("moe/avg_overload_factor", overload_metrics["avg_overload_factor"], iteration) + if "max_overload_factor" in overload_metrics: + writer.add_scalar("moe/max_overload_factor", overload_metrics["max_overload_factor"], iteration) + if "max_cum_overload_factor" in overload_metrics and overload_metrics["max_cum_overload_factor"] is not None: + writer.add_scalar("moe/max_cum_overload_factor", overload_metrics["max_cum_overload_factor"], iteration) + if wandb_writer: + if "avg_overload_factor" in overload_metrics: + wandb_writer.log({"moe/avg_overload_factor": overload_metrics["avg_overload_factor"]}, iteration) + if "max_overload_factor" in overload_metrics: + wandb_writer.log({"moe/max_overload_factor": overload_metrics["max_overload_factor"]}, iteration) + if "max_cum_overload_factor" in overload_metrics and overload_metrics["max_cum_overload_factor"] is not None: + wandb_writer.log({"moe/max_cum_overload_factor": overload_metrics["max_cum_overload_factor"]}, iteration) + clear_overload_factor_tracker() + return get_moe_metrics_tracker().report( loss_scale=loss_scale, iteration=iteration, diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py new file mode 100644 index 00000000000..06281583a9f --- /dev/null +++ b/megatron/core/transformer/moe/paged_stash.py @@ -0,0 +1,1241 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from contextlib import nullcontext +from typing import Any + +import torch +import triton +import triton.language as tl + +from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer +from megatron.core.full_cuda_graph import FullCudaGraphWrapper +from megatron.core.utils import get_attr_wrapped_model + +GLOBAL_BLOCK_SIZE = 1024 +SCALE_INV_BLOCK_SIZE = 32 + + +class PagedStashBuffer: + """ + A paged stash buffer with page-level memory management. + Supports both CUDA and optional pinned host buffer for overflow fallback. + + Buffers are organized as [num_pages, page_size, hidden_size]. + Uses per-buffer free lists (circular buffer) tracked as two-element state: [0]=CUDA, [1]=host. + """ + + def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype, num_tokens_host=0): + """ + Args: + num_tokens: Maximum number of tokens the CUDA buffer can hold + hidden_size: Hidden dimension size + page_size: Number of tokens per page + device: Device for the buffer + overflow: Overflow flag tensor (shared across all buffers) + dtype: Data type + num_tokens_host: If > 0, allocate pinned host buffer with this many tokens for spillover. + """ + self.hidden_size = hidden_size + self.page_size = page_size + self.device = device + self.dtype = dtype + self.overflow = overflow # GPU flag (shared) + + # CUDA buffer + self.num_cuda_pages = (num_tokens + page_size - 1) // page_size + self.total_cuda_tokens = self.num_cuda_pages * page_size + self.cuda_buffer = torch.empty( + (self.total_cuda_tokens, hidden_size), dtype=dtype, device=device + ) + + # Host buffer (pinned), optional + self.num_host_pages = (num_tokens_host + page_size - 1) // page_size if num_tokens_host > 0 else 0 + self.total_host_tokens = self.num_host_pages * page_size if self.num_host_pages > 0 else 0 + if self.num_host_pages > 0: + self.host_buffer = torch.empty( + (self.total_host_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True + ) + else: + self.host_buffer = None + + # Free list state: shape (2,) index 0 = CUDA, 1 = host (all in device memory for kernel) + self.free_list_head = torch.zeros(2, dtype=torch.int64, device=device) + self.free_list_tail = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], dtype=torch.int64, device=device + ) + self.free_list_capacity = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], dtype=torch.int64, device=device + ) + + # Free list arrays (device memory): page IDs for each buffer + self.free_list_cuda = torch.arange(self.num_cuda_pages, dtype=torch.int64, device=device) + if self.num_host_pages > 0: + self.free_list_host = torch.arange(self.num_host_pages, dtype=torch.int64, device=device) + else: + self.free_list_host = torch.empty(0, dtype=torch.int64, device=device) + + # Pre-allocated reset values (CUDA graph safe: no allocation in reset()) + self._reset_tail = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], + dtype=torch.int64, + device=device, + ) + self._reset_free_list_cuda = torch.arange( + self.num_cuda_pages, dtype=torch.int64, device=device + ) + if self.num_host_pages > 0: + self._reset_free_list_host = torch.arange( + self.num_host_pages, dtype=torch.int64, device=device + ) + else: + self._reset_free_list_host = None + + def reset(self): + """Reset both CUDA and host free lists (CUDA graph safe: no new allocations).""" + self.free_list_cuda.copy_(self._reset_free_list_cuda) + self.free_list_head.zero_() + self.free_list_tail.copy_(self._reset_tail) + if self._reset_free_list_host is not None: + self.free_list_host.copy_(self._reset_free_list_host) + + def __repr__(self): + return ( + f"PagedStashBuffer(num_cuda_pages={self.num_cuda_pages}, num_host_pages={self.num_host_pages}, " + f"page_size={self.page_size}, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" + ) + + +@triton.jit +def _paged_stash_copy_kernel( + src_ptr, + cuda_dst_ptr, + host_dst_ptr, + num_tokens_ptr, + free_list_cuda_ptr, + free_list_host_ptr, + free_list_head_ptr, # shape (2,): [cuda_head, host_head] + free_list_tail_ptr, # shape (2,) + free_list_capacity_ptr, + page_record_ptr, + overflow_ptr, + spilled_to_host_ptr, # Output: 0 = stored in CUDA, 1 = stored in host or overflow + new_free_list_head_ptr, # Output: shape (2,) updated heads + PAGE_SIZE: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_HOST_BUFFER: tl.constexpr, +): + """Copy tokens to paged stash: try CUDA first (fast path), then host if CUDA full.""" + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load overflow first (get in flight early); branch on it only before any write + overflow = tl.load(overflow_ptr) + + num_tokens = tl.load(num_tokens_ptr) + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + + # Common case: load only CUDA state (and head_host for output when use_cuda) + head_cuda = tl.load(free_list_head_ptr) + head_host = tl.load(free_list_head_ptr + 1) + tail_cuda = tl.load(free_list_tail_ptr) + cap_cuda = tl.load(free_list_capacity_ptr) + + avail_cuda = tail_cuda - head_cuda + use_cuda = avail_cuda >= required_pages + + # Assume CUDA path: set everything for GPU stash + spill = 0 + dst_ptr = cuda_dst_ptr + free_list_ptr = free_list_cuda_ptr + head = head_cuda + cap = cap_cuda + new_head_cuda = head_cuda + required_pages + new_head_host = head_host + + if overflow == 1: + return + + # Only when CUDA is full: load host state and maybe switch to host + if not use_cuda: + tail_host = tl.load(free_list_tail_ptr + 1) + cap_host = tl.load(free_list_capacity_ptr + 1) + use_host = HAS_HOST_BUFFER == 1 and (tail_host - head_host) >= required_pages + if use_host: + spill = 1 + dst_ptr = host_dst_ptr + free_list_ptr = free_list_host_ptr + head = head_host + cap = cap_host + new_head_cuda = head_cuda + new_head_host = head_host + required_pages + else: + if pid == 0: + tl.store(overflow_ptr, 1) + tl.store(spilled_to_host_ptr, 1) + tl.store(new_free_list_head_ptr, head_cuda) + tl.store(new_free_list_head_ptr + 1, head_host) + return + + if pid == 0: + tl.store(spilled_to_host_ptr, spill) + + # Copy loop: strided over tokens + token_idx = pid + while token_idx < num_tokens: + page_slot = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + free_list_idx = (head + page_slot) % cap + page_id = tl.load(free_list_ptr + free_list_idx) + if token_in_page == 0: + tl.store(page_record_ptr + page_slot, page_id) + dst_token_idx = page_id * PAGE_SIZE + token_in_page + + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + token_idx_i64 = token_idx.to(tl.int64) + dst_token_idx_i64 = dst_token_idx.to(tl.int64) + src_base = src_ptr + token_idx_i64 * HIDDEN_SIZE + dst_base = dst_ptr + dst_token_idx_i64 * HIDDEN_SIZE + + if need_mask: + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + token_idx += num_blocks + + if pid == 0: + tl.store(new_free_list_head_ptr, new_head_cuda) + tl.store(new_free_list_head_ptr + 1, new_head_host) + + +@triton.jit +def _paged_stash_pop_kernel( + cuda_src_ptr, + host_src_ptr, + dst_ptr, + num_tokens_ptr, + page_record_ptr, + spilled_to_host_ptr, # 0 = read from CUDA, 1 = read from host + overflow_ptr, + free_list_cuda_ptr, + free_list_host_ptr, + free_list_tail_ptr, # shape (2,) + free_list_capacity_ptr, + new_free_list_tail_ptr, # Output: shape (2,) updated tails + PAGE_SIZE: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Reload tokens from paged stash; CUDA path fast, host path when spilled_to_host.""" + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load overflow first (get in flight early); branch on it only before any write + overflow = tl.load(overflow_ptr) + + num_tokens = tl.load(num_tokens_ptr) + spill = tl.load(spilled_to_host_ptr) + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + + # Common case: load only CUDA state (and tail_host for output when spill=0) + tail_cuda = tl.load(free_list_tail_ptr) + tail_host = tl.load(free_list_tail_ptr + 1) + cap_cuda = tl.load(free_list_capacity_ptr) + + # Assume CUDA path + src_ptr = cuda_src_ptr + free_list_ptr = free_list_cuda_ptr + tail = tail_cuda + cap = cap_cuda + new_tail_cuda = tail_cuda + required_pages + new_tail_host = tail_host + + # Only when spilled to host: load host state and switch + if spill == 1: + cap_host = tl.load(free_list_capacity_ptr + 1) + if cap_host == 0: + return + src_ptr = host_src_ptr + free_list_ptr = free_list_host_ptr + tail = tail_host + cap = cap_host + new_tail_cuda = tail_cuda + new_tail_host = tail_host + required_pages + + if overflow == 1: + return + + token_idx = pid + while token_idx < num_tokens: + page_slot = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + page_id = tl.load(page_record_ptr + page_slot) + src_token_idx = page_id * PAGE_SIZE + token_in_page + + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + src_token_idx_i64 = src_token_idx.to(tl.int64) + token_idx_i64 = token_idx.to(tl.int64) + src_base = src_ptr + src_token_idx_i64 * HIDDEN_SIZE + dst_base = dst_ptr + token_idx_i64 * HIDDEN_SIZE + + if need_mask: + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + + if token_in_page == 0: + write_idx = (tail + page_slot) % cap + tl.store(free_list_ptr + write_idx, page_id) + token_idx += num_blocks + + if pid == 0: + tl.store(new_free_list_tail_ptr, new_tail_cuda) + tl.store(new_free_list_tail_ptr + 1, new_tail_host) + + +class PagedTensor: + """ + A paged tensor that stores data in pages within a paged stash buffer. + """ + + def __init__( + self, + tensor, + num_tokens_tensor=None, + avg_num_tokens: int = None, + vp_stage=None, + original_shape=None, + schedule_layer_no=None, + layer_name=None, + max_num_tokens=None, + hidden_size=None, + page_size=64, + ): + """ + Args: + tensor: The tensor to store + num_tokens_tensor: Scalar tensor containing actual number of tokens + vp_stage: Virtual pipeline stage + layer_name: Name of the layer + max_num_tokens: Maximum number of tokens + hidden_size: Hidden size + page_size: Number of tokens per page + """ + self._tensor = tensor + self._original_tensor = None + assert ( + num_tokens_tensor is not None + and isinstance(num_tokens_tensor, torch.Tensor) + and num_tokens_tensor.numel() == 1 + ) + self.num_tokens_tensor = num_tokens_tensor.clone() + self.avg_num_tokens = avg_num_tokens + self.vp_stage = vp_stage + self.schedule_layer_no = schedule_layer_no + self.layer_name = layer_name + self.max_num_tokens = max_num_tokens + self.hidden_size = hidden_size + self.page_size = page_size + + # Original tensor information + self.original_shape = list(tensor.shape) if original_shape is None else original_shape + self.element_size = tensor.element_size() + self.dtype = tensor.dtype + self.device = tensor.device + + # Calculate number of pages needed + self.max_num_pages = (self.max_num_tokens + page_size - 1) // page_size # Ceiling division + + # Page record: stores which pages are being used for this tensor + self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) + # Set by copy kernel: 0 = data in CUDA stash, 1 = data in host (pinned) stash + self.spilled_to_host = torch.zeros(1, dtype=torch.int64, device=self.device) + + @property + def schedule_layer(self): + """Get the schedule layer.""" + return self.schedule_layer_no + + def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): + """Offload the paged tensor to paged stash buffer (CUDA or host if CUDA full).""" + self._tensor = self._tensor.contiguous() + if self.num_tokens_tensor.dim() == 0: + self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) + if 'columnwise_scale_inv' in self.layer_name: + num_tokens_tensor = self.num_tokens_tensor // SCALE_INV_BLOCK_SIZE + max_num_tokens = self.max_num_tokens // SCALE_INV_BLOCK_SIZE + else: + num_tokens_tensor = self.num_tokens_tensor + max_num_tokens = self.max_num_tokens + + tensor_to_copy = self._tensor + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(max_num_tokens, max_blocks) + grid = (num_blocks,) + + new_free_list_head = torch.empty(2, dtype=torch.int64, device=self.device) + has_host = 1 if paged_stash_buffer.host_buffer is not None else 0 + host_dst = ( + paged_stash_buffer.host_buffer + if paged_stash_buffer.host_buffer is not None + else paged_stash_buffer.cuda_buffer + ) + + _paged_stash_copy_kernel[grid]( + tensor_to_copy.view(paged_stash_buffer.cuda_buffer.dtype), + paged_stash_buffer.cuda_buffer, + host_dst, + num_tokens_tensor, + paged_stash_buffer.free_list_cuda, + paged_stash_buffer.free_list_host, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + self.page_record, + paged_stash_buffer.overflow, + self.spilled_to_host, + new_free_list_head, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + HAS_HOST_BUFFER=has_host, + ) + # if self.spilled_to_host.item() == 1: + # print(f"PagedTensor {self.layer_name} spilled to host", flush=True) + # else: + # print(f"PagedTensor {self.layer_name} stashed to cuda", flush=True) + paged_stash_buffer.free_list_head.copy_(new_free_list_head) + self._original_tensor = self._tensor + self._tensor = None + + def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): + """Reload the paged tensor from paged stash buffer (CUDA or host from spilled_to_host).""" + self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) + tensor_to_reload = self._tensor + + if 'columnwise_scale_inv' in self.layer_name: + num_tokens_tensor = self.num_tokens_tensor // SCALE_INV_BLOCK_SIZE + max_num_tokens = self.max_num_tokens // SCALE_INV_BLOCK_SIZE + else: + num_tokens_tensor = self.num_tokens_tensor + max_num_tokens = self.max_num_tokens + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(max_num_tokens, max_blocks) + grid = (num_blocks,) + + new_free_list_tail = torch.empty(2, dtype=torch.int64, device=self.device) + host_src = ( + paged_stash_buffer.host_buffer + if paged_stash_buffer.host_buffer is not None + else paged_stash_buffer.cuda_buffer + ) + _paged_stash_pop_kernel[grid]( + paged_stash_buffer.cuda_buffer, + host_src, + tensor_to_reload.view(paged_stash_buffer.cuda_buffer.dtype), + num_tokens_tensor, + self.page_record, + self.spilled_to_host, + paged_stash_buffer.overflow, + paged_stash_buffer.free_list_cuda, + paged_stash_buffer.free_list_host, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + new_free_list_tail, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) + + +class PP_PreScheduleFunction(torch.autograd.Function): + """ + This function is used to update the pp schedule. + """ + + @staticmethod + def forward(ctx, tensor, stash_manager): # after forward + # pylint: disable=missing-function-docstring + ctx.stash_manager = stash_manager + # Wait for stash to complete before starting the next layer + stash_manager.wait_for_stash_to_complete() + return tensor + + @staticmethod + def backward(ctx, *grad_output): # before backward + # pylint: disable=missing-function-docstring + # Initiate reload for next layer + if ( + ctx.stash_manager.status == 'captured' + and ctx.stash_manager.current_schedule_index < len(ctx.stash_manager._pp_schedule) + ): + next_schedule_layer = ctx.stash_manager._pp_schedule[ + ctx.stash_manager.current_schedule_index + ] + if next_schedule_layer < 0: + ctx.stash_manager.reload_paged_tensors(-next_schedule_layer) + + return grad_output + (None, None) + + +class PP_PostScheduleFunction(torch.autograd.Function): + """ + This function is used to update the pp schedule. + """ + + @staticmethod + def forward(ctx, tensor, stash_manager): # after forward + # pylint: disable=missing-function-docstring + ctx.stash_manager = stash_manager + ctx.vp_stage = stash_manager.current_vp_stage + if ctx.vp_stage is None: + ctx.vp_stage = 0 + ctx.layer_no, ctx.microbatch_no = stash_manager.update_pp_schedule(ctx.vp_stage + 1) + + # Initiate stash for current layer and reload for next layer + if stash_manager.status == 'captured': + current_schedule_layer = stash_manager.get_schedule_layer( + ctx.vp_stage + 1, ctx.layer_no, ctx.microbatch_no + ) + next_schedule_layer = ctx.stash_manager._pp_schedule[ + ctx.stash_manager.current_schedule_index + 1 + ] + if current_schedule_layer != -next_schedule_layer: + # Start stash for current layer + ctx.stash_manager.stash_paged_tensors(current_schedule_layer) + if next_schedule_layer < 0: + # reload for next backward layer + ctx.stash_manager.reload_paged_tensors(-next_schedule_layer, no_wait=True) + else: + ctx.stash_manager.remove_paged_tensor_from_stash() + + ctx.stash_manager.current_schedule_index += 1 + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, *grad_output): # before backward + # pylint: disable=missing-function-docstring + if ctx.vp_stage is not None: + ctx.stash_manager.update_pp_schedule( + -(ctx.vp_stage + 1), -ctx.layer_no, -ctx.microbatch_no + ) + ctx.stash_manager.current_schedule_index += 1 + current_stream = torch.cuda.current_stream() + + ctx.stash_manager.wait_for_stash_to_complete() + if ctx.stash_manager._unpack_stream_status == 'reloading': + current_stream.wait_stream(ctx.stash_manager.unpack_stream) + ctx.stash_manager._unpack_stream_status = 'idle' + + return grad_output + (None, None) + + +class PagedStashManager: + """ + Singleton manager for coordinating paged stashing across pipeline stages. + Manages chunk handlers, synchronizes GPU-GPU transfers, + and handles virtual pipeline parallelism + """ + + STASH_MGR = None + + @classmethod + def get_instance(cls): + """Get the singleton instance of PagedStashManager.""" + if cls.STASH_MGR is None: + cls.STASH_MGR = PagedStashManager() + return cls.STASH_MGR + + def __init__(self): + """Initialize the manager with queues and dedicated CUDA streams.""" + # allocate streams and events for synchronization + self.enabled = False + self._pack_stream = torch.cuda.Stream() + # Currently paged stashing is not stream-safe, so use the same stream for packing + # and unpacking + self._unpack_stream = self._pack_stream + self._pack_stream_status = 'idle' # idle, stashing + self._unpack_stream_status = 'idle' # idle, reloading + self.paged_tensors_to_stash = [] + self.paged_tensors_stash_in_progress = [] + self.paged_tensors_to_reload = {} + + self.iteration = 0 + self._current_layer_name = None + self.vp_size = None + self.current_vp_stage = None + self.status = 'begin' # begin, capture, captured + # If element is +ve, it denotes forward pass of vp stage, + # if -ve, it denotes backward pass of vp stage + self._pp_schedule = None + self.current_layer = None + self.current_microbatch = None + self.current_schedule_index = None + + # Track max tokens needed across all vp_stages grouped by dtype and hidden_size + self.max_tokens_across_vp_stages = None + self.temp_tokens_across_vp_stages = None + # Track max tokens computed from avg_num_tokens (heuristic) across all vp_stages + self.max_avg_tokens_across_vp_stages = None + self.temp_avg_tokens_across_vp_stages = None + + self.num_tokens_tensor = None + self.max_num_tokens = None + # Optional hint: expected/average number of tokens (e.g., pre-padding estimate) + self.avg_num_tokens = None + self.stash_buffers = None + self.overflow = None + self.device = None + + # Page size for paged memory management (default; overwritten from config in paged_stash_reset) + self.page_size = 64 + + @property + def pack_stream(self): + """Get the pack stream.""" + return self._pack_stream + + @property + def unpack_stream(self): + """Get the unpack stream.""" + return self._unpack_stream + + def set_current_layer_name(self, name): + """Set the current layer name.""" + self._current_layer_name = name + + def get_schedule_layer(self, vp_stage, layer_no, microbatch_no): + """Get the schedule layer.""" + return vp_stage * 1000000 + layer_no * 1000 + microbatch_no + + def add_paged_tensor_to_stash(self, paged_tensor): + """Add a paged tensor to the stash list.""" + if self.status == 'captured': + self.paged_tensors_to_stash.append(paged_tensor) + else: + pass + + def remove_paged_tensor_from_stash(self): + """Remove all paged tensors from the stash list.""" + if self.status == 'captured': + while len(self.paged_tensors_to_stash) > 0: + paged_tensor = self.paged_tensors_to_stash.pop(0) + assert ( + len(self.paged_tensors_to_stash) == 0 + ), f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" + else: + pass + + def stash_paged_tensors(self, pp_schedule_layer): + """Stash the paged tensors.""" + current_stream = torch.cuda.current_stream() + self.pack_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.pack_stream): + if self.status == 'captured': + self._pack_stream_status = 'stashing' + if pp_schedule_layer not in self.paged_tensors_to_reload: + self.paged_tensors_to_reload[pp_schedule_layer] = [] + assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, ( + f"paged_tensors_to_reload {pp_schedule_layer} is not empty " + f"{self.paged_tensors_to_reload[pp_schedule_layer]}" + ) + while len(self.paged_tensors_to_stash) > 0: + paged_tensor = self.paged_tensors_to_stash.pop(0) + stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] + paged_tensor.offload_to_stash(stash_buffer) + self.paged_tensors_to_reload[pp_schedule_layer].append(paged_tensor) + self.paged_tensors_stash_in_progress.append(paged_tensor) + else: + pass + assert ( + len(self.paged_tensors_to_stash) == 0 + ), f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" + + def wait_for_stash_to_complete(self): + """Wait for stash to complete.""" + current_stream = torch.cuda.current_stream() + if self._pack_stream_status == 'stashing': + current_stream.wait_stream(self.pack_stream) + self._pack_stream_status = 'idle' + + # Deallocate original tensor after stash is complete + while len(self.paged_tensors_stash_in_progress) > 0: + paged_tensor = self.paged_tensors_stash_in_progress.pop(0) + paged_tensor._original_tensor = None + + def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): + """Reload the paged tensors.""" + # Avoid waiting for main stream if reload is immediately after stash + # since stash is already waiting for main stream + if not no_wait or self.unpack_stream != self.pack_stream: + current_stream = torch.cuda.current_stream() + self.unpack_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.unpack_stream): + if self.status == 'captured': + self._unpack_stream_status = 'reloading' + count = 0 + for item in self.paged_tensors_to_reload: + if len(self.paged_tensors_to_reload[item]) > 0: + count += 1 + while len(self.paged_tensors_to_reload[pp_schedule_layer]) > 0: + paged_tensor = self.paged_tensors_to_reload[pp_schedule_layer].pop(0) + stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] + paged_tensor.reload_from_stash(stash_buffer) + else: + pass + assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, ( + f"paged_tensors_to_reload {pp_schedule_layer} is not empty " + f"{self.paged_tensors_to_reload[pp_schedule_layer]}" + ) + + def allocate_stash_buffers( + self, + stash_buffer_size_factor_cuda: float = 1.10, + stash_buffer_size_factor_cpu: float = 0.0, + ): + """Allocate stash buffers organized by [dtype][hidden_size].""" + self.stash_buffers = {} + self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) + + cuda_factor = stash_buffer_size_factor_cuda + cpu_factor = stash_buffer_size_factor_cpu + + # Both factors use the same sign convention: + # - positive: size based on avg_num_tokens-derived maxima + # - negative: size based on actual num_tokens-derived maxima (legacy behavior) + # Scale is always abs(factor). For CPU, 0 means no host buffer. + if cuda_factor >= 0: + max_tokens_dict = self.max_avg_tokens_across_vp_stages + cuda_scale = cuda_factor + else: + max_tokens_dict = self.max_tokens_across_vp_stages + cuda_scale = -cuda_factor + + # Fallback safety: if avg-based dict is not available/populated yet, use actual-max dict. + if not max_tokens_dict: + max_tokens_dict = self.max_tokens_across_vp_stages + + if cpu_factor > 0: + host_tokens_dict = self.max_avg_tokens_across_vp_stages or self.max_tokens_across_vp_stages + cpu_scale = cpu_factor + elif cpu_factor < 0: + host_tokens_dict = self.max_tokens_across_vp_stages + cpu_scale = -cpu_factor + else: + host_tokens_dict = None + cpu_scale = 0.0 + + for dtype, hidden_size in max_tokens_dict: + if dtype not in self.stash_buffers: + self.stash_buffers[dtype] = {} + assert hidden_size not in self.stash_buffers[dtype] + num_tokens = int(max_tokens_dict[dtype, hidden_size] * cuda_scale) + num_tokens_host = ( + int(host_tokens_dict[dtype, hidden_size] * cpu_scale) + if host_tokens_dict is not None and (dtype, hidden_size) in host_tokens_dict + else 0 + ) + buf_dtype = torch.uint8 if dtype in [torch.float8_e4m3fn, torch.float8_e8m0fnu] else dtype + self.stash_buffers[dtype][hidden_size] = PagedStashBuffer( + num_tokens, + hidden_size, + self.page_size, + self.device, + self.overflow, + buf_dtype, + num_tokens_host=num_tokens_host, + ) + sb = self.stash_buffers[dtype][hidden_size] + msg = f'allocate_stash_buffers cuda: {sb.cuda_buffer.shape}' + if sb.host_buffer is not None: + msg += f' host: {sb.host_buffer.shape}' + print(f'{msg} dtype={sb.dtype} ({dtype})') + + def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): + """Update the pp schedule.""" + if self._pp_schedule is None: + self._pp_schedule = [] + + assert self.vp_size is not None + if layer_no is None: + # forward pass + vp_stage_index = vp_stage - 1 + layer_no = self.current_layer[vp_stage_index] + self.current_layer[vp_stage_index] += 1 + microbatch_no = self.current_microbatch[vp_stage_index] + + if self.status == 'capture': + self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) + num_tokens = self.num_tokens_tensor.item() + + expected = self.get_schedule_layer(vp_stage, layer_no, microbatch_no) + actual = self._pp_schedule[self.current_schedule_index] + assert actual == expected, f"schedule {actual} != {expected}" + + return layer_no, microbatch_no + + + def update_model_chunk(self, vp_stage_index): + """Update layer=1, increment microbatch of new vp vp_stage.""" + if self.current_layer is None: + # current layer and microbatch for each vp stage for forward pass + self.current_layer = [1 for _ in range(self.vp_size)] + self.current_microbatch = [0 for _ in range(self.vp_size)] + self.current_layer[vp_stage_index] = 1 + self.current_microbatch[vp_stage_index] += 1 + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """ + Hook called when autograd saves a tensor for backward pass. + Returns a tag to identify the tensor later. + """ + # Handle 0-dim tensors (torch.Size([])) - they have no size(0) + if ( + self.max_num_tokens is None + or tensor.dim() == 0 + or not hasattr(tensor, 'grouped_name') + or (tensor.size(0) != self.max_num_tokens and (tensor.logical_shape is None or tensor.logical_shape[0] != self.max_num_tokens)) + ): + return tensor.detach() + + assert isinstance(tensor, torch.Tensor), f"tensor is not a torch.Tensor {type(tensor)}" + #if hasattr(tensor, 'grouped_name'): + # print (f'on_save_for_backward {self.status} tensor: num_tokens: {self.num_tokens_tensor.item()}-{type(tensor)}-{tensor.shape}-{tensor.dtype}-{hex(tensor.data_ptr())}-grouped: name: {tensor.grouped_name} element_size: {tensor.element_size()} logical_shape: {tensor.logical_shape if hasattr(tensor, 'logical_shape') else None}') + #else: + # print (f'on_save_for_backward {self.status} tensor: num_tokens: {self.num_tokens_tensor.item()}-{type(tensor)}-{tensor.shape}-{tensor.dtype}-{hex(tensor.data_ptr())} element_size: {tensor.element_size()} logical_shape: {tensor.logical_shape if hasattr(tensor, 'logical_shape') else None}') + + original_shape = tensor.shape + grouped_name = tensor.grouped_name + tensor = tensor.flatten() + dtype = tensor.dtype + columnwise_scale_inv = 'columnwise_scale_inv' in grouped_name + hidden_size = tensor.numel() // (self.max_num_tokens if not columnwise_scale_inv else self.max_num_tokens // SCALE_INV_BLOCK_SIZE) + + if self.max_tokens_across_vp_stages is None: + self.max_tokens_across_vp_stages = {} + self.temp_tokens_across_vp_stages = {} + self.max_avg_tokens_across_vp_stages = {} + self.temp_avg_tokens_across_vp_stages = {} + + avg_num_tokens = None + if self.status == 'capture': + + self.num_tokens = self.num_tokens_tensor.item() + actual_num_tokens = self.num_tokens // SCALE_INV_BLOCK_SIZE if columnwise_scale_inv else self.num_tokens + + avg_num_tokens = ( + int(self.avg_num_tokens) if self.avg_num_tokens is not None else None + ) + + if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: + self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.max_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = 0 + + self.temp_tokens_across_vp_stages[dtype, hidden_size] += actual_num_tokens + self.max_tokens_across_vp_stages[dtype, hidden_size] = max( + self.max_tokens_across_vp_stages[dtype, hidden_size], + self.temp_tokens_across_vp_stages[dtype, hidden_size], + ) + + # Track avg tokens across vp stages (if provided) using the same accumulation model. + if avg_num_tokens is not None: + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] += (avg_num_tokens if not columnwise_scale_inv else avg_num_tokens // SCALE_INV_BLOCK_SIZE) + self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = max( + self.max_avg_tokens_across_vp_stages[dtype, hidden_size], + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size], + ) + + # Since capture stage does not use CUDA graph, we can truncate + # the saved tensor to actual num_tokens + new_size = (actual_num_tokens * hidden_size,) + + tensor_truncated = torch.empty(new_size, dtype=dtype, device=tensor.device) + tensor_truncated.copy_(tensor[: actual_num_tokens * hidden_size]) + tensor = tensor_truncated + + tensor.grouped_name = grouped_name + paged_tensor = PagedTensor( + tensor, + num_tokens_tensor=self.num_tokens_tensor, + avg_num_tokens=avg_num_tokens, + vp_stage=self.current_vp_stage, + original_shape=original_shape, + schedule_layer_no=( + self._pp_schedule[self.current_schedule_index] + if self._pp_schedule is not None + and self.current_schedule_index < len(self._pp_schedule) + else None + ), + layer_name=tensor.grouped_name, + max_num_tokens=self.max_num_tokens, + hidden_size=hidden_size, + page_size=self.page_size, + ) + + if self.status == 'captured': + self.add_paged_tensor_to_stash(paged_tensor) + return paged_tensor + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """ + Hook called when autograd retrieves a saved tensor during backward pass. + Returns the actual tensor (potentially reloading from CPU). + """ + if isinstance(saved_state, (PagedTensor)): + columnwise_scale_inv = 'columnwise_scale_inv' in saved_state.layer_name + if self.status == 'capture': + num_tokens = saved_state.num_tokens_tensor.item() + key = (saved_state.dtype, saved_state.hidden_size) + if key in self.temp_tokens_across_vp_stages: + self.temp_tokens_across_vp_stages[key] -= (num_tokens if not columnwise_scale_inv else num_tokens // SCALE_INV_BLOCK_SIZE) + if ( + saved_state.avg_num_tokens is not None + and key in self.temp_avg_tokens_across_vp_stages + ): + self.temp_avg_tokens_across_vp_stages[key] -= (int(saved_state.avg_num_tokens) if not columnwise_scale_inv else int(saved_state.avg_num_tokens) // SCALE_INV_BLOCK_SIZE) + + # Handle 1-byte tensors (torch.uint8) + dtype = saved_state._tensor.dtype + if saved_state._tensor.element_size() == 1: + saved_state._tensor = saved_state._tensor.view(torch.uint8) + + # Pad the tensor to the max number of tokens + # check if the tensor is 1D + assert saved_state._tensor.ndim == 1, f"saved_state._tensor.ndim is not 1 {saved_state._tensor.ndim}" + npad = (self.max_num_tokens - num_tokens) * saved_state.hidden_size + if columnwise_scale_inv: + npad = npad // SCALE_INV_BLOCK_SIZE + pad = (0, npad) + saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad).view(dtype) + + assert ( + saved_state._tensor is not None + ), f"saved_state._tensor is None {saved_state._tensor}" + + # Record cross-stream usage (important when tensor was produced on another stream). + if isinstance(saved_state._tensor, torch.Tensor) and saved_state._tensor.is_cuda: + saved_state._tensor.record_stream(torch.cuda.current_stream()) + + return saved_state._tensor.view(saved_state.original_shape) + + return saved_state + + +class PagedStashContext: + """Wrapper context manager that adds custom enter/exit behavior around saved_tensors_hooks.""" + + def __init__(self, stash_manager): + self.stash_manager = stash_manager + self.saved_tensors_context = torch.autograd.graph.saved_tensors_hooks( + stash_manager.on_save_for_backward, stash_manager.on_get_saved_tensor + ) + + def __enter__(self): + from megatron.core.extensions.transformer_engine import cpu_offload + + if cpu_offload is not None: + cpu_offload.CPUOffloadEnabled = True + # Call the underlying context manager's __enter__ + result = self.saved_tensors_context.__enter__() + + # Add more custom logic after entering if needed + return result + + def __exit__(self, *args: Any): + # Call the underlying context manager's __exit__ + result = self.saved_tensors_context.__exit__(*args) + from megatron.core.extensions.transformer_engine import cpu_offload + + if cpu_offload is not None: + cpu_offload.CPUOffloadEnabled = False + return result + + +def paged_stash_group_start(tensor): + """Mark the start of a layer group and prepare for stash/reload.""" + rank = torch.distributed.get_rank() + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return tensor + return PP_PreScheduleFunction.apply(tensor, stash_manager) + + +def get_paged_stash_context( + name=None, + max_num_tokens=None, + num_tokens_tensor=None, + avg_num_tokens=None, +): + """Get the paged stash context""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return nullcontext() + stash_manager.max_num_tokens = max_num_tokens + stash_manager.avg_num_tokens = avg_num_tokens + assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) + stash_manager.num_tokens_tensor = num_tokens_tensor + stash_manager.set_current_layer_name(name) if name is not None else None + pack_unpack_context = PagedStashContext(stash_manager) + return pack_unpack_context + + +def paged_stash_group_commit(tensor, name=None): + """Mark the end of a layer group and prepare for stash/reload.""" + rank = torch.distributed.get_rank() + stash_manager = PagedStashManager.get_instance() + stash_manager.device = tensor.device + if not stash_manager.enabled: + return tensor + return PP_PostScheduleFunction.apply(tensor, stash_manager) + + +def paged_stash_init_chunk_handler(vp_size, vp_stage): + """Initialize the chunk handler, called at the start of a microbatch forward pass.""" + stash_manager = PagedStashManager.get_instance() + stash_manager.vp_size = vp_size if vp_size is not None else 1 + stash_manager.current_vp_stage = vp_stage if vp_stage is not None else 0 + stash_manager.update_model_chunk(stash_manager.current_vp_stage) + +def paged_stash_set_last_layer(is_last_layer=False): + """Set the last layer flag.""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return + stash_manager._last_layer = is_last_layer + +def paged_stash_reset(enabled=True, config=None): + """Reset the chunk handler, called at the start of a training iteration. + + config: optional TransformerConfig; if provided, stash_buffer_size_factor_cuda/cpu and + moe_paged_stash_page_size are read from it. Otherwise defaults to 1.10 (CUDA), 0.0 (CPU). + """ + stash_manager = PagedStashManager.get_instance() + stash_manager.enabled = enabled + stash_manager.iteration += 1 + if config is not None: + stash_manager.page_size = config.moe_paged_stash_page_size + # current layer and microbatch for each vp stage for forward pass + stash_manager.current_schedule_index = 0 + + if not enabled: + return + + if stash_manager.status == 'begin': + stash_manager.status = 'capture' + elif stash_manager.status == 'capture': + stash_manager.status = 'captured' + print (f'schedule {stash_manager._pp_schedule}') + cuda_factor = config.stash_buffer_size_factor_cuda if config is not None else 1.10 + cpu_factor = config.stash_buffer_size_factor_cpu if config is not None else 0.0 + stash_manager.allocate_stash_buffers( + stash_buffer_size_factor_cuda=cuda_factor, + stash_buffer_size_factor_cpu=cpu_factor, + ) + elif stash_manager.status == 'captured': + pass + + if stash_manager.status == 'captured': + if not torch.cuda.is_current_stream_capturing(): + overflow = stash_manager.overflow.item() + assert overflow == 0, f"PagedStashManager overflow!!!" + + for dtype in stash_manager.stash_buffers.keys(): + for hidden_size in stash_manager.stash_buffers[dtype].keys(): + stash_manager.stash_buffers[dtype][hidden_size].reset() + stash_manager.overflow.zero_() + stash_manager.current_layer = [1 for _ in range(stash_manager.vp_size)] + stash_manager.current_microbatch = [0 for _ in range(stash_manager.vp_size)] + assert ( + len(stash_manager.paged_tensors_to_stash) == 0 + ), f"paged_tensors_to_stash is not empty {stash_manager.paged_tensors_to_stash}" + assert len(stash_manager.paged_tensors_stash_in_progress) == 0, ( + f"paged_tensors_stash_in_progress is not empty " + f"{stash_manager.paged_tensors_stash_in_progress}" + ) + +def check_paged_stash_overflow(): + """Check if paged stash overflow""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled or stash_manager.overflow is None: + return torch.zeros(1, dtype=torch.bool, device='cuda') + overflow = stash_manager.overflow.ne(0) + return overflow + +class PagedStashRunner: + """Runner for paged stash""" + + def __init__(self, config, copy_main_params, model, optimizer, forward_backward_func): + self.stash_manager = PagedStashManager.get_instance() + self.config = config + self.copy_main_params = copy_main_params + self.model = model + self.optimizer = optimizer + self.forward_backward_func = forward_backward_func + self.moe_layers = [] + for model_chunk in self.model: + model_with_decoder = get_attr_wrapped_model( + model_chunk, "decoder", allow_none=False, return_model_obj=True + ) + for layer in model_with_decoder.decoder.layers: + mlp = layer.mlp + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher, 'check_over_budget' + ): + self.moe_layers.append(mlp) + if model_with_decoder.mtp_process: + for layer in model_with_decoder.mtp.layers: + mlp = layer.mtp_model_layer.mlp + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher, 'check_over_budget' + ): + self.moe_layers.append(mlp) + print (f"PagedStashRunner: Moe layers {len(self.moe_layers)}!!!") + + def data_read(self, data_iterator, model, training, num_microbatches): + """Read all microbatch inputs from Dataloader and copy to static buffers.""" + data_iterator_saved = [] + if not isinstance(model, list) or len(model) == 1: + assert not isinstance(data_iterator, list) or len(data_iterator) == 1 + iterator0 = data_iterator if not isinstance(data_iterator, list) else data_iterator[0] + data_list = [] + if iterator0 is not None: + for b in range(num_microbatches): + data_list.append(next(iterator0)) + data_iterator_saved.append(data_list) + data_list = [iter(data_list)] + else: + data_list.append(None) + else: + assert isinstance(data_iterator, list) and len(data_iterator) == len(model) + data_list = [] + for i in range(len(model)): + if data_iterator[i] is not None: + data_list_i = [] + for b in range(num_microbatches): + data_list_i.append(next(data_iterator[i])) + data_iterator_saved.append(iter(data_list_i)) + data_list.append(iter(data_list_i)) + else: + data_list.append(None) + return data_iterator_saved, data_list + + def check_moe_overflow(self): + # check for paged stash overflow + overflow = check_paged_stash_overflow() + # check for token dispatcher overflow + for mlp in self.moe_layers: + overbudget = mlp.token_dispatcher.check_over_budget() + overflow |= overbudget + + overflow_int = overflow.to(torch.int32) + torch.distributed.all_reduce(overflow_int, op=torch.distributed.ReduceOp.SUM) + overflow_int = overflow_int.item() + return overflow_int + + def prepare_for_rerun(self, is_training=True): + """Prepare for rerun""" + print (f"!!!!!!!! Attempting to run without expert_rank_capacity_factor padding") + # check for token dispatcher overflow + for mlp in self.moe_layers: + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher._comm_manager, 'moe_expert_rank_capacity_factor' + ): + mlp.token_dispatcher._comm_manager.moe_expert_rank_capacity_factor = None + mlp.token_dispatcher._comm_manager.over_budget.fill_(0) + self.stash_manager.overflow.zero_() + self.config.moe_paged_stash = False + + # Set grad to zero. + for model_chunk in self.model: + model_chunk.zero_grad_buffer() + if self.optimizer is not None: + self.optimizer.zero_grad() + + #_handle_mxfp8_param_buffer_copy + if self.copy_main_params: + def _try_copy_main_params(opt): + if isinstance(opt, DistributedOptimizer) and hasattr(opt, 'shard_fp32_from_float16_groups'): + opt._copy_main_params_to_param_buffer() + # Handle both ChainedOptimizer and direct DistributedOptimizer cases + # Note: FSDP's DistributedOptimizer doesn't have shard_fp32_from_float16_groups, + # so we check for this attribute before calling _copy_main_params_to_param_buffer + if self.optimizer is not None: + if hasattr(self.optimizer, 'chained_optimizers'): + for optim_instance in self.optimizer.chained_optimizers: + _try_copy_main_params(optim_instance) + else: + _try_copy_main_params(self.optimizer) + + # Delete the CUDA graph + if isinstance(self.forward_backward_func, FullCudaGraphWrapper): + self.forward_backward_func.reset_cuda_graph(stage='training' if is_training else 'validation') + + def __call__(self, *args, **kwargs): + """Run the paged stash""" + assert len(args) == 0, 'forward_backward_func does not accept positional args' + assert all( + [ + kwarg in kwargs + for kwarg in [ + 'model', + 'data_iterator', + 'num_microbatches', + 'seq_length', + 'forward_only', + ] + ] + ) + model = kwargs['model'] + num_microbatches = kwargs['num_microbatches'] + + training = not kwargs['forward_only'] + data_iterator = kwargs['data_iterator'] + saved_moe_paged_stash = self.config.moe_paged_stash + num_tries = 0 + while True: + assert num_tries < 2, f"PagedStashRunner: num_tries {num_tries} exceeded max attempts!!!" + num_tries += 1 + data_iterator, data_list = self.data_read(data_iterator, model, training, num_microbatches) + kwargs['data_iterator'] = data_list + result = self.forward_backward_func(*args, **kwargs) + + overflow_int = self.check_moe_overflow() + # if no overflow, set the expert_rank_capacity_factor to the original value + if overflow_int == 0: + for mlp in self.moe_layers: + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher._comm_manager, 'moe_expert_rank_capacity_factor' + ): + mlp.token_dispatcher._comm_manager.moe_expert_rank_capacity_factor = mlp.token_dispatcher.config.moe_expert_rank_capacity_factor + self.config.moe_paged_stash = saved_moe_paged_stash + break + + # if overflow, set the expert_rank_capacity_factor to None + print(f"MBridge train(): PagedStashManager or Token Dispatcher overflow detected across ranks {overflow_int} config.moe_paged_stash: {self.config.moe_paged_stash}!!!") + self.prepare_for_rerun(is_training=training) + return result diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index c9a2a469531..e2dde33c9a3 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -17,6 +17,7 @@ compute_routing_scores_for_aux_loss, get_tokens_per_expert_and_token_count, router_gating_linear, + save_overload_factor_to_tracker, sinkhorn, switch_load_balancing_loss_func, topk_routing_with_score_function, @@ -53,6 +54,8 @@ def __init__( self.cp_group = pg_collection.cp self.tp_cp_group = pg_collection.tp_cp self.tp_dp_cp_group = pg_collection.tp_dp_cp + self.tp_ep_group = pg_collection.tp_ep + self.expt_dp_group = pg_collection.expt_dp # Initialize the gate weights. # TODO: Add support for GPU initialization, which requires updating the golden values. @@ -703,7 +706,20 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No logits, self.config.moe_router_force_biased, self.layer_number ) - probs, routing_map = self.routing(logits, padding_mask=padding_mask) + probs, routing_map = self.routing(logits) + # Log overload factor if enabled + if self.config.log_overload_factor: + # Compute num_local_experts from config and EP size + ep_size = self.tp_ep_group.size() // self.tp_group.size() + num_local_experts = self.config.num_moe_experts // ep_size + probs = save_overload_factor_to_tracker( + tensor=probs, + routing_map=routing_map, + layer_number=self.layer_number, + num_local_experts=num_local_experts, + tp_ep_group=self.tp_ep_group, + dp_group=self.expt_dp_group, + ) return probs, routing_map diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 62e7ff41b87..67f742212f4 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -78,6 +78,7 @@ def __init__( self.tp_size = utils.get_pg_size(self.tp_group) self.tp_rank = utils.get_pg_rank(self.tp_group) self.ep_size = utils.get_pg_size(self.ep_group) + self.ep_rank = utils.get_pg_rank(self.ep_group) # Attributes that need to be captured in cudagraph. These attributes are returned # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. @@ -1022,10 +1023,23 @@ def __init__( "https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep." ) + self.moe_expert_rank_capacity_factor = self.config.moe_expert_rank_capacity_factor + self.over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') + def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): num_tokens = routing_map.shape[0] self.routing_map = routing_map.reshape(num_tokens, self.num_experts) self.token_probs = probs.reshape(num_tokens, self.num_experts) + + if self.moe_expert_rank_capacity_factor is not None: + pad_multiple = get_align_size_for_quantization(self.config) + budget = int( + routing_map.shape[0] + * self.config.moe_router_topk + * self.moe_expert_rank_capacity_factor + ) + budget += -budget % pad_multiple + self.num_permuted_tokens = budget # Compute the capacity for each expert at the drop_and_pad mode if self.drop_and_pad: num_out_tokens = num_tokens * self.config.moe_router_topk @@ -1070,12 +1084,16 @@ def dispatch( pad_multiple=self.pad_multiple, ) ) + if self.moe_expert_rank_capacity_factor is not None: + over_budget = self.handle[8] != 0 # this is overflow_flag + self.over_budget |= over_budget - if not self.drop_and_pad: - self.tokens_per_expert = tokens_per_expert + if self.num_permuted_tokens is None: + self.tokens_per_expert = tokens_per_expert.to(torch.int64) # self.num_permuted_tokens is necessary to allocate the output tensor for permute self.num_permuted_tokens = self.tokens_per_expert.sum() - + if self.moe_expert_rank_capacity_factor is not None: + self.tokens_per_expert = tokens_per_expert.to(torch.int64) return dispatched_hidden def combine( @@ -1430,6 +1448,7 @@ def _initialize_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor) - .expand(-1, -1, self.tp_size, -1) .reshape(num_local_tokens, world_size, self.num_local_experts) ).contiguous() + return routing_map, probs @jit_fuser @@ -1550,3 +1569,10 @@ def combine_postprocess(self, hidden_states: torch.Tensor): The final MoE layer output reshaped to its original dimensions. """ return hidden_states.view(self.hidden_shape) + + def check_over_budget(self): + """Check if the dispatcher has exceeded its budget.""" + if hasattr(self._comm_manager, 'over_budget'): + return self._comm_manager.over_budget + else: + return None diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 7b8a764b813..c71c9377542 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -685,9 +685,10 @@ def process_mtp_loss( mtp_loss = compute_language_model_loss(mtp_labels, mtp_logits) mtp_loss = loss_mask * mtp_loss if is_training: + # Safe divide without sync: mask numerator when num_tokens==0, divide by clamp(min=1) mtp_loss_for_log = ( - torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0) - ) + torch.sum(mtp_loss) * (num_tokens > 0).to(mtp_loss.dtype) + ) / num_tokens.clamp(min=1) MTPLossLoggingHelper.save_loss_to_tracker( mtp_loss_for_log, mtp_layer_number, diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index e9bd52f34b4..2af2eeb9a8f 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -892,7 +892,6 @@ def forward( mhc_manager.is_last_layer_in_recompute_block = ( mhc_is_last_in_recompute_block[l_no] ) - with self.offload_context, inner_quantization_context: hidden_states, context = layer( hidden_states=hidden_states, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 7ec5636ab87..fa1c0e1e215 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -679,6 +679,11 @@ class TransformerConfig(ModelParallelConfig): If negative, generates bias once per layer and reuses it (abs value is std). This is an experimental feature for benchmarking purposes.""" + log_overload_factor: bool = False + """If True, log MoE router overload factors: avg_overload_factor, max_overload_factor + (load imbalance across EP ranks), and max_cum_overload_factor (peak cumulative tokens + ratio for forward/backward memory analysis).""" + moe_grouped_gemm: bool = False """When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped @@ -770,10 +775,13 @@ class TransformerConfig(ModelParallelConfig): block interleaved format. Instead of interpreting the input tensor as a concatenation of gates and linear units, it will be interpreted as alternating blocks of gates and linear units. - This data format is experimental and primarily intended to enable advanced fused kernels.""" + moe_expert_rank_capacity_factor: Optional[float] = None + """moe_expert_rank_capacity_factor (float): The capacity factor for each expert, None means no token + will be dropped. The default is None.""" + ################## # Context Parallel ################## @@ -989,6 +997,28 @@ class TransformerConfig(ModelParallelConfig): min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" + moe_paged_stash: bool = False + """If True, enable paged stash for MoE expert activations.""" + + moe_paged_stash_page_size: int = 64 + """Number of tokens per page for paged stash memory management.""" + + stash_modules: Optional[list[str]] = None + """The MoE submodules to stash activations for. + choices: "expert_fc1", "moe_act", "expert_fc2". + "expert_fc1": stash the input of the expert fc1 part. + "moe_act": stash the input of the moe activation part. + "expert_fc2": stash the input of the expert fc2 part. + """ + + stash_buffer_size_factor_cuda: float = 1.10 + """Scale factor for paged stash CUDA buffer allocation. Sign selects sizing: positive = avg-based, + negative = actual-max. Magnitude is headroom (e.g. 1.10 = 10%).""" + + stash_buffer_size_factor_cpu: float = 0.0 + """Scale factor for paged stash host buffer. 0 disables host buffer. Same sign convention as + stash_buffer_size_factor_cuda: positive = avg-based, negative = actual-max; scale = abs(factor).""" + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more @@ -1242,6 +1272,18 @@ def __post_init__(self): "moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity" ) + if self.moe_expert_rank_capacity_factor is not None: + if not self.use_transformer_engine_op_fuser: + raise ValueError( + "moe_expert_rank_capacity_factor requires " + "use_transformer_engine_op_fuser to be enabled." + ) + if self.moe_flex_dispatcher_backend != "hybridep": + raise ValueError( + "moe_expert_rank_capacity_factor requires moe_flex_dispatcher_backend to be " + "'hybridep'." + ) + if self.cpu_offloading and ( self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers ): @@ -1443,7 +1485,37 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) - + if self.moe_paged_stash:# vasu + assert ( + self.stash_modules is not None and len(self.stash_modules) > 0 + ), "stash_modules must be specified when moe_paged_stash is enabled." + allowed_modules = {"expert_fc1", "expert_fc2", "moe_act"} + invalid_modules = set(self.stash_modules) - allowed_modules + assert not invalid_modules, ( + f'Invalid choices for stash_modules: {invalid_modules}. ' + f'Allowed modules are: {allowed_modules}' + ) + assert ( + self.moe_expert_rank_capacity_factor is not None + ), "moe_expert_rank_capacity_factor must be set when moe_paged_stash is enabled." + + # Check that no module is both stashed and offloaded + if self.stash_modules and self.offload_modules: + overlap = set(self.stash_modules) & set(self.offload_modules) + assert not overlap, ( + f"A module cannot be stashed and offloaded at the same time. " + f"Found overlapping modules: {overlap}" + ) + # Check that Full/Selective recompute for MOE not enabled when paged stash is enabled + if self.moe_paged_stash: + if self.recompute_granularity == "full": + raise ValueError( + "Full recompute is not supported when paged stash is enabled." + ) + if self.recompute_granularity == "selective" and "moe" in self.recompute_modules: + raise ValueError( + "Selective recompute for MOE is not supported when paged stash is enabled." + ) if ( self.num_layers_in_first_pipeline_stage is not None or self.num_layers_in_last_pipeline_stage is not None @@ -2052,14 +2124,15 @@ def __post_init__(self): ) if self.cuda_graph_impl != "none": - assert ( - self.cuda_graph_impl == "transformer_engine" - and CudaGraphScope.moe not in self.cuda_graph_scope - and CudaGraphScope.mlp not in self.cuda_graph_scope - ), ( - 'CUDA graph scope on moe and mlp is not ' - 'supported with overlap_moe_expert_parallel_comm' - ) + if self.cuda_graph_impl == "transformer_engine": + assert ( + self.cuda_graph_impl == "transformer_engine" + and CudaGraphScope.moe not in self.cuda_graph_scope + and CudaGraphScope.mlp not in self.cuda_graph_scope + ), ( + 'CUDA graph scope on moe and mlp is not ' + 'supported with overlap_moe_expert_parallel_comm' + ) # Check delay_wgrad_compute compatibility if self.delay_wgrad_compute: @@ -2081,6 +2154,12 @@ def __post_init__(self): 'ep_overlap_early_attn_memory_release' ) + if self.use_dynamic_comp_stream: + assert self.overlap_moe_expert_parallel_comm, ( + 'overlap_moe_expert_parallel_comm must be enabled when enabling ' + 'use_dynamic_comp_stream' + ) + if self.context_parallel_size > 1 and self.cp_comm_type is not None: if isinstance(self.cp_comm_type, list): assert len(self.cp_comm_type) == self.num_layers, ( diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index c1bb0f8ac0d..db86e9099ae 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1409,6 +1409,7 @@ def validate_args(args, defaults={}): if is_te_min_version("2.10.0"): assert os.getenv("NVTE_CPU_OFFLOAD_V1", "0") == "1", \ "For fine-grained activation offloading with TE >= 2.10.0, NVTE_CPU_OFFLOAD_V1 should be set to 1 to avoid offloading weights." + assert not args.moe_paged_stash, "Fine-grained activation offloading and paged stash cannot be enabled at the same time" if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." diff --git a/megatron/training/training.py b/megatron/training/training.py index c5715e96aed..9ba8ac4d295 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2044,6 +2044,7 @@ def training_log( wandb_writer.log({'max_attention_logit': max_attention_logit}, iteration) # Log MoE metrics. moe_log_string = "" + overload_dict = {} if args.num_experts is not None: moe_loss_scale = 1 / get_num_microbatches() track_names = [] @@ -2066,6 +2067,7 @@ def training_log( iteration=iteration, writer=writer, wandb_writer=wandb_writer, + overload_dict=overload_dict, per_layer_logging=args.moe_per_layer_logging, force_initialize=True, track_names=track_names, @@ -2175,6 +2177,13 @@ def training_log( total_loss_dict[advanced_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 + if overload_dict: + if "avg_overload_factor" in overload_dict: + log_string += f' avg overload factor: {overload_dict["avg_overload_factor"]:.3f} |' + if "max_overload_factor" in overload_dict: + log_string += f' max overload factor: {overload_dict["max_overload_factor"]:.3f} |' + if "max_cum_overload_factor" in overload_dict and overload_dict["max_cum_overload_factor"] is not None: + log_string += f' max cum overload factor: {overload_dict["max_cum_overload_factor"]:.3f} |' print_rank_last(log_string) reported_memory_in_this_iteration = False if report_memory_flag: @@ -2716,7 +2725,11 @@ def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs): # Wrap forward_backward_func for Full iteration CUDA graph forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: - forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + forward_backward_func = FullCudaGraphWrapper( + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + moe_expert_rank_capacity_factor=args.moe_expert_rank_capacity_factor, + ) def get_e2e_base_metrics(): """Get base metrics values for one-logger to calculate E2E tracking metrics.""" @@ -3202,7 +3215,11 @@ def evaluate( eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size) forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: - forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + forward_backward_func = FullCudaGraphWrapper( + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + moe_expert_rank_capacity_factor=args.moe_expert_rank_capacity_factor, + ) if has_nvidia_modelopt: # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 7844b450136..53802726cf8 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -562,7 +562,7 @@ def _broadcast(item): def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) - n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + n_tensor = torch.empty(1, dtype=torch.int64, device=dev).fill_(n) _broadcast(n_tensor) if n == 0: diff --git a/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py b/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py index 85586095bd7..cc56b170ba9 100644 --- a/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py +++ b/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py @@ -13,7 +13,7 @@ ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator -from megatron.core.pipeline_parallel.utils import set_streams +from megatron.core.pipeline_parallel.utils import reset_streams, set_streams from megatron.core.tensor_parallel.random import HAVE_TE, model_parallel_cuda_manual_seed from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.module import float16_to_fp32 @@ -68,6 +68,7 @@ def teardown_method(self, method): os.environ.pop(key, None) else: os.environ[key] = value + reset_streams() Utils.destroy_model_parallel() destroy_global_vars() destroy_num_microbatches_calculator() diff --git a/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py b/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py index 6c59dd3f9e3..0e950946898 100644 --- a/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py +++ b/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py @@ -10,7 +10,7 @@ get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.pipeline_parallel.utils import set_streams +from megatron.core.pipeline_parallel.utils import reset_streams, set_streams from megatron.core.transformer.module import float16_to_fp32 from megatron.core.utils import is_te_min_version from tests.unit_tests.a2a_overlap.utils import ( @@ -80,6 +80,7 @@ def setup_method(self, method): set_streams() def teardown_method(self, method): + reset_streams() Utils.destroy_model_parallel() @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") diff --git a/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py b/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py index c6c4a75af99..69926dbf466 100644 --- a/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py +++ b/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py @@ -12,6 +12,12 @@ get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.pipeline_parallel.utils import ( + get_comm_stream, + get_comp_stream, + reset_streams, + set_streams, +) from megatron.core.utils import is_te_min_version from tests.unit_tests.a2a_overlap.utils import ( DummyState, @@ -68,9 +74,8 @@ def run_transformer_layer_a2a_overlap_with_capture(model, input_tensors, microba for i in range(len(input_tensors)): input_tensors[i] = input_tensors[i].clone() + set_streams() event = torch.cuda.Event() - comp_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream(device="cuda") state = DummyState() state.is_mtp = False state.model = model @@ -79,8 +84,8 @@ def run_transformer_layer_a2a_overlap_with_capture(model, input_tensors, microba transformer_layer, event, state, - comp_stream, - comm_stream, + get_comp_stream, + get_comm_stream, extra_args={"is_moe": True, "enable_deepep": False}, ) for _ in range(microbatches) @@ -183,8 +188,7 @@ def run_mtp_layer_a2a_overlap_with_capture( for i in range(len(hidden_states)): hidden_states[i] = hidden_states[i].clone() - comp_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream(device="cuda") + set_streams() layers = [] for _ in range(microbatches): state = DummyState() @@ -203,8 +207,8 @@ def run_mtp_layer_a2a_overlap_with_capture( model.mtp.layers[0], event, state, - comp_stream, - comm_stream, + get_comp_stream, + get_comm_stream, extra_args={ "is_moe": True, "enable_deepep": False, @@ -255,6 +259,7 @@ def setup_method(self, method): ) def teardown_method(self, method): + reset_streams() Utils.destroy_model_parallel() @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 558c6934a0c..d39a2ed4f0f 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -346,7 +346,7 @@ def test_fine_grained_activation_offload_with_ep_a2a_overlap_compatibility( from megatron.core.models.common.model_chunk_schedule_plan import ( TransformerModelChunkSchedulePlan, ) - from megatron.core.pipeline_parallel.utils import set_streams + from megatron.core.pipeline_parallel.utils import reset_streams, set_streams from tests.unit_tests.a2a_overlap.utils import deterministic_mode # EP overlap requires distributed initialization with EP groups. @@ -570,4 +570,5 @@ def _run_schedule_1f1b_two_microbatches( f"(rel_err={rel_err:.2f}, abs_err={abs_err:.2f})" ) finally: + reset_streams() Utils.destroy_model_parallel() diff --git a/tests/unit_tests/transformer/moe/test_paged_stashing.py b/tests/unit_tests/transformer/moe/test_paged_stashing.py new file mode 100644 index 00000000000..62a22e04054 --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_paged_stashing.py @@ -0,0 +1,400 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest +import torch +import torch.nn.functional as F + +from megatron.core import config +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_utils import get_align_size_for_quantization +from megatron.core.transformer.moe.experts import TEGroupedMLP +from megatron.core.transformer.moe.paged_stash import ( + check_paged_stash_overflow, + paged_stash_init_chunk_handler, + paged_stash_reset, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.training.initialize import _set_random_seed +from tests.unit_tests.test_utilities import Utils + + +def _global_tokens_per_expert_from_local_routing_map(routing_map: torch.Tensor) -> torch.Tensor: + """Per-expert token counts from a local routing map, summed across the default process group. + + ``routing_map`` is shaped [num_local_token_rows, num_experts] (as in + ``_HybridEPManager``). Tests here assume world size equals expert-parallel size (all GPUs + are EP ranks); ``all_reduce`` on the world group aggregates disjoint local maps. + """ + counts = routing_map.sum(dim=0).to(torch.int64) + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.all_reduce(counts, op=torch.distributed.ReduceOp.SUM) + return counts + + +def _tokens_per_expert_from_routing_map(routing_map: torch.Tensor, layer: MoELayer) -> torch.Tensor: + """Per-local-expert assignment counts from the routing map (columns for this EP rank).""" + counts = _global_tokens_per_expert_from_local_routing_map(routing_map) + idx = torch.as_tensor(layer.local_expert_indices, device=counts.device, dtype=torch.long) + return counts[idx].to(torch.int64).clone() + + +def _pad_token_counts_to_align_size( + tokens_per_expert: torch.Tensor, pad_multiple: int +) -> torch.Tensor: + """Round each count up to a multiple of ``pad_multiple`` (``n + (-n % m)`` like budget).""" + t = tokens_per_expert.to(torch.int64) + return t + (-t % pad_multiple) + + +class MoEModelTestContainer: + def __init__( + self, + tp_size, + ep_size, + pp_size, + cp_size=1, + moe_tp_size=None, + data_parallel_random_init=False, + num_moe_experts=8, + num_layers=1, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_expert_capacity_factor=None, + moe_pad_expert_input_to_capacity=False, + moe_aux_loss_coeff=0.1, + test_dtype=torch.float32, + **kwargs, + ): + self.num_local_experts = num_moe_experts // ep_size + self.num_layers = num_layers + self.test_dtype = test_dtype + if moe_tp_size is None: + moe_tp_size = tp_size + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + ) + _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) + self.config = TransformerConfig( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + pipeline_model_parallel_size=pp_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + fp8='e4m3', + fp8_recipe='mxfp8', + fp8_wgrad=True, + fp8_amax_compute_algo='most_recent', + fp8_amax_history_len=1, + fp8_interval=1, + fp8_margin=0, + moe_router_topk=moe_router_topk, + num_moe_experts=num_moe_experts, + moe_router_load_balancing_type=moe_router_load_balancing_type, + moe_token_dispatcher_type=moe_token_dispatcher_type, + moe_expert_capacity_factor=moe_expert_capacity_factor, + moe_pad_expert_input_to_capacity=moe_pad_expert_input_to_capacity, + moe_aux_loss_coeff=moe_aux_loss_coeff, + num_layers=num_layers, + moe_router_dtype="fp32", + hidden_size=kwargs.get("hidden_size", 16), + num_attention_heads=kwargs.get("num_attention_heads", 8), + use_cpu_initialization=kwargs.get("use_cpu_initialization", True), + sequence_parallel=tp_size > 1, + add_bias_linear=kwargs.get("add_bias_linear", False), + moe_permute_fusion=kwargs.get("moe_permute_fusion", False), + moe_flex_dispatcher_backend=kwargs.get("moe_flex_dispatcher_backend", None), + moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False), + moe_use_legacy_grouped_gemm=kwargs.get("moe_use_legacy_grouped_gemm", False), + moe_paged_stash=kwargs.get("moe_paged_stash", False), + stash_modules=kwargs.get("stash_modules", None), + moe_expert_rank_capacity_factor=kwargs.get("moe_expert_rank_capacity_factor", None), + moe_router_padding_for_fp8=kwargs.get("moe_router_padding_for_fp8", True), + use_transformer_engine_op_fuser=kwargs.get("use_transformer_engine_op_fuser", False), + moe_mlp_glu_interleave_size=kwargs.get("moe_mlp_glu_interleave_size", None), + moe_router_padding_for_quantization=kwargs.get( + "moe_router_padding_for_quantization", False + ), + gated_linear_unit=kwargs.get("gated_linear_unit", False), + activation_func=kwargs.get("activation_func", F.gelu), + moe_router_force_biased=kwargs.get("moe_router_force_biased", None), + stash_buffer_size_factor_cuda=0.5, + stash_buffer_size_factor_cpu=1.5, + ) + self.moe_layers = [ + self._create_moe_layer(layer_number=i) for i in range(num_layers) + ] + self.moe_layer = self.moe_layers[0] + + def _create_moe_layer(self, layer_number=0): + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=self.config.num_moe_experts, moe_grouped_gemm=True + ) + quantization_context = get_fp8_context(self.config, layer_number, is_init=True) + with quantization_context: + moe_layer = ( + MoELayer(self.config, transformer_layer_spec.submodules.mlp.submodules) + .cuda() + .to(dtype=self.test_dtype) + ) + moe_layer.set_layer_number(layer_number) + return moe_layer + + def zero_grad(self): + for layer in self.moe_layers: + layer.zero_grad() + + def __del__(self): + torch.distributed.barrier() + torch.cuda.synchronize() + Utils.destroy_model_parallel() + + def destroy(self): + Utils.destroy_model_parallel() + + +def _forward_backward_all_layers(container: MoEModelTestContainer, hidden_states: torch.Tensor): + """Forward/backward all MoE layers; returns output, input grad, last layer routing state.""" + initial_hidden_states = hidden_states.cuda().requires_grad_(True) + hidden_states = initial_hidden_states + quantization_context = get_fp8_context(container.config) + with quantization_context: + for layer in container.moe_layers: + hidden_states, _ = layer(hidden_states) + output = hidden_states + last_layer = container.moe_layers[-1] + comm = getattr(last_layer.token_dispatcher, "_comm_manager", None) + routing_map = getattr(comm, "routing_map", None) + tokens_per_expert = ( + comm.get_number_of_tokens_per_expert() + if comm is not None and hasattr(comm, "get_number_of_tokens_per_expert") + else None + ) + output.backward(torch.ones_like(output)) + return ( + output.detach(), + initial_hidden_states.grad, + routing_map, + tokens_per_expert, + ) + + +def is_hybrid_ep_available(): + from megatron.core.transformer.moe.fused_a2a import HAVE_HYBRIDEP + return HAVE_HYBRIDEP + + +@pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available") +class TestPagedStashing: + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + def test_forward_backward_4_layers(self): + """Test paged stashing with 4 MoE layers: ref run vs paged run match.""" + if not is_hybrid_ep_available(): + pytest.skip("Hybrid EP is not available") + + config.ENABLE_EXPERIMENTAL = True + + container = MoEModelTestContainer( + tp_size=1, + ep_size=4, + pp_size=1, + num_moe_experts=8, + num_layers=4, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="flex", + moe_permute_fusion=True, + hidden_size=1024, + moe_flex_dispatcher_backend="hybridep", + test_dtype=torch.bfloat16, + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, + moe_paged_stash=True, + stash_modules=["expert_fc1", "moe_act", "expert_fc2"], + moe_expert_rank_capacity_factor=1.5, + use_transformer_engine_op_fuser=True, + moe_mlp_glu_interleave_size=32, + moe_router_padding_for_quantization=True, + gated_linear_unit=True, + activation_func=F.silu, + ) + experts = container.moe_layer.experts + fused_ok = isinstance(experts, TEGroupedMLP) and experts._is_fused_impl_supported() + if not fused_ok: + container.destroy() + pytest.skip("TEGroupedMLP fused impl not supported") + + seq_length = 1024 + batch_size = 1 + hidden_size = container.config.hidden_size + hidden_states = torch.randn( + (seq_length, batch_size, hidden_size), dtype=torch.bfloat16 + ) + + # First iteration: capture schedule, capacity, etc. + paged_stash_reset(True, config=container.config) + paged_stash_init_chunk_handler(1, 0) + output_ref, hidden_states_grad_ref, routing_map_ref, tokens_per_expert_ref = ( + _forward_backward_all_layers(container, hidden_states) + ) + + container.zero_grad() + + # Second iteration: run with paged stash. + paged_stash_reset(True, config=container.config) + paged_stash_init_chunk_handler(1, 0) + output, hidden_states_grad, routing_map, tokens_per_expert = _forward_backward_all_layers( + container, hidden_states + ) + + overflow = check_paged_stash_overflow() + assert overflow.any().item() == 0 + + assert torch.allclose(output, output_ref, atol=1e-4, rtol=1e-4), ( + f"output != output_ref: max diff = {(output - output_ref).abs().max().item()}" + ) + assert torch.allclose(hidden_states_grad, hidden_states_grad_ref, atol=1e-4, rtol=1e-4), ( + f"hidden_states_grad != ref: max diff = " + f"{(hidden_states_grad - hidden_states_grad_ref).abs().max().item()}" + ) + if routing_map is not None and tokens_per_expert is not None: + num_tokens_per_ep_rank = tokens_per_expert.sum().item() + assert num_tokens_per_ep_rank > 0, ( + f"num_tokens_per_ep_rank={num_tokens_per_ep_rank} (expected > 0)" + ) + assert routing_map_ref is not None and tokens_per_expert_ref is not None + tpe_f = tokens_per_expert.float() + ref_f = tokens_per_expert_ref.float() + assert torch.allclose(tpe_f, ref_f, atol=1e-4, rtol=1e-4), ( + f"tokens_per_expert != ref: max diff = {(tpe_f - ref_f).abs().max().item()}" + ) + + +@pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available") +class TestPagedStashingOverBudget: + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + def test_overload_factor_and_over_budget(self): + """Budget matches HybridEP setup_metadata; over_budget matches map-derived load.""" + if not is_hybrid_ep_available(): + pytest.skip("Hybrid EP is not available") + + config.ENABLE_EXPERIMENTAL = True + + container = MoEModelTestContainer( + tp_size=1, + ep_size=4, + pp_size=1, + num_moe_experts=8, + num_layers=4, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="flex", + moe_permute_fusion=True, + hidden_size=1024, + moe_flex_dispatcher_backend="hybridep", + test_dtype=torch.bfloat16, + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, + moe_paged_stash=True, + stash_modules=["expert_fc1", "moe_act", "expert_fc2"], + moe_expert_rank_capacity_factor=1.5, + use_transformer_engine_op_fuser=True, + moe_mlp_glu_interleave_size=32, + moe_router_padding_for_quantization=True, + gated_linear_unit=True, + activation_func=F.silu, + moe_router_force_biased=1, + ) + experts = container.moe_layer.experts + fused_ok = isinstance(experts, TEGroupedMLP) and experts._is_fused_impl_supported() + if not fused_ok: + container.destroy() + pytest.skip("TEGroupedMLP fused impl not supported") + + seq_length = 1024 + batch_size = 1 + topk = container.config.moe_router_topk + capacity_factor = container.config.moe_expert_rank_capacity_factor + hidden_states = torch.randn( + (seq_length, batch_size, container.config.hidden_size), dtype=torch.bfloat16 + ) + + num_tokens = seq_length * batch_size * topk + pad_multiple = get_align_size_for_quantization(container.config) + budget = int(num_tokens * capacity_factor) + budget += -budget % pad_multiple + + paged_stash_reset(True, config=container.config) + paged_stash_init_chunk_handler(1, 0) + _forward_backward_all_layers(container, hidden_states) + + overflow = check_paged_stash_overflow() + num_layers = len(container.moe_layers) + stash_cuda = container.config.stash_buffer_size_factor_cuda + stash_cpu = container.config.stash_buffer_size_factor_cpu + stash_buffer_size = num_tokens * num_layers * (stash_cuda + stash_cpu) + + total_tokens = 0 + for layer_idx, layer in enumerate(container.moe_layers): + comm = getattr(layer.token_dispatcher, "_comm_manager", None) + routing_map = getattr(comm, "routing_map", None) if comm is not None else None + over_budget_tensor = ( + layer.token_dispatcher.check_over_budget() + if hasattr(layer.token_dispatcher, "check_over_budget") + else None + ) + over_budget = over_budget_tensor.item() if over_budget_tensor is not None else False + + assert routing_map is not None, f"layer {layer_idx}: routing_map is None" + assert routing_map.dim() == 2, f"layer {layer_idx}: expected 2D routing_map" + assert routing_map.shape[1] == container.config.num_moe_experts, ( + f"layer {layer_idx}: routing_map has {routing_map.shape[1]} experts, " + f"expected {container.config.num_moe_experts}" + ) + tokens_per_expert_from_map = _tokens_per_expert_from_routing_map(routing_map, layer) + tokens_per_expert_from_map_padded = _pad_token_counts_to_align_size( + tokens_per_expert_from_map, pad_multiple + ) + tokens_per_ep_rank_from_map = tokens_per_expert_from_map_padded.sum().item() + total_tokens += tokens_per_ep_rank_from_map + + # Padded map-derived tokens strictly over budget iff dispatcher reports over_budget + if tokens_per_ep_rank_from_map > budget: + assert over_budget, ( + f"layer {layer_idx}: tokens_per_ep_rank_from_map " + f"({tokens_per_ep_rank_from_map}) > budget ({budget}), " + f"but over_budget flag was not set" + ) + else: + assert not over_budget, ( + f"layer {layer_idx}: tokens_per_ep_rank_from_map " + f"({tokens_per_ep_rank_from_map}) <= budget ({budget}), " + f"but over_budget flag was set" + ) + + overflow_set = overflow.any().item() + stash_exceeded = total_tokens > stash_buffer_size + assert overflow_set == stash_exceeded, ( + f"overflow {overflow_set} should match total_tokens > stash_buffer_size " + f"({total_tokens} > {stash_buffer_size})" + )