From 1efb85159acf8a6c3994862a695334f393e8ceb9 Mon Sep 17 00:00:00 2001 From: Qi Zhang Date: Mon, 3 Nov 2025 00:06:51 -0800 Subject: [PATCH 01/57] Add --moe-use-device-initiated-grouped-gemm to allow token_per_expert tensor on GPU --- megatron/core/fp8_utils.py | 2 +- megatron/core/transformer/transformer_config.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/megatron/core/fp8_utils.py b/megatron/core/fp8_utils.py index fa6be91dfbf..0a85ea42e19 100644 --- a/megatron/core/fp8_utils.py +++ b/megatron/core/fp8_utils.py @@ -168,7 +168,7 @@ def _get_custom_recipe(quantizer_factory_python_path: str) -> Union[Fp8Recipe, F def get_fp8_align_size(fp8_recipe: Fp8Recipe) -> int: """Get the alignment size required for fp8 GEMM.""" if fp8_recipe == Fp8Recipe.mxfp8: - return 32 + return 128 else: return 16 diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 7ec5636ab87..8e514806018 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -685,6 +685,10 @@ class TransformerConfig(ModelParallelConfig): GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm). """ + moe_use_device_initiated_grouped_gemm: bool = False + """Use the cutlass grouped gemm kernel, which allows for the token_per_expert tensor on GPU. This can prevent the GPU-CPU synchronization during the grouped gemm.""" + + moe_use_legacy_grouped_gemm: bool = False """Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.""" From 64041fcc9f7d3267d546123d39d31f685f6a276d Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Sun, 16 Nov 2025 19:02:06 -0800 Subject: [PATCH 02/57] Initial change for packed offloading --- megatron/core/full_cuda_graph.py | 10 +- .../common/model_chunk_schedule_plan.py | 7 + megatron/core/models/gpt/gpt_model.py | 16 + .../pipeline_parallel/moe_packed_offload.py | 728 ++++++++++++++++++ megatron/core/pipeline_parallel/schedules.py | 24 +- megatron/core/transformer/moe/experts.py | 44 +- .../core/transformer/moe/token_dispatcher.py | 32 +- .../core/transformer/transformer_block.py | 7 + .../core/transformer/transformer_config.py | 15 + megatron/training/arguments.py | 8 + 10 files changed, 877 insertions(+), 14 deletions(-) create mode 100644 megatron/core/pipeline_parallel/moe_packed_offload.py diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 7c11195f33b..0c00b1a45a8 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -7,6 +7,9 @@ import torch from megatron.core.tensor_parallel.random import get_all_rng_states +from megatron.core.pipeline_parallel.moe_packed_offload import ( + packed_moe_expert_offloading_reset, +) logger = logging.getLogger(__name__) @@ -98,10 +101,11 @@ class FullCudaGraphWrapper: cuda_graph = {'training': None, 'validation': None} result = {'training': None, 'validation': None} - def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1): + def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1, packed_moe_expert_offloading=False): self.forward_backward_func = forward_backward_func self.static_loader = StaticBufferLoader() self.cuda_graph_warmup_steps = cuda_graph_warmup_steps + self.packed_moe_expert_offloading = packed_moe_expert_offloading def data_read(self, data_iterator, model, training, num_microbatches): """Read all microbatch inputs from Dataloader and copy to static buffers.""" @@ -161,7 +165,7 @@ def __call__(self, *args, **kwargs): training_str = 'training' if training else 'validation' curr_iteration = self.curr_iter(training_str) if curr_iteration == self.cuda_graph_warmup_steps: - logger.info(f'Capture CUDA graph for {training_str}!!!') + print(f'Capture CUDA graph for {training_str}!!!') torch.distributed.barrier() assert FullCudaGraphWrapper.cuda_graph[training_str] is None FullCudaGraphWrapper.cuda_graph[training_str] = torch.cuda.CUDAGraph() @@ -184,6 +188,8 @@ def __call__(self, *args, **kwargs): if FullCudaGraphWrapper.cuda_graph[training_str] is None: FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: + if self.packed_moe_expert_offloading and training_str == 'training': + packed_moe_expert_offloading_reset() FullCudaGraphWrapper.cuda_graph[training_str].replay() self.next_iter(training_str) diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index 2e26e5fd1d3..eef232bb30b 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -8,6 +8,9 @@ from megatron.core.enums import Fp8Recipe from megatron.core.fp8_utils import get_fp8_context +from megatron.core.pipeline_parallel.moe_packed_offload import ( + packed_moe_expert_offloading_set_last_layer, +) from megatron.core.pipeline_parallel.utils import ( AbstractSchedulePlan, NoopScheduleNode, @@ -479,6 +482,8 @@ def run( f_layer = f_schedule_plan.get_layer(i) b_layer = b_schedule_plan.pop_layer() torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b") + if f_layer.layer.config.packed_moe_expert_offloading: + packed_moe_expert_offloading_set_last_layer(i == f_num_layers - 1) f_input, b_grad = TransformerLayerSchedulePlan.run( f_layer, b_layer, @@ -505,6 +510,8 @@ def run( for i in range(overlapped_layers, f_num_layers): f_layer = f_schedule_plan.get_layer(i) torch.cuda.nvtx.range_push(f"layer_{i}f") + if f_layer.layer.config.packed_moe_expert_offloading: + packed_moe_expert_offloading_set_last_layer(i == f_num_layers - 1) f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input) torch.cuda.nvtx.range_pop() diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 27b62f91c34..dfbd69f7e19 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -21,6 +21,9 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) +from megatron.core.pipeline_parallel.moe_packed_offload import ( + packed_moe_expert_offloading_init_chunk_handler, +) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.quantization.utils import get_quant_config_or_none from megatron.core.tensor_parallel import gather_from_sequence_parallel_region @@ -473,6 +476,13 @@ def preprocess_for_fine_grained_offloading(self): off_interface.mark_not_offloadable(param) self.disable_param_offloading = False + def preprocess_for_packed_moe_expert_offloading(self): + """Preprocess for packed moe expert offloading.""" + return packed_moe_expert_offloading_init_chunk_handler( + vp_size=self.config.virtual_pipeline_model_parallel_size, + vp_stage=self.vp_stage, + ) + def forward( self, input_ids: Tensor, @@ -505,6 +515,9 @@ def forward( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() + if self.config.packed_moe_expert_offloading: + self.preprocess_for_packed_moe_expert_offloading() + inference_context = deprecate_inference_params(inference_context, inference_params) preproc_output = self._preprocess( @@ -527,6 +540,7 @@ def forward( rotary_pos_cos_sin = preproc_output[6] if len(preproc_output) == 7 else None + # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, @@ -745,6 +759,8 @@ def build_schedule_plan( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() + if self.config.packed_moe_expert_offloading: + self.preprocess_for_packed_moe_expert_offloading() from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py new file mode 100644 index 00000000000..67fc4d521dc --- /dev/null +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -0,0 +1,728 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import warnings +from collections import deque +from contextlib import nullcontext +from typing import Any +import os +import torch +try: + import triton + import triton.language as tl + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + +# Packed Moe Expert Offload implementation for pipeline parallelism +DEBUG = False +DEBUG_RANK = [0] +def debug_print(message): + """Print debug message for a specific rank when DEBUG is enabled.""" + # pylint: disable=bad-builtin + if not DEBUG: + return + assert torch.distributed.is_initialized() + if torch.distributed.get_rank() in DEBUG_RANK: + print(f'{torch.distributed.get_rank()}: {message}') + +GLOBAL_BLOCK_SIZE = 2048 +@triton.jit +def _stash_copy_kernel( + src_ptr, + dst_ptr, + size_ptr, + alloc_offset_ptr, + free_offset_ptr, + capacity_ptr, + overflow_ptr, + BLOCK_SIZE: tl.constexpr, + num_iterations: tl.constexpr, +): + """Triton kernel to copy tensor data to stash buffer. + + Each block can handle multiple chunks of data (num_iterations) to limit total blocks. + Ignores out-of-bound writes if offset + size exceeds capacity. + + Args: + src_ptr: Pointer to source tensor (flattened) + dst_ptr: Pointer to destination buffer (stash_buffer) + size_ptr: Pointer to scalar tensor containing the size to copy + offset_original_ptr: Pointer to GPU tensor containing original offset (read-only) + over_capacity_ptr: Pointer to counter tensor (incremented when over capacity) + capacity: Total capacity of the buffer + BLOCK_SIZE: Block size for Triton kernel + num_iterations: Number of iterations each block should handle + """ + # Get the program ID + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # Load the size value from GPU tensor + size = tl.load(size_ptr) + + # Load original offset from GPU tensor (for position calculations) + alloc_offset = tl.load(alloc_offset_ptr) + free_offset = tl.load(free_offset_ptr) + capacity = tl.load(capacity_ptr) + + # Only the first thread checks capacity + # Do this BEFORE the loop so it always happens + overflow = False + # Check if over capacity and increment counter + avail_space = free_offset - alloc_offset + if avail_space < 0: + avail_space = -avail_space + else: + avail_space = capacity - avail_space + if avail_space < size: + overflow = True + if pid == 0 and overflow: + tl.store(overflow_ptr, 1) + + #if pid == 1: + # tl.device_print("free_offset: ", free_offset) + if overflow: + return + + # Each block handles num_iterations chunks of BLOCK_SIZE elements + # Use while loop with early exit condition in the loop test + iteration = 0 + block_start = (pid * num_iterations + iteration) * BLOCK_SIZE + while iteration < num_iterations and block_start < size: + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Create mask for valid elements within source size + src_mask = offsets < size + + # Create mask for valid destination indices (within buffer capacity) + dst_indices = free_offset + offsets + dst_mask = dst_indices >= capacity + dst_indices = tl.where(dst_mask, dst_indices - capacity, dst_indices) + + # Load from source + src_data = tl.load(src_ptr + offsets, mask=src_mask, other=0.0) + + # Store to destination (ignores out-of-bound writes) + tl.store(dst_ptr + dst_indices, src_data, mask=src_mask) + + # Move to next iteration + iteration += 1 + block_start = (pid * num_iterations + iteration) * BLOCK_SIZE + + # Check if over capacity and increment counter + size_page_aligned = tl.cdiv(size, BLOCK_SIZE) * BLOCK_SIZE + free_offset = free_offset + size_page_aligned + if free_offset > capacity: + free_offset -= capacity + if pid == 0: + tl.store(free_offset_ptr, free_offset) + +@triton.jit +def _stash_pop_kernel( + src_ptr, + dst_ptr, + size_ptr, + tensor_offset_ptr, + alloc_offset_ptr, + free_offset_ptr, + capacity_ptr, + BLOCK_SIZE: tl.constexpr, + num_iterations: tl.constexpr, +): + """Triton kernel to copy tensor data from stash buffer. + + Each block can handle multiple chunks of data (num_iterations) to limit total blocks. + Ignores out-of-bound writes if offset + size exceeds capacity. + + Args: + src_ptr: Pointer to source tensor (flattened) + dst_ptr: Pointer to destination buffer (stash_buffer) + size_ptr: Pointer to scalar tensor containing the size to copy + offset_original_ptr: Pointer to GPU tensor containing original offset (read-only) + over_capacity_ptr: Pointer to counter tensor (incremented when over capacity) + capacity: Total capacity of the buffer + BLOCK_SIZE: Block size for Triton kernel + num_iterations: Number of iterations each block should handle + """ + # Get the program ID + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # Load the size value from GPU tensor + size = tl.load(size_ptr) + + # Load original offset from GPU tensor (for position calculations) + tensor_offset = tl.load(tensor_offset_ptr) + alloc_offset = tl.load(alloc_offset_ptr) + free_offset = tl.load(free_offset_ptr) + capacity = tl.load(capacity_ptr) + + # Each block handles num_iterations chunks of BLOCK_SIZE elements + # Use while loop with early exit condition in the loop test + iteration = 0 + block_start = (pid * num_iterations + iteration) * BLOCK_SIZE + while iteration < num_iterations and block_start < size: + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Create mask for valid elements within source size + dst_mask = offsets < size + + # Create mask for valid destination indices (within buffer capacity) + src_indices = tensor_offset + offsets + src_mask = src_indices >= capacity + src_indices = tl.where(src_mask, src_indices - capacity, src_indices) + + # Load from source + src_data = tl.load(src_ptr + src_indices, mask=dst_mask, other=0.0) + + # Store to destination (ignores out-of-bound writes) + tl.store(dst_ptr + offsets, src_data, mask=dst_mask) + + # Move to next iteration + iteration += 1 + block_start = (pid * num_iterations + iteration) * BLOCK_SIZE + + # Check if over capacity and increment counter + size_page_aligned = tl.cdiv(size, BLOCK_SIZE) * BLOCK_SIZE + tensor_offset = tensor_offset + size_page_aligned + if tensor_offset > capacity: + tensor_offset -= capacity + if pid == 0: + mask = tensor_offset > alloc_offset + tl.store(alloc_offset_ptr, tensor_offset, mask=mask) + +class StashBuffer: + """ + A class to represent a stash buffer. + """ + + def __init__(self, size, device, overflow, dtype): + self.buffer = torch.empty(size, dtype=dtype, device=device) + self.overflow = overflow # GPU flag + self.device = device + self.free_offset = torch.zeros(1, dtype=torch.int64, device=device) # start offset of free space + self.alloc_offset = torch.zeros(1, dtype=torch.int64, device=device) # start offset of allocations + self.capacity = torch.zeros(1, dtype=torch.int64, device=device) + self.capacity.fill_(size) + self.dtype = dtype + def reset(self): + """Reset the stash buffer.""" + #assert self.alloc_offset.item() == self.free_offset.item(), f"alloc_offset {self.alloc_offset.item()} != free_offset {self.free_offset.item()}" + #print + self.free_offset.zero_() + self.alloc_offset.zero_() + + def __repr__(self): + return f"StashBuffer(capacity={self.capacity}, device={self.device})" + + +class PackedTensor: + """ + A class to represent a packed tensor. + """ + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None): + self._tensor = tensor.clone() + self._original_tensor = None + assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1, f"num_tokens_tensor {num_tokens_tensor} is not a scalar tensor" + self.num_tokens_tensor = num_tokens_tensor.clone() + self.vp_stage = vp_stage + + # Original tensor information + self.original_shape = list(tensor.shape) + self.num_elements = tensor.numel() + self.element_size = tensor.element_size() + self.hidden_size = self.num_elements // self.original_shape[0] + self.dtype = tensor.dtype + self.device = tensor.device + + self.stash_buffer_offset = None + + def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=128): + """Offload the packed tensor.""" + #self._tensor.record_stream(torch.cuda.current_stream()) + # TODO: Call offload function to offload the tensor + # After offload stream joins main stream, the tensor is no longer needed and can be freed + + #pass + + """Copy tensor content into stash_buffer starting at current offset using Triton kernel. + + Out-of-bound writes are silently ignored by the kernel. + Increments self.over_capacity counter if capacity was exceeded. + + Args: + tensor (torch.Tensor): The tensor to stash. Will be flattened before copying. + size (torch.Tensor): GPU tensor containing the number of bytes to copy. + max_blocks (int): Maximum number of blocks to launch. Defaults to 2048. + + Returns: + offset: GPU tensor indicating the offset where the tensor was stashed. + + Raises: + RuntimeError: If Triton is not available. + """ + if not HAVE_TRITON: + raise RuntimeError("Triton is required for PackedTensor.offload_to_stash(). Please install triton.") + + self._tensor = self._tensor.contiguous() + if self.num_tokens_tensor.dim() == 0: + self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) + num_elements_tensor = self.num_tokens_tensor.mul(self.hidden_size) + # Flatten the tensor to get total number of elements + flat_tensor = self._tensor.flatten() + + # Determine grid size with cap on max blocks + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + max_size = flat_tensor.numel() + total_blocks_needed = triton.cdiv(max_size, BLOCK_SIZE) + + # Cap the number of blocks and calculate iterations per block + num_blocks = min(total_blocks_needed, max_blocks) + num_iterations = triton.cdiv(total_blocks_needed, num_blocks) + if DEBUG: + debug_print (f"offload_to_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} num_elements {num_elements_tensor.item()} max_blocks {max_blocks} total_blocks_needed {total_blocks_needed} num_blocks {num_blocks} num_iterations {num_iterations} oveflow {stash_buffer.overflow.item()}") + # + grid = (num_blocks,) + self.stash_buffer_offset = stash_buffer.free_offset.clone() + + # Launch Triton kernel to copy data + # self.offload_stream.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(self.offload_stream): + + # TODO: make this async. Something unexpected with TE on deallocate the tensor + _stash_copy_kernel[grid]( + flat_tensor, + stash_buffer.buffer, + num_elements_tensor, + stash_buffer.alloc_offset, # Read-only: Write boundary + stash_buffer.free_offset, # Read+Write: Start offset for offload + stash_buffer.capacity, # Read-only: Capacity of the buffer + stash_buffer.overflow, # Read+Write: Over capacity flag updated by kernel + BLOCK_SIZE=BLOCK_SIZE, + num_iterations=num_iterations, + ) + if DEBUG: + self._original_tensor = self._tensor.clone() + self._tensor = None + if DEBUG: + debug_print (f"After offload_to_stash {stash_buffer.free_offset.item()} overflow {stash_buffer.overflow.item()} capacity {stash_buffer.capacity.item()}") + + + def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=128): + """Reload the packed tensor from the stash.""" + if not HAVE_TRITON: + raise RuntimeError("Triton is required for PackedTensor.reload_from_stash(). Please install triton.") + self._tensor = torch.zeros(self.original_shape, dtype=self.dtype, device=self.device) + flat_tensor = self._tensor.flatten() + + num_elements_tensor = self.num_tokens_tensor.mul(self.hidden_size) + + # Determine grid size with cap on max blocks + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + max_size = self.num_elements + total_blocks_needed = triton.cdiv(max_size, BLOCK_SIZE) + + # Cap the number of blocks and calculate iterations per block + num_blocks = min(total_blocks_needed, max_blocks) + num_iterations = triton.cdiv(total_blocks_needed, num_blocks) + + if DEBUG: + debug_print (f"reload_from_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} num_elements {num_elements_tensor.item()} max_blocks {max_blocks} total_blocks_needed {total_blocks_needed} num_blocks {num_blocks} num_iterations {num_iterations}") + # + grid = (num_blocks,) + + + # Launch Triton kernel to copy data + # self.offload_stream.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(self.offload_stream): + + # TODO: make this async. Something unexpected with TE on deallocate the tensor + _stash_pop_kernel[grid]( + stash_buffer.buffer, + flat_tensor, + num_elements_tensor, + self.stash_buffer_offset, # Read-only: Start offset for reload + stash_buffer.alloc_offset, # Read+write: Free stash buffer for model chunk + stash_buffer.free_offset, # Read: Start offset for offload + stash_buffer.capacity, # Read-only: Capacity of the buffer + BLOCK_SIZE=BLOCK_SIZE, + num_iterations=num_iterations, + ) + #torch.cuda.synchronize() + if DEBUG: + debug_print (f"After reload_from_stash alloc_offset {stash_buffer.alloc_offset.item()} free_offset {stash_buffer.free_offset.item()} capacity {stash_buffer.capacity.item()}") + def __repr__(self): + return f"PackedTensor(original_shape={self.original_shape}, num_tokens={self.num_tokens_tensor.item()}, vp_stage={self.vp_stage})" + +class PP_ScheduleFunction(torch.autograd.Function): + """ + This function is used to update the pp schedule. + """ + + @staticmethod + def forward(ctx, tensor, offload_manager): # after forward + # pylint: disable=missing-function-docstring + + ctx.offload_manager = offload_manager + ctx.vp_stage = offload_manager.current_vp_stage + if ctx.vp_stage is None: + ctx.vp_stage = 0 + ctx.layer_no, ctx.microbatch_no = offload_manager.update_pp_schedule(ctx.vp_stage+1) + current_stream = torch.cuda.current_stream() + if offload_manager._pack_stream_status == 'offloading': + current_stream.wait_stream(offload_manager.pack_stream) + offload_manager._pack_stream_status = 'idle' + + if offload_manager.status == 'captured': + current_schedule_layer = (ctx.vp_stage+1)*100 + ctx.layer_no*10 + ctx.microbatch_no + next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index+1] + if current_schedule_layer != -next_schedule_layer: + # Start offload for current layer + ctx.offload_manager.offload_packed_tensors(current_schedule_layer) + if next_schedule_layer < 0: + # reload for next backward layer + ctx.offload_manager.reload_packed_tensors(-next_schedule_layer) + else: + ctx.offload_manager.remove_packed_tensor_from_offload() + + ctx.offload_manager.current_schedule_index += 1 + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, *grad_output): # before backward + # pylint: disable=missing-function-docstring + #debug_print(f"PP_ScheduleFunction vp_stage {ctx.vp_stage} before backward") + if ctx.vp_stage is not None: + ctx.offload_manager.update_pp_schedule(-(ctx.vp_stage+1), -ctx.layer_no, -ctx.microbatch_no) + ctx.offload_manager.current_schedule_index += 1 + current_stream = torch.cuda.current_stream() + if ctx.offload_manager._unpack_stream_status == 'reloading': + current_stream.wait_stream(ctx.offload_manager.unpack_stream) + ctx.offload_manager._unpack_stream_status = 'idle' + + if ctx.offload_manager.status == 'captured' and ctx.offload_manager.current_schedule_index < len(ctx.offload_manager._pp_schedule): + next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index] + if next_schedule_layer < 0: + ctx.offload_manager.reload_packed_tensors(-next_schedule_layer) + + return grad_output + (None, None) + +class PackedOffloadManager: + """ + Singleton manager for coordinating activation offloading across pipeline stages. + Manages chunk handlers, synchronizes GPU-GPU transfers, + and handles virtual pipeline parallelism. + """ + + OFFLOAD_MGR = None + + @classmethod + def get_instance(cls): + """Get the singleton instance of PipelineOffloadManager.""" + if cls.OFFLOAD_MGR is None: + cls.OFFLOAD_MGR = PackedOffloadManager() + return cls.OFFLOAD_MGR + + def __init__(self): + """Initialize the manager with queues and dedicated CUDA streams.""" + # allocate streams and events for synchronization + self._pack_stream = torch.cuda.Stream() + self._unpack_stream = torch.cuda.Stream() + self._pack_stream_status = 'idle' # idle, offloading + self._unpack_stream_status = 'idle' # idle, reloading + self.packed_tensors_to_offload = [] + self.packed_tensors_to_reload = {} + + self.iteration = 0 + self._current_layer_name = None + self.vp_size = None + self.current_vp_stage = None + self._last_layer = False + self.status = 'begin' # begin, capture, captured + self._pp_schedule = None # If element is +ve, it denotes forward pass of vp stage, if -ve, it denotes backward pass of vp stage + self.current_layer = None + self.current_microbatch = None + self.current_schedule_index = None + + self.page_size = GLOBAL_BLOCK_SIZE + self.max_pages_per_vp_stage = None + self.temp_pages_per_vp_stage = None + self.num_tokens_tensor = None + self.max_num_tokens = None + self.stash_buffers = None + self.overflow = None + self.device = None + + @property + def pack_stream(self): + """Get the pack stream.""" + return self._pack_stream + + @property + def unpack_stream(self): + """Get the unpack stream.""" + return self._unpack_stream + + def set_current_layer_name(self, name): + """Set the current layer name.""" + self._current_layer_name = name + + def add_packed_tensor_to_offload(self, packed_tensor): + """Add a packed tensor to the offload list.""" + if self.status == 'captured': + self.packed_tensors_to_offload.append(packed_tensor) + else: + pass + + def remove_packed_tensor_from_offload(self): + """Remove all packed tensors from the offload list.""" + if self.status == 'captured': + while len(self.packed_tensors_to_offload) > 0: + packed_tensor = self.packed_tensors_to_offload.pop(0) + assert len(self.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {self.packed_tensors_to_offload}" + else: + pass + + def offload_packed_tensors(self, pp_schedule_layer): + """Offload the packed tensors.""" + current_stream = torch.cuda.current_stream() + self.pack_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.pack_stream): + if self.status == 'captured': + self._pack_stream_status = 'offloading' + #assert self.packed_tensors_to_reload + #for packed_tensor in self.packed_tensors_to_offload: + # packed_tensor.offload_to_stash(self.stash_buffers[packed_tensor.vp_stage]) + debug_print(f"offload_packed_tensors {len(self.packed_tensors_to_offload)}") + if pp_schedule_layer not in self.packed_tensors_to_reload: + self.packed_tensors_to_reload[pp_schedule_layer] = [] + assert len(self.packed_tensors_to_reload[pp_schedule_layer]) == 0, f"packed_tensors_to_reload {pp_schedule_layer} is not empty {self.packed_tensors_to_reload[pp_schedule_layer]}" + + while len(self.packed_tensors_to_offload) > 0: + packed_tensor = self.packed_tensors_to_offload.pop(0) + packed_tensor.offload_to_stash(self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype]) + self.packed_tensors_to_reload[pp_schedule_layer].append(packed_tensor) + else: + pass + assert len(self.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {self.packed_tensors_to_offload}" + + def reload_packed_tensors(self, pp_schedule_layer): + """Reload the packed tensors.""" + current_stream = torch.cuda.current_stream() + self.unpack_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.unpack_stream): + if self.status == 'captured': + self._unpack_stream_status = 'reloading' + count = 0 + for item in self.packed_tensors_to_reload: + if len(self.packed_tensors_to_reload[item]) > 0: + count += 1 + + debug_print(f"reload_packed_tensors {count}") + while len(self.packed_tensors_to_reload[pp_schedule_layer]) > 0: + packed_tensor = self.packed_tensors_to_reload[pp_schedule_layer].pop(0) + packed_tensor.reload_from_stash(self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype]) + else: + pass + assert len(self.packed_tensors_to_reload[pp_schedule_layer]) == 0, f"packed_tensors_to_reload {pp_schedule_layer} is not empty {self.packed_tensors_to_reload[pp_schedule_layer]}" + + + def allocate_offload_pages(self, stash_buffer_size_factor=1.10): + """Allocate offload pages for each vp stage.""" + self.stash_buffers = [] + self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) + for vp_stage in range(self.vp_size): + self.stash_buffers.append({}) + for dtype in self.max_pages_per_vp_stage[vp_stage]: + self.max_pages_per_vp_stage[vp_stage][dtype] = int(self.max_pages_per_vp_stage[vp_stage][dtype] * stash_buffer_size_factor) + self.stash_buffers[vp_stage][dtype] = StashBuffer(self.max_pages_per_vp_stage[vp_stage][dtype]*GLOBAL_BLOCK_SIZE, self.device, self.overflow, dtype) + debug_print(f'allocated stash buffer {vp_stage} {dtype} {self.stash_buffers[vp_stage][dtype]}') + + def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): + """Update the pp schedule.""" + if self._pp_schedule is None: + self._pp_schedule = [] + # current layer and microbatch for each vp stage for forward pass + self.current_layer = [1 for _ in range(self.vp_size)] + self.current_microbatch = [1 for _ in range(self.vp_size)] + + assert self.vp_size is not None + if layer_no is None: + # forward pass + layer_no = self.current_layer[vp_stage-1] + self.current_layer[vp_stage-1] += 1 + microbatch_no = self.current_microbatch[vp_stage-1] + if self._last_layer: + self.current_layer[vp_stage-1] = 1 + self.current_microbatch[vp_stage-1] += 1 + + if self.status == 'capture': + self._pp_schedule.append(vp_stage*100 + layer_no*10 + microbatch_no) + num_tokens = self.num_tokens_tensor.item() + + #debug_print(f"------{self.current_schedule_index} len PP_Schedule {len(self._pp_schedule)}") + #debug_print(f" {self.status} {self.current_schedule_index} {self._pp_schedule[self.current_schedule_index]} {vp_stage*100 + layer_no*10 + microbatch_no}") + assert self._pp_schedule[self.current_schedule_index] == vp_stage*100 + layer_no*10 + microbatch_no, f"schedule {self._pp_schedule[self.current_schedule_index]} != {vp_stage*100 + layer_no*10 + microbatch_no}" + + + return layer_no, microbatch_no + #self._pp_schedule.append(vp_size) + #self._pp_schedule.append(vp_stage) + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """ + Hook called when autograd saves a tensor for backward pass. + Returns a tag to identify the tensor later. + """ + + + if self.max_num_tokens is None or tensor.size(0) != self.max_num_tokens: + return tensor.detach() + if tensor.size(1) == 7168 and DEBUG: + return tensor.detach() + + if self.status == 'capture': + + self.num_tokens = self.num_tokens_tensor.item() + num_elements = tensor.numel() * self.num_tokens // self.max_num_tokens + num_pages = (num_elements + self.page_size - 1) // self.page_size + + if tensor.dtype not in self.temp_pages_per_vp_stage[self.current_vp_stage]: + self.temp_pages_per_vp_stage[self.current_vp_stage][tensor.dtype] = 0 + self.max_pages_per_vp_stage[self.current_vp_stage][tensor.dtype] = 0 + self.temp_pages_per_vp_stage[self.current_vp_stage][tensor.dtype] += num_pages + self.max_pages_per_vp_stage[self.current_vp_stage][tensor.dtype] = max(self.max_pages_per_vp_stage[self.current_vp_stage][tensor.dtype], self.temp_pages_per_vp_stage[self.current_vp_stage][tensor.dtype]) + + packed_tensor = PackedTensor(tensor.detach(), num_tokens_tensor=self.num_tokens_tensor, vp_stage=self.current_vp_stage) + if self.status == 'captured': + if DEBUG: + debug_print(f"on_save_for_backward {packed_tensor._tensor.shape} {packed_tensor.num_tokens_tensor.item()}") + self.add_packed_tensor_to_offload(packed_tensor) + #debug_print(f"------{self._pp_schedule[self.current_schedule_index] if self.status == 'captured' else ""} {self._current_layer_name} on_save_for_backward {hex(tensor.data_ptr())}-{tensor.shape}-{tensor.dtype}") + return packed_tensor + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """ + Hook called when autograd retrieves a saved tensor during backward pass. + Returns the actual tensor (potentially reloading from CPU). + """ + # debug_print(f"------{self._pp_schedule[self.current_schedule_index-1] if self.status == 'captured' else ""} {self._current_layer_name} on_get_saved_tensor {saved_state.shape}-{saved_state.dtype}") + if isinstance(saved_state, PackedTensor): + #debug_print (f'on_get_saved_tensor {type(saved_state)}') + if self.status == 'capture': + num_tokens = saved_state.num_tokens_tensor.item() + num_elements = saved_state.num_elements * num_tokens // self.max_num_tokens + num_pages = (num_elements + self.page_size - 1) // self.page_size + self.temp_pages_per_vp_stage[saved_state.vp_stage][saved_state.dtype] -= num_pages + + + if saved_state._original_tensor is not None: + if self.status == 'captured': + debug_print(f"on_get_saved_tensor {saved_state._original_tensor.shape} {saved_state.num_tokens_tensor.item()}") + if saved_state._tensor is not None: + original_flat = saved_state._original_tensor.flatten() + tensor_flat = saved_state._tensor.flatten() + num_elements = saved_state.num_tokens_tensor.item() * saved_state.hidden_size + original_flat_sub = original_flat[:num_elements] + tensor_flat_sub = tensor_flat[:num_elements] + equal = torch.equal(original_flat_sub, tensor_flat_sub) + num_not_equal = (original_flat_sub != tensor_flat_sub).sum() + idx_not_equal = (original_flat_sub != tensor_flat_sub).nonzero() + debug_print(f"on_get_saved_tensor original: {saved_state._original_tensor.shape} tensor: {saved_state._tensor.shape} equal tensors {equal} num_not_equal {num_not_equal}/{num_elements} idx_not_equal {idx_not_equal} original_tensor {original_flat_sub[idx_not_equal]} tensor {tensor_flat_sub[idx_not_equal]}") + #debug_print(f"on_get_saved_tensor equal tensors {torch.equal(saved_state._original_tensor, saved_state._tensor)} original_tensor {original_flat[-100:]} tensor {tensor_flat[-100:]}") + return saved_state._original_tensor + else: + if self.status == 'captured' and DEBUG: + debug_print(f"on_get_saved_tensor {saved_state._tensor.shape} {saved_state.num_tokens_tensor.item()}") + return saved_state._tensor + + return saved_state + +def packed_moe_expert_offloading_group_start(tensor, name=None): + """Mark the start of a layer group and prepare for offload/reload.""" + rank = torch.distributed.get_rank() + #debug_print(f'{rank}: packed_moe_expert_offloading_group_start tensor {tensor.shape}-{tensor.dtype} name {name}') + return tensor + #return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) + +def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num_tokens_tensor=None): + """Get the fine-grained offload context""" + #debug_print(f'get_packed_moe_expert_offloading_context name {name}') + offload_manager = PackedOffloadManager.get_instance() + offload_manager.max_num_tokens = max_num_tokens + assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) + + offload_manager.num_tokens_tensor = num_tokens_tensor + offload_manager.set_current_layer_name(name) if name is not None else None + pack_unpack_context = torch.autograd.graph.saved_tensors_hooks(offload_manager.on_save_for_backward, offload_manager.on_get_saved_tensor) + return pack_unpack_context + +def packed_moe_expert_offloading_group_commit(tensor, name=None): + """Mark the end of a layer group and prepare for offload/reload.""" + rank = torch.distributed.get_rank() + #debug_print(f'{rank}: packed_moe_expert_offloading_group_commit tensor {tensor.shape}-{tensor.dtype} name {name}') + offload_manager = PackedOffloadManager.get_instance() + offload_manager.device = tensor.device + + return PP_ScheduleFunction.apply(tensor, offload_manager) + +def packed_moe_expert_offloading_init_chunk_handler(vp_size, vp_stage): + """Initialize the chunk handler, called at the start of a microbatch forward pass.""" + #debug_print(f'packed_moe_expert_offloading_init_chunk_handler vp_size {vp_size} vp_stage {vp_stage}') + offload_manager = PackedOffloadManager.get_instance() + offload_manager.current_vp_stage = vp_stage if vp_stage is not None else 0 + if vp_size is not None: + offload_manager.vp_size = vp_size + else: + offload_manager.vp_size = 1 + if offload_manager.max_pages_per_vp_stage is None: + offload_manager.max_pages_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] + offload_manager.temp_pages_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] + +def packed_moe_expert_offloading_set_last_layer(is_last_layer=False): + """Set the last layer flag.""" + #PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) + #debug_print(f'packed_moe_expert_offloading_set_last_layer is_last_layer {is_last_layer}') + offload_manager = PackedOffloadManager.get_instance() + offload_manager._last_layer = is_last_layer + +def packed_moe_expert_offloading_reset(): + """Reset the chunk handler, called at the start of a training iteration.""" + offload_manager = PackedOffloadManager.get_instance() + offload_manager.iteration += 1 + # current layer and microbatch for each vp stage for forward pass + offload_manager.current_schedule_index = 0 + if os.getenv('MEM_PROFILE', '0') == '1': + if offload_manager.iteration == 1 and torch.distributed.get_rank() == 0: + torch.cuda.memory._record_memory_history() + print(f'packed_moe_expert_offloading_reset record_memory_history') + if offload_manager.iteration == 10 and torch.distributed.get_rank() == 0: + torch.cuda.memory._dump_snapshot("packed_offloading.pkl") + torch.cuda.memory._record_memory_history(enabled=None) + print(f'packed_moe_expert_offloading_reset dump_snapshot') + + if offload_manager.status == 'begin': + offload_manager.status = 'capture' + elif offload_manager.status == 'capture': + offload_manager.status = 'captured' + offload_manager.allocate_offload_pages(stash_buffer_size_factor=1.10) # 10% extra to account for overhead + debug_print(f'packed_moe_expert_offloading_reset captured schedule: {offload_manager._pp_schedule}') + debug_print(f'packed_moe_expert_offloading_reset max_pages_per_vp_stage: {offload_manager.max_pages_per_vp_stage}') + elif offload_manager.status == 'captured': + pass + else: + debug_print(f'packed_moe_expert_offloading_reset unknown status: {offload_manager.status}') + + if offload_manager.status == 'captured': + offload_manager.overflow.zero_() + for vp_buffers in offload_manager.stash_buffers: + for dtype in vp_buffers.keys(): + vp_buffers[dtype].reset() + + offload_manager.current_layer = [1 for _ in range(offload_manager.vp_size)] + offload_manager.current_microbatch = [1 for _ in range(offload_manager.vp_size)] + assert len(offload_manager.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {offload_manager.packed_tensors_to_offload}" + \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index ed3794208f0..f0632cb6504 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -12,6 +12,9 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) +from megatron.core.pipeline_parallel.moe_packed_offload import ( + packed_moe_expert_offloading_reset, +) from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, @@ -590,6 +593,9 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + if not forward_only and config.packed_moe_expert_offloading: + packed_moe_expert_offloading_reset() + no_sync_func = config.no_sync_func if no_sync_func is None: no_sync_func = contextlib.nullcontext @@ -1049,6 +1055,9 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" + if not forward_only and config.packed_moe_expert_offloading: + packed_moe_expert_offloading_reset() + if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -1545,10 +1554,12 @@ def forward_backward_helper_wrapper( send_next_wait_handle = None send_prev_wait_handle = None recv_next_wait_handles = [] - + model_chunk_ids = {0: [1,3], 1:[2,4]} + print (f'{torch.distributed.get_rank()}: forward_backward_pipelining_with_interleaving num_warmup_microbatches {num_warmup_microbatches} num_microbatches_1f1b {num_microbatches_remaining} total_num_microbatches {total_num_microbatches}') for k in range(num_warmup_microbatches): cur_model_chunk_id = get_model_chunk_id(k, forward=True) - + if torch.distributed.get_rank() in [0, 2]: + print(f'{pipeline_parallel_rank}: +++++ warmup iteration {k}, fwd_model_chunk_id {model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}') if config.overlap_p2p_comm_warmup_flush: if ( not ( @@ -1705,6 +1716,8 @@ def forward_backward_helper_wrapper( # Forward pass. forward_k = k + num_warmup_microbatches + + # Decide to checkpoint all layers' activations of the current micro-batch. if max_outstanding_backprops is not None: checkpoint_activations_microbatch = ( @@ -1718,6 +1731,8 @@ def forward_backward_helper_wrapper( if config.overlap_p2p_comm: backward_k = k + if torch.distributed.get_rank() in [0, 2]: + print(f'{pipeline_parallel_rank}: +++++ steady iteration forward_k {forward_k} fwd_model_chunk_id {model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}, backward_k {backward_k} bwd_model_chunk_id {-model_chunk_ids[pipeline_parallel_rank][get_model_chunk_id(backward_k, forward=False)]}') # Sync forward recv def pp_pre_forward(vp_stage=None): @@ -1936,6 +1951,8 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): ) for k in range(num_microbatches_remaining, total_num_microbatches): cur_model_chunk_id = get_model_chunk_id(k, forward=False) + if torch.distributed.get_rank() in [0, 2]: + print(f'{pipeline_parallel_rank}: cooldown iteration k {k} bwd_model_chunk_id {-model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}') if ( not (_is_vp_last_stage(vp_stage=cur_model_chunk_id) and is_pp_last_stage(pp_group)) and k != 0 @@ -2232,6 +2249,9 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + if not forward_only and config.packed_moe_expert_offloading: + packed_moe_expert_offloading_reset() + # Disable async grad reductions no_sync_func = config.no_sync_func if no_sync_func is None: diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 8168c8ab611..fe40bb83f07 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -6,6 +6,7 @@ from collections.abc import Callable from copy import deepcopy from dataclasses import dataclass +from contextlib import nullcontext from functools import partial from itertools import chain from math import ceil @@ -32,6 +33,12 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) +from megatron.core.pipeline_parallel.moe_packed_offload import ( + packed_moe_expert_offloading_group_start, + get_packed_moe_expert_offloading_context, + packed_moe_expert_offloading_reset, + packed_moe_expert_offloading_group_commit, +) from megatron.core.tensor_parallel.layers import ( _initialize_affine_weight_cpu, _initialize_affine_weight_gpu, @@ -698,6 +705,12 @@ def __init__( and "moe_act" in self.config.offload_modules ) + self.packed_offload_expert_fc1 = self.config.packed_moe_expert_offloading and "expert_fc1" in self.config.offload_modules + self.packed_offload_moe_act = self.config.packed_moe_expert_offloading and "moe_act" in self.config.offload_modules + self.packed_offload_expert_fc2 = self.config.packed_moe_expert_offloading and "expert_fc2" in self.config.offload_modules + + if torch.distributed.get_rank() == 0: + print(f'packed_offload_expert_fc1 {self.packed_offload_expert_fc1}, packed_offload_moe_act {self.packed_offload_moe_act}, packed_offload_expert_fc2 {self.packed_offload_expert_fc2}') self.activation_recompute = ( self.config.recompute_granularity == 'selective' and "moe_act" in self.config.recompute_modules @@ -1001,9 +1014,14 @@ def forward( with off_interface( self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" ) as permuted_local_hidden_states: - fc1_output, bias_parallel = apply_module(self.linear_fc1)( - permuted_local_hidden_states, tokens_per_expert - ) + if self.packed_offload_expert_fc1: + offload_context = get_packed_moe_expert_offloading_context(name="expert_fc1", max_num_tokens=permuted_local_hidden_states.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + else: + offload_context = nullcontext() + with offload_context: + fc1_output, bias_parallel = apply_module(self.linear_fc1)( + permuted_local_hidden_states, tokens_per_expert + ) if self.offload_expert_fc1: fc1_output = off_interface.group_commit( fc1_output, @@ -1102,9 +1120,25 @@ def glu(x): ) else: with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: - bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) + if self.packed_offload_moe_act: + offload_context = get_packed_moe_expert_offloading_context(name="moe_act", max_num_tokens=fc1_output.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + else: + offload_context = nullcontext() + with offload_context: + bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) + if self.offload_moe_act: + (bias_act_output,) = fine_grained_offloading_group_commit( + bias_act_output, name="moe_act", forced_released_tensors=[fc1_output] + ) - output, output_bias = apply_module(self.linear_fc2)(bias_act_output, tokens_per_expert) + if self.packed_offload_expert_fc2: + offload_context = get_packed_moe_expert_offloading_context(name="expert_fc2", max_num_tokens=bias_act_output.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + else: + offload_context = nullcontext() + with offload_context: + output, output_bias = apply_module(self.linear_fc2)(bias_act_output, tokens_per_expert) + if self.config.packed_moe_expert_offloading: + output = packed_moe_expert_offloading_group_commit(output, name="expert_fc2") if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 62e7ff41b87..1c40bb7f7c2 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -78,6 +78,7 @@ def __init__( self.tp_size = utils.get_pg_size(self.tp_group) self.tp_rank = utils.get_pg_rank(self.tp_group) self.ep_size = utils.get_pg_size(self.ep_group) + self.ep_rank = utils.get_pg_rank(self.ep_group) # Attributes that need to be captured in cudagraph. These attributes are returned # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. @@ -1022,10 +1023,15 @@ def __init__( "https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep." ) - def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): + def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor, budget_local: int = None): num_tokens = routing_map.shape[0] self.routing_map = routing_map.reshape(num_tokens, self.num_experts) self.token_probs = probs.reshape(num_tokens, self.num_experts) + #if torch.distributed.get_rank() == 0: + # print (f'setup_metadata budget_local {budget_local}') + if budget_local is not None: + self.num_dispatched_tokens = budget_local + self.num_permuted_tokens = budget_local # Compute the capacity for each expert at the drop_and_pad mode if self.drop_and_pad: num_out_tokens = num_tokens * self.config.moe_router_topk @@ -1071,11 +1077,12 @@ def dispatch( ) ) - if not self.drop_and_pad: - self.tokens_per_expert = tokens_per_expert + if self.num_permuted_tokens is None: + self.tokens_per_expert = tokens_per_expert.to(torch.int64) # self.num_permuted_tokens is necessary to allocate the output tensor for permute self.num_permuted_tokens = self.tokens_per_expert.sum() - + if self.config.moe_expert_capacity_factor_for_packed_offloading is not None: + self.tokens_per_expert = tokens_per_expert.to(torch.int64) return dispatched_hidden def combine( @@ -1399,6 +1406,8 @@ def __init__( "Please set --moe-flex-dispatcher-backend=deepep or " "--moe-flex-dispatcher-backend=hybridep" ) + self.packed_offloading_capacity_factor = self.config.moe_expert_capacity_factor_for_packed_offloading + self.budget_local_gpu = None def set_shared_experts(self, shared_experts): raise NotImplementedError( @@ -1430,6 +1439,14 @@ def _initialize_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor) - .expand(-1, -1, self.tp_size, -1) .reshape(num_local_tokens, world_size, self.num_local_experts) ).contiguous() + + if self.packed_offloading_capacity_factor is not None: + budget_local = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) + #if self.ep_rank == 0: + # print (f'budget_local {budget_local} = {routing_map.shape[0]} x {self.config.moe_router_topk} x {self.packed_offloading_capacity_factor}') + self.budget_local_gpu = torch.full((1,), budget_local, device='cuda') + else: + self.budget_local_gpu = None return routing_map, probs @jit_fuser @@ -1456,7 +1473,12 @@ def dispatch_preprocess( # Initialize metadata routing_map, probs = self._initialize_metadata(routing_map, probs) - self._comm_manager.setup_metadata(routing_map, probs) + budget_local = None + if self.packed_offloading_capacity_factor is not None: + budget_local = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) +# if self.ep_rank == 0: +# print (f'budget_local {budget_local} = {routing_map.shape[0]} x {self.config.moe_router_topk} x {self.packed_offloading_capacity_factor}') + self._comm_manager.setup_metadata(routing_map, probs, budget_local) return hidden_states, self._comm_manager.token_probs def token_dispatch( diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index e9bd52f34b4..9fb40446126 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -17,6 +17,9 @@ from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel.moe_packed_offload import ( + packed_moe_expert_offloading_set_last_layer, +) from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import CheckpointManager @@ -892,6 +895,10 @@ def forward( mhc_manager.is_last_layer_in_recompute_block = ( mhc_is_last_in_recompute_block[l_no] ) + if self.config.packed_moe_expert_offloading: + packed_moe_expert_offloading_set_last_layer( + is_last_layer = (l_no == self.num_layers_per_pipeline_rank - 1) + ) with self.offload_context, inner_quantization_context: hidden_states, context = layer( diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 8e514806018..6195c905a9a 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -993,6 +993,10 @@ class TransformerConfig(ModelParallelConfig): min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" + packed_moe_expert_offloading: bool = False + """If True, enable packed moe expert offloading.""" + + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more @@ -1447,6 +1451,17 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) + if self.packed_moe_expert_offloading: + assert ( + not self.cpu_offloading and not self.fine_grained_activation_offloading + ), "packed_moe_expert_offloading cannot be enabled with cpu_offloading." + assert self.offload_modules is not None and len(self.offload_modules) > 0 + allowed_modules = {"expert_fc1", "expert_fc2", "moe_act"} + invalid_modules = set(self.offload_modules) - allowed_modules + assert not invalid_modules, ( + f'Invalid choices for offload_modules: {invalid_modules}. ' + f'Allowed modules are: {allowed_modules}' + ) if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index c1bb0f8ac0d..8ab4c7e7785 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1409,6 +1409,12 @@ def validate_args(args, defaults={}): if is_te_min_version("2.10.0"): assert os.getenv("NVTE_CPU_OFFLOAD_V1", "0") == "1", \ "For fine-grained activation offloading with TE >= 2.10.0, NVTE_CPU_OFFLOAD_V1 should be set to 1 to avoid offloading weights." + assert not args.packed_moe_expert_offloading, "Fine-grained activation offloading and packed moe expert offloading cannot be enabled at the same time" + + if args.packed_moe_expert_offloading: + assert args.transformer_impl == 'transformer_engine', \ + "Packed moe expert offloading is only supported with transformer_engine implementation" + assert not args.fine_grained_activation_offloading, "Packed moe expert offloading and fine-grained activation offloading cannot be enabled at the same time" if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." @@ -1645,6 +1651,8 @@ def _add_inference_args(parser): group.add_argument('--use-legacy-static-engine', action='store_true', default=False, help='Use legacy static engine. (Current static engine uses dynamic engine under the hood)', dest='use_legacy_static_engine') + group.add_argument('--moe-expert-capacity-factor-for-packed-offloading', type=float, default=None, + help='The capacity factor for each EP rank when packed offloading is enabled.') group.add_argument('--inference-max-requests', type=int, default=8, help='Maximum number of requests for inference.', dest='inference_max_requests') From bbfcef2972b9a131de9394e51c4286ff0a6dba56 Mon Sep 17 00:00:00 2001 From: a Date: Mon, 17 Nov 2025 12:35:43 -0800 Subject: [PATCH 03/57] Bug fix --- .../pipeline_parallel/moe_packed_offload.py | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 67fc4d521dc..723254b1d86 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -25,7 +25,7 @@ def debug_print(message): if torch.distributed.get_rank() in DEBUG_RANK: print(f'{torch.distributed.get_rank()}: {message}') -GLOBAL_BLOCK_SIZE = 2048 +GLOBAL_BLOCK_SIZE = 1024 @triton.jit def _stash_copy_kernel( src_ptr, @@ -111,6 +111,7 @@ def _stash_copy_kernel( # Check if over capacity and increment counter size_page_aligned = tl.cdiv(size, BLOCK_SIZE) * BLOCK_SIZE + free_offset = free_offset + size_page_aligned if free_offset > capacity: free_offset -= capacity @@ -221,7 +222,7 @@ class PackedTensor: A class to represent a packed tensor. """ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None): - self._tensor = tensor.clone() + self._tensor = tensor self._original_tensor = None assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1, f"num_tokens_tensor {num_tokens_tensor} is not a scalar tensor" self.num_tokens_tensor = num_tokens_tensor.clone() @@ -237,7 +238,7 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None): self.stash_buffer_offset = None - def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=128): + def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): """Offload the packed tensor.""" #self._tensor.record_stream(torch.cuda.current_stream()) # TODO: Call offload function to offload the tensor @@ -295,20 +296,22 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=128): stash_buffer.buffer, num_elements_tensor, stash_buffer.alloc_offset, # Read-only: Write boundary - stash_buffer.free_offset, # Read+Write: Start offset for offload + stash_buffer.free_offset, # Read+Write: Start offset for next offload stash_buffer.capacity, # Read-only: Capacity of the buffer stash_buffer.overflow, # Read+Write: Over capacity flag updated by kernel BLOCK_SIZE=BLOCK_SIZE, num_iterations=num_iterations, ) - if DEBUG: - self._original_tensor = self._tensor.clone() + + # save reference to original tensor to avoid deallocation before offload is complete + self._original_tensor = self._tensor + # set tensor to None. This will be replaced by reload_from_stash. self._tensor = None if DEBUG: - debug_print (f"After offload_to_stash {stash_buffer.free_offset.item()} overflow {stash_buffer.overflow.item()} capacity {stash_buffer.capacity.item()}") + debug_print (f"After offload_to_stash offset {self.stash_buffer_offset.item()} free_offset {stash_buffer.free_offset.item()} overflow {stash_buffer.overflow.item()} capacity {stash_buffer.capacity.item()}") - def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=128): + def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): """Reload the packed tensor from the stash.""" if not HAVE_TRITON: raise RuntimeError("Triton is required for PackedTensor.reload_from_stash(). Please install triton.") @@ -350,7 +353,7 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=128): ) #torch.cuda.synchronize() if DEBUG: - debug_print (f"After reload_from_stash alloc_offset {stash_buffer.alloc_offset.item()} free_offset {stash_buffer.free_offset.item()} capacity {stash_buffer.capacity.item()}") + debug_print (f"After reload_from_stash reload_offset {self.stash_buffer_offset.item()} alloc_offset {stash_buffer.alloc_offset.item()} free_offset {stash_buffer.free_offset.item()} capacity {stash_buffer.capacity.item()}") def __repr__(self): return f"PackedTensor(original_shape={self.original_shape}, num_tokens={self.num_tokens_tensor.item()}, vp_stage={self.vp_stage})" @@ -581,8 +584,8 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if self.max_num_tokens is None or tensor.size(0) != self.max_num_tokens: return tensor.detach() - if tensor.size(1) == 7168 and DEBUG: - return tensor.detach() + #if tensor.size(1) in [7168, 4096, 1] and DEBUG: + # return tensor.detach() if self.status == 'capture': @@ -598,10 +601,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: packed_tensor = PackedTensor(tensor.detach(), num_tokens_tensor=self.num_tokens_tensor, vp_stage=self.current_vp_stage) if self.status == 'captured': - if DEBUG: - debug_print(f"on_save_for_backward {packed_tensor._tensor.shape} {packed_tensor.num_tokens_tensor.item()}") self.add_packed_tensor_to_offload(packed_tensor) - #debug_print(f"------{self._pp_schedule[self.current_schedule_index] if self.status == 'captured' else ""} {self._current_layer_name} on_save_for_backward {hex(tensor.data_ptr())}-{tensor.shape}-{tensor.dtype}") return packed_tensor def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: @@ -609,9 +609,7 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: Hook called when autograd retrieves a saved tensor during backward pass. Returns the actual tensor (potentially reloading from CPU). """ - # debug_print(f"------{self._pp_schedule[self.current_schedule_index-1] if self.status == 'captured' else ""} {self._current_layer_name} on_get_saved_tensor {saved_state.shape}-{saved_state.dtype}") if isinstance(saved_state, PackedTensor): - #debug_print (f'on_get_saved_tensor {type(saved_state)}') if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() num_elements = saved_state.num_elements * num_tokens // self.max_num_tokens @@ -619,10 +617,10 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: self.temp_pages_per_vp_stage[saved_state.vp_stage][saved_state.dtype] -= num_pages - if saved_state._original_tensor is not None: - if self.status == 'captured': - debug_print(f"on_get_saved_tensor {saved_state._original_tensor.shape} {saved_state.num_tokens_tensor.item()}") - if saved_state._tensor is not None: + if saved_state._tensor is not None: + if self.status == 'captured' and DEBUG: + #debug_print(f"on_get_saved_tensor {saved_state._original_tensor.shape} {saved_state.num_tokens_tensor.item()}") + if saved_state._original_tensor is not None: original_flat = saved_state._original_tensor.flatten() tensor_flat = saved_state._tensor.flatten() num_elements = saved_state.num_tokens_tensor.item() * saved_state.hidden_size @@ -633,20 +631,18 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: idx_not_equal = (original_flat_sub != tensor_flat_sub).nonzero() debug_print(f"on_get_saved_tensor original: {saved_state._original_tensor.shape} tensor: {saved_state._tensor.shape} equal tensors {equal} num_not_equal {num_not_equal}/{num_elements} idx_not_equal {idx_not_equal} original_tensor {original_flat_sub[idx_not_equal]} tensor {tensor_flat_sub[idx_not_equal]}") #debug_print(f"on_get_saved_tensor equal tensors {torch.equal(saved_state._original_tensor, saved_state._tensor)} original_tensor {original_flat[-100:]} tensor {tensor_flat[-100:]}") - return saved_state._original_tensor - else: - if self.status == 'captured' and DEBUG: - debug_print(f"on_get_saved_tensor {saved_state._tensor.shape} {saved_state.num_tokens_tensor.item()}") + debug_print(f"on_get_saved_tensor return _tensor") return saved_state._tensor + else: + debug_print(f"on_get_saved_tensor return _original_tensor") + return saved_state._original_tensor return saved_state def packed_moe_expert_offloading_group_start(tensor, name=None): """Mark the start of a layer group and prepare for offload/reload.""" rank = torch.distributed.get_rank() - #debug_print(f'{rank}: packed_moe_expert_offloading_group_start tensor {tensor.shape}-{tensor.dtype} name {name}') return tensor - #return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num_tokens_tensor=None): """Get the fine-grained offload context""" @@ -700,7 +696,7 @@ def packed_moe_expert_offloading_reset(): torch.cuda.memory._record_memory_history() print(f'packed_moe_expert_offloading_reset record_memory_history') if offload_manager.iteration == 10 and torch.distributed.get_rank() == 0: - torch.cuda.memory._dump_snapshot("packed_offloading.pkl") + torch.cuda.memory._dump_snapshot("packed_offloading_cg.pkl") torch.cuda.memory._record_memory_history(enabled=None) print(f'packed_moe_expert_offloading_reset dump_snapshot') From 169f9a5e5b876133544fda170680528bd9d65ecc Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Mon, 17 Nov 2025 15:23:49 -0800 Subject: [PATCH 04/57] Mem Opt --- megatron/core/pipeline_parallel/moe_packed_offload.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 723254b1d86..85ca112105a 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -376,6 +376,12 @@ def forward(ctx, tensor, offload_manager): # after forward current_stream.wait_stream(offload_manager.pack_stream) offload_manager._pack_stream_status = 'idle' + # Deallocate original tensor after offload is complete + while len(offload_manager.packed_tensors_offload_in_progress) > 0: + packed_tensor = offload_manager.packed_tensors_offload_in_progress.pop(0) + if not DEBUG: + packed_tensor._original_tensor = None + if offload_manager.status == 'captured': current_schedule_layer = (ctx.vp_stage+1)*100 + ctx.layer_no*10 + ctx.microbatch_no next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index+1] @@ -435,6 +441,7 @@ def __init__(self): self._pack_stream_status = 'idle' # idle, offloading self._unpack_stream_status = 'idle' # idle, reloading self.packed_tensors_to_offload = [] + self.packed_tensors_offload_in_progress = [] self.packed_tensors_to_reload = {} self.iteration = 0 @@ -507,6 +514,7 @@ def offload_packed_tensors(self, pp_schedule_layer): packed_tensor = self.packed_tensors_to_offload.pop(0) packed_tensor.offload_to_stash(self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype]) self.packed_tensors_to_reload[pp_schedule_layer].append(packed_tensor) + self.packed_tensors_offload_in_progress.append(packed_tensor) else: pass assert len(self.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {self.packed_tensors_to_offload}" @@ -721,4 +729,5 @@ def packed_moe_expert_offloading_reset(): offload_manager.current_layer = [1 for _ in range(offload_manager.vp_size)] offload_manager.current_microbatch = [1 for _ in range(offload_manager.vp_size)] assert len(offload_manager.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {offload_manager.packed_tensors_to_offload}" - \ No newline at end of file + assert len(offload_manager.packed_tensors_offload_in_progress) == 0, f"packed_tensors_offload_in_progress is not empty {offload_manager.packed_tensors_offload_in_progress}" + From ac9dd934c88f53ea54c3628ad795915693a222b1 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 20 Nov 2025 13:01:46 -0800 Subject: [PATCH 05/57] Handle MXFP8Tensor offload --- megatron/core/full_cuda_graph.py | 4 +- .../pipeline_parallel/moe_packed_offload.py | 68 +++++++++++++------ megatron/core/pipeline_parallel/schedules.py | 4 +- 3 files changed, 50 insertions(+), 26 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 0c00b1a45a8..eccac3388b7 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -188,8 +188,8 @@ def __call__(self, *args, **kwargs): if FullCudaGraphWrapper.cuda_graph[training_str] is None: FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: - if self.packed_moe_expert_offloading and training_str == 'training': - packed_moe_expert_offloading_reset() + if training_str == 'training': + packed_moe_expert_offloading_reset(enabled=self.packed_moe_expert_offloading) FullCudaGraphWrapper.cuda_graph[training_str].replay() self.next_iter(training_str) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 85ca112105a..2e961e9a11e 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -6,6 +6,7 @@ from typing import Any import os import torch +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor try: import triton import triton.language as tl @@ -221,19 +222,19 @@ class PackedTensor: """ A class to represent a packed tensor. """ - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None): + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=None): self._tensor = tensor self._original_tensor = None assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1, f"num_tokens_tensor {num_tokens_tensor} is not a scalar tensor" self.num_tokens_tensor = num_tokens_tensor.clone() self.vp_stage = vp_stage - + self.layer_name = layer_name # Original tensor information self.original_shape = list(tensor.shape) self.num_elements = tensor.numel() self.element_size = tensor.element_size() self.hidden_size = self.num_elements // self.original_shape[0] - self.dtype = tensor.dtype + self.dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype self.device = tensor.device self.stash_buffer_offset = None @@ -270,7 +271,7 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) num_elements_tensor = self.num_tokens_tensor.mul(self.hidden_size) # Flatten the tensor to get total number of elements - flat_tensor = self._tensor.flatten() + flat_tensor = self._tensor.flatten() if not isinstance(self._tensor, MXFP8Tensor) else self._tensor._columnwise_data.flatten() # Determine grid size with cap on max blocks BLOCK_SIZE = GLOBAL_BLOCK_SIZE @@ -281,7 +282,7 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): num_blocks = min(total_blocks_needed, max_blocks) num_iterations = triton.cdiv(total_blocks_needed, num_blocks) if DEBUG: - debug_print (f"offload_to_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} num_elements {num_elements_tensor.item()} max_blocks {max_blocks} total_blocks_needed {total_blocks_needed} num_blocks {num_blocks} num_iterations {num_iterations} oveflow {stash_buffer.overflow.item()}") + debug_print (f"offload_to_stash ({self.layer_name}) {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} num_elements {num_elements_tensor.item()} max_blocks {max_blocks} total_blocks_needed {total_blocks_needed} num_blocks {num_blocks} num_iterations {num_iterations} oveflow {stash_buffer.overflow.item()}") # grid = (num_blocks,) self.stash_buffer_offset = stash_buffer.free_offset.clone() @@ -315,8 +316,22 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): """Reload the packed tensor from the stash.""" if not HAVE_TRITON: raise RuntimeError("Triton is required for PackedTensor.reload_from_stash(). Please install triton.") - self._tensor = torch.zeros(self.original_shape, dtype=self.dtype, device=self.device) - flat_tensor = self._tensor.flatten() + if isinstance(self._original_tensor, MXFP8Tensor): + columnwise_data = torch.zeros(self.original_shape, dtype=self.dtype, device=self.device) + self._tensor = MXFP8Tensor( + shape=self._original_tensor.shape, + dtype=self._original_tensor.dtype, + fp8_dtype=self._original_tensor._fp8_dtype, + rowwise_data=self._original_tensor._rowwise_data, + rowwise_scale_inv=self._original_tensor._rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=self._original_tensor._columnwise_scale_inv, + quantizer=self._original_tensor._quantizer, + ) + flat_tensor = self._tensor._columnwise_data.flatten() + else: + self._tensor = torch.zeros(self.original_shape, dtype=self.dtype, device=self.device) + flat_tensor = self._tensor.flatten() num_elements_tensor = self.num_tokens_tensor.mul(self.hidden_size) @@ -380,7 +395,10 @@ def forward(ctx, tensor, offload_manager): # after forward while len(offload_manager.packed_tensors_offload_in_progress) > 0: packed_tensor = offload_manager.packed_tensors_offload_in_progress.pop(0) if not DEBUG: - packed_tensor._original_tensor = None + if isinstance(packed_tensor._original_tensor, MXFP8Tensor): + packed_tensor._original_tensor._columnwise_data = None + else: + packed_tensor._original_tensor = None if offload_manager.status == 'captured': current_schedule_layer = (ctx.vp_stage+1)*100 + ctx.layer_no*10 + ctx.microbatch_no @@ -550,7 +568,8 @@ def allocate_offload_pages(self, stash_buffer_size_factor=1.10): for dtype in self.max_pages_per_vp_stage[vp_stage]: self.max_pages_per_vp_stage[vp_stage][dtype] = int(self.max_pages_per_vp_stage[vp_stage][dtype] * stash_buffer_size_factor) self.stash_buffers[vp_stage][dtype] = StashBuffer(self.max_pages_per_vp_stage[vp_stage][dtype]*GLOBAL_BLOCK_SIZE, self.device, self.overflow, dtype) - debug_print(f'allocated stash buffer {vp_stage} {dtype} {self.stash_buffers[vp_stage][dtype]}') + if torch.distributed.get_rank() == 0: + print(f'allocated stash buffer {vp_stage} {dtype} {self.stash_buffers[vp_stage][dtype]}') def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): """Update the pp schedule.""" @@ -592,22 +611,26 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if self.max_num_tokens is None or tensor.size(0) != self.max_num_tokens: return tensor.detach() + if isinstance(tensor, MXFP8Tensor): + debug_print(f'on_save_for_backward MXFP8Tensor ({self._current_layer_name}) {tensor.shape} {tensor.dtype} rowwise {tensor._rowwise_data is not None} columnwise {(tensor._columnwise_data.shape, tensor._columnwise_data.dtype) if tensor._columnwise_data is not None else None}-scale_inv {tensor._columnwise_scale_inv.shape} {tensor._columnwise_scale_inv.dtype}') + assert tensor._rowwise_data is None, f"rowwise_data is not None; Only columnwise data is supported for packed offloading" + #if tensor.size(1) in [7168, 4096, 1] and DEBUG: # return tensor.detach() - if self.status == 'capture': self.num_tokens = self.num_tokens_tensor.item() num_elements = tensor.numel() * self.num_tokens // self.max_num_tokens num_pages = (num_elements + self.page_size - 1) // self.page_size - if tensor.dtype not in self.temp_pages_per_vp_stage[self.current_vp_stage]: - self.temp_pages_per_vp_stage[self.current_vp_stage][tensor.dtype] = 0 - self.max_pages_per_vp_stage[self.current_vp_stage][tensor.dtype] = 0 - self.temp_pages_per_vp_stage[self.current_vp_stage][tensor.dtype] += num_pages - self.max_pages_per_vp_stage[self.current_vp_stage][tensor.dtype] = max(self.max_pages_per_vp_stage[self.current_vp_stage][tensor.dtype], self.temp_pages_per_vp_stage[self.current_vp_stage][tensor.dtype]) + dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype + if dtype not in self.temp_pages_per_vp_stage[self.current_vp_stage]: + self.temp_pages_per_vp_stage[self.current_vp_stage][dtype] = 0 + self.max_pages_per_vp_stage[self.current_vp_stage][dtype] = 0 + self.temp_pages_per_vp_stage[self.current_vp_stage][dtype] += num_pages + self.max_pages_per_vp_stage[self.current_vp_stage][dtype] = max(self.max_pages_per_vp_stage[self.current_vp_stage][dtype], self.temp_pages_per_vp_stage[self.current_vp_stage][dtype]) - packed_tensor = PackedTensor(tensor.detach(), num_tokens_tensor=self.num_tokens_tensor, vp_stage=self.current_vp_stage) + packed_tensor = PackedTensor(tensor, num_tokens_tensor=self.num_tokens_tensor, vp_stage=self.current_vp_stage, layer_name=self._current_layer_name) if self.status == 'captured': self.add_packed_tensor_to_offload(packed_tensor) return packed_tensor @@ -628,9 +651,10 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: if saved_state._tensor is not None: if self.status == 'captured' and DEBUG: #debug_print(f"on_get_saved_tensor {saved_state._original_tensor.shape} {saved_state.num_tokens_tensor.item()}") - if saved_state._original_tensor is not None: - original_flat = saved_state._original_tensor.flatten() - tensor_flat = saved_state._tensor.flatten() + original_tensor = saved_state._original_tensor if not isinstance(saved_state._original_tensor, MXFP8Tensor) else saved_state._original_tensor._columnwise_data + if original_tensor is not None: + original_flat = original_tensor.flatten() if not isinstance(original_tensor, MXFP8Tensor) else original_tensor._columnwise_data.flatten() + tensor_flat = saved_state._tensor.flatten() if not isinstance(saved_state._tensor, MXFP8Tensor) else saved_state._tensor._columnwise_data.flatten() num_elements = saved_state.num_tokens_tensor.item() * saved_state.hidden_size original_flat_sub = original_flat[:num_elements] tensor_flat_sub = tensor_flat[:num_elements] @@ -639,10 +663,8 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: idx_not_equal = (original_flat_sub != tensor_flat_sub).nonzero() debug_print(f"on_get_saved_tensor original: {saved_state._original_tensor.shape} tensor: {saved_state._tensor.shape} equal tensors {equal} num_not_equal {num_not_equal}/{num_elements} idx_not_equal {idx_not_equal} original_tensor {original_flat_sub[idx_not_equal]} tensor {tensor_flat_sub[idx_not_equal]}") #debug_print(f"on_get_saved_tensor equal tensors {torch.equal(saved_state._original_tensor, saved_state._tensor)} original_tensor {original_flat[-100:]} tensor {tensor_flat[-100:]}") - debug_print(f"on_get_saved_tensor return _tensor") return saved_state._tensor else: - debug_print(f"on_get_saved_tensor return _original_tensor") return saved_state._original_tensor return saved_state @@ -693,7 +715,7 @@ def packed_moe_expert_offloading_set_last_layer(is_last_layer=False): offload_manager = PackedOffloadManager.get_instance() offload_manager._last_layer = is_last_layer -def packed_moe_expert_offloading_reset(): +def packed_moe_expert_offloading_reset(enabled=True): """Reset the chunk handler, called at the start of a training iteration.""" offload_manager = PackedOffloadManager.get_instance() offload_manager.iteration += 1 @@ -708,6 +730,8 @@ def packed_moe_expert_offloading_reset(): torch.cuda.memory._record_memory_history(enabled=None) print(f'packed_moe_expert_offloading_reset dump_snapshot') + if not enabled: + return if offload_manager.status == 'begin': offload_manager.status = 'capture' elif offload_manager.status == 'capture': diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index f0632cb6504..b099be93280 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -593,8 +593,8 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - if not forward_only and config.packed_moe_expert_offloading: - packed_moe_expert_offloading_reset() + if not forward_only: + packed_moe_expert_offloading_reset(enabled=config.packed_moe_expert_offloading) no_sync_func = config.no_sync_func if no_sync_func is None: From d6dbc99931a12fb399d2e16367d369d4928c23cd Mon Sep 17 00:00:00 2001 From: a Date: Thu, 20 Nov 2025 14:33:18 -0800 Subject: [PATCH 06/57] Enable Packed offloading to CPU pinned memory with PACKED_OFFLOAD_CPU=1 --- .../pipeline_parallel/moe_packed_offload.py | 45 +++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 2e961e9a11e..6096866cb5c 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -26,6 +26,38 @@ def debug_print(message): if torch.distributed.get_rank() in DEBUG_RANK: print(f'{torch.distributed.get_rank()}: {message}') +def set_ideal_affinity_for_current_gpu(): + """Set CPU affinity for the current GPU to optimize host-device transfers.""" + import uuid + + try: + import cuda.bindings.driver as cuda_driver + import cuda.bindings.runtime as cuda_runtime + except ImportError: + try: + import cuda.cuda as cuda_driver + import cuda.cudart as cuda_runtime + except ImportError: + # print("cuda-python may not be installed, skipping GPU affinity setting") + warnings.warn("cuda-python may not be installed, skipping GPU affinity setting") + return + try: + import pynvml + except ImportError: + warnings.warn("pynvml is not installed, skipping GPU affinity setting") + return + + # Get current CUDA device ID + err, device_id = cuda_runtime.cudaGetDevice() + assert err == cuda_runtime.cudaError_t.cudaSuccess + # Get device UUID + err, device_uuid = cuda_driver.cuDeviceGetUuid(device_id) + assert err == cuda_driver.CUresult.CUDA_SUCCESS + # Set CPU affinity based on GPU's NUMA node + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes))) + pynvml.nvmlDeviceSetCpuAffinity(handle) + GLOBAL_BLOCK_SIZE = 1024 @triton.jit def _stash_copy_kernel( @@ -197,9 +229,14 @@ class StashBuffer: """ A class to represent a stash buffer. """ - + def __init__(self, size, device, overflow, dtype): - self.buffer = torch.empty(size, dtype=dtype, device=device) + + self.buffer = None + if os.getenv('PACKED_OFFLOAD_CPU', '0') == '1': + self.buffer = torch.empty(size, dtype=dtype, device='cpu', pin_memory=True) + else: + self.buffer = torch.empty(size, dtype=dtype, device=device) self.overflow = overflow # GPU flag self.device = device self.free_offset = torch.zeros(1, dtype=torch.int64, device=device) # start offset of free space @@ -726,12 +763,14 @@ def packed_moe_expert_offloading_reset(enabled=True): torch.cuda.memory._record_memory_history() print(f'packed_moe_expert_offloading_reset record_memory_history') if offload_manager.iteration == 10 and torch.distributed.get_rank() == 0: - torch.cuda.memory._dump_snapshot("packed_offloading_cg.pkl") + torch.cuda.memory._dump_snapshot("packed_cpu_offloading_cg.pkl") torch.cuda.memory._record_memory_history(enabled=None) print(f'packed_moe_expert_offloading_reset dump_snapshot') if not enabled: return + + set_ideal_affinity_for_current_gpu() # Set the ideal affinity for the current GPU if offload_manager.status == 'begin': offload_manager.status = 'capture' elif offload_manager.status == 'capture': From 35d3c0631cfed762556e4b3bfdb54f9d2ef8285c Mon Sep 17 00:00:00 2001 From: a Date: Fri, 21 Nov 2025 12:37:19 -0800 Subject: [PATCH 07/57] Enable activation truncation for first step --- .../pipeline_parallel/moe_packed_offload.py | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 6096866cb5c..3c4384ddbb8 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -259,13 +259,14 @@ class PackedTensor: """ A class to represent a packed tensor. """ - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=None): + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=None, max_tokens=None): self._tensor = tensor self._original_tensor = None assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1, f"num_tokens_tensor {num_tokens_tensor} is not a scalar tensor" self.num_tokens_tensor = num_tokens_tensor.clone() self.vp_stage = vp_stage self.layer_name = layer_name + self.max_tokens = max_tokens # Original tensor information self.original_shape = list(tensor.shape) self.num_elements = tensor.numel() @@ -318,6 +319,7 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): # Cap the number of blocks and calculate iterations per block num_blocks = min(total_blocks_needed, max_blocks) num_iterations = triton.cdiv(total_blocks_needed, num_blocks) + if DEBUG: debug_print (f"offload_to_stash ({self.layer_name}) {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} num_elements {num_elements_tensor.item()} max_blocks {max_blocks} total_blocks_needed {total_blocks_needed} num_blocks {num_blocks} num_iterations {num_iterations} oveflow {stash_buffer.overflow.item()}") # @@ -339,6 +341,7 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): stash_buffer.overflow, # Read+Write: Over capacity flag updated by kernel BLOCK_SIZE=BLOCK_SIZE, num_iterations=num_iterations, +# max_tokens=self.max_tokens, ) # save reference to original tensor to avoid deallocation before offload is complete @@ -346,7 +349,7 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): # set tensor to None. This will be replaced by reload_from_stash. self._tensor = None if DEBUG: - debug_print (f"After offload_to_stash offset {self.stash_buffer_offset.item()} free_offset {stash_buffer.free_offset.item()} overflow {stash_buffer.overflow.item()} capacity {stash_buffer.capacity.item()}") + debug_print (f"After offload_to_stash offset {self.stash_buffer_offset.item()} free_offset {stash_buffer.free_offset.item()} overflow {stash_buffer.overflow.item()} capacity {stash_buffer.capacity.item()} max_tokens {self.max_tokens}") def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): @@ -649,7 +652,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if self.max_num_tokens is None or tensor.size(0) != self.max_num_tokens: return tensor.detach() if isinstance(tensor, MXFP8Tensor): - debug_print(f'on_save_for_backward MXFP8Tensor ({self._current_layer_name}) {tensor.shape} {tensor.dtype} rowwise {tensor._rowwise_data is not None} columnwise {(tensor._columnwise_data.shape, tensor._columnwise_data.dtype) if tensor._columnwise_data is not None else None}-scale_inv {tensor._columnwise_scale_inv.shape} {tensor._columnwise_scale_inv.dtype}') + debug_print(f'on_save_for_backward MXFP8Tensor ({self._current_layer_name}) ndim {tensor.ndim} shape {tensor.shape} {tensor.dtype} rowwise {tensor._rowwise_data is not None} columnwise {(tensor._columnwise_data.shape, tensor._columnwise_data.dtype) if tensor._columnwise_data is not None else None}-scale_inv {tensor._columnwise_scale_inv.shape} {tensor._columnwise_scale_inv.dtype}') assert tensor._rowwise_data is None, f"rowwise_data is not None; Only columnwise data is supported for packed offloading" #if tensor.size(1) in [7168, 4096, 1] and DEBUG: @@ -666,8 +669,21 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: self.max_pages_per_vp_stage[self.current_vp_stage][dtype] = 0 self.temp_pages_per_vp_stage[self.current_vp_stage][dtype] += num_pages self.max_pages_per_vp_stage[self.current_vp_stage][dtype] = max(self.max_pages_per_vp_stage[self.current_vp_stage][dtype], self.temp_pages_per_vp_stage[self.current_vp_stage][dtype]) - - packed_tensor = PackedTensor(tensor, num_tokens_tensor=self.num_tokens_tensor, vp_stage=self.current_vp_stage, layer_name=self._current_layer_name) + + # Since capture stage does not use CUDA graph, we can truncate the saved tensor to actual num_tokens + # Truncate the tensor to the actual number of tokens + new_size = (self.num_tokens, *tensor.shape[1:]) + + if isinstance(tensor, MXFP8Tensor): + tensor_truncated = torch.empty(new_size, dtype=tensor._columnwise_data.dtype, device=tensor.device) + tensor_truncated.copy_(tensor._columnwise_data[:self.num_tokens, ...]) + tensor._columnwise_data = tensor_truncated + else: + tensor_truncated = torch.empty(new_size, dtype=tensor.dtype, device=tensor.device) + tensor_truncated.copy_(tensor[:self.num_tokens, ...]) + tensor = tensor_truncated + + packed_tensor = PackedTensor(tensor, num_tokens_tensor=self.num_tokens_tensor, vp_stage=self.current_vp_stage, layer_name=self._current_layer_name, max_tokens=self.max_num_tokens) if self.status == 'captured': self.add_packed_tensor_to_offload(packed_tensor) return packed_tensor @@ -684,7 +700,19 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: num_pages = (num_elements + self.page_size - 1) // self.page_size self.temp_pages_per_vp_stage[saved_state.vp_stage][saved_state.dtype] -= num_pages - + # Pad the tensor to the max number of tokens + npad = self.max_num_tokens - num_tokens + pad = () + for _ in range(saved_state._tensor.ndim-1): + pad = pad + (0, 0) + pad = pad + (0, npad) + if isinstance(saved_state._tensor, MXFP8Tensor): + saved_state._tensor._columnwise_data = torch.nn.functional.pad(saved_state._tensor._columnwise_data, pad) + else: + saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad) + + if not DEBUG: + assert saved_state._tensor is not None, f"saved_state._tensor is None {saved_state._tensor}" if saved_state._tensor is not None: if self.status == 'captured' and DEBUG: #debug_print(f"on_get_saved_tensor {saved_state._original_tensor.shape} {saved_state.num_tokens_tensor.item()}") @@ -763,7 +791,7 @@ def packed_moe_expert_offloading_reset(enabled=True): torch.cuda.memory._record_memory_history() print(f'packed_moe_expert_offloading_reset record_memory_history') if offload_manager.iteration == 10 and torch.distributed.get_rank() == 0: - torch.cuda.memory._dump_snapshot("packed_cpu_offloading_cg.pkl") + torch.cuda.memory._dump_snapshot("packed_offloading_cg.pkl") torch.cuda.memory._record_memory_history(enabled=None) print(f'packed_moe_expert_offloading_reset dump_snapshot') From 8e5857ccebf8a7fc2c1f874196e998d53213dad6 Mon Sep 17 00:00:00 2001 From: a Date: Fri, 21 Nov 2025 18:50:43 -0800 Subject: [PATCH 08/57] Overflow check and assert --- .../core/pipeline_parallel/moe_packed_offload.py | 16 ++++++++++------ megatron/core/transformer/moe/experts.py | 2 -- .../core/transformer/moe/token_dispatcher.py | 2 -- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 3c4384ddbb8..cbf36b7db05 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -70,6 +70,7 @@ def _stash_copy_kernel( overflow_ptr, BLOCK_SIZE: tl.constexpr, num_iterations: tl.constexpr, + max_tokens: tl.constexpr, ): """Triton kernel to copy tensor data to stash buffer. @@ -97,7 +98,6 @@ def _stash_copy_kernel( alloc_offset = tl.load(alloc_offset_ptr) free_offset = tl.load(free_offset_ptr) capacity = tl.load(capacity_ptr) - # Only the first thread checks capacity # Do this BEFORE the loop so it always happens overflow = False @@ -107,7 +107,7 @@ def _stash_copy_kernel( avail_space = -avail_space else: avail_space = capacity - avail_space - if avail_space < size: + if avail_space < size or max_tokens < size: overflow = True if pid == 0 and overflow: tl.store(overflow_ptr, 1) @@ -341,7 +341,7 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): stash_buffer.overflow, # Read+Write: Over capacity flag updated by kernel BLOCK_SIZE=BLOCK_SIZE, num_iterations=num_iterations, -# max_tokens=self.max_tokens, + max_tokens=self.max_tokens*self.hidden_size, ) # save reference to original tensor to avoid deallocation before offload is complete @@ -803,7 +803,8 @@ def packed_moe_expert_offloading_reset(enabled=True): offload_manager.status = 'capture' elif offload_manager.status == 'capture': offload_manager.status = 'captured' - offload_manager.allocate_offload_pages(stash_buffer_size_factor=1.10) # 10% extra to account for overhead + stash_buffer_size_factor = float(os.getenv('STASH_BUFFER_SIZE_FACTOR', '1.10')) + offload_manager.allocate_offload_pages(stash_buffer_size_factor=stash_buffer_size_factor) debug_print(f'packed_moe_expert_offloading_reset captured schedule: {offload_manager._pp_schedule}') debug_print(f'packed_moe_expert_offloading_reset max_pages_per_vp_stage: {offload_manager.max_pages_per_vp_stage}') elif offload_manager.status == 'captured': @@ -812,11 +813,14 @@ def packed_moe_expert_offloading_reset(enabled=True): debug_print(f'packed_moe_expert_offloading_reset unknown status: {offload_manager.status}') if offload_manager.status == 'captured': - offload_manager.overflow.zero_() + if not torch.cuda.is_current_stream_capturing(): + overflow = offload_manager.overflow.item() + assert overflow == 0, f"PackedOffloadManager overflow!!!" + for vp_buffers in offload_manager.stash_buffers: for dtype in vp_buffers.keys(): vp_buffers[dtype].reset() - + offload_manager.overflow.zero_() offload_manager.current_layer = [1 for _ in range(offload_manager.vp_size)] offload_manager.current_microbatch = [1 for _ in range(offload_manager.vp_size)] assert len(offload_manager.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {offload_manager.packed_tensors_to_offload}" diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index fe40bb83f07..ab6f42f98a6 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -709,8 +709,6 @@ def __init__( self.packed_offload_moe_act = self.config.packed_moe_expert_offloading and "moe_act" in self.config.offload_modules self.packed_offload_expert_fc2 = self.config.packed_moe_expert_offloading and "expert_fc2" in self.config.offload_modules - if torch.distributed.get_rank() == 0: - print(f'packed_offload_expert_fc1 {self.packed_offload_expert_fc1}, packed_offload_moe_act {self.packed_offload_moe_act}, packed_offload_expert_fc2 {self.packed_offload_expert_fc2}') self.activation_recompute = ( self.config.recompute_granularity == 'selective' and "moe_act" in self.config.recompute_modules diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 1c40bb7f7c2..f480de5b632 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1027,8 +1027,6 @@ def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor, budget_ num_tokens = routing_map.shape[0] self.routing_map = routing_map.reshape(num_tokens, self.num_experts) self.token_probs = probs.reshape(num_tokens, self.num_experts) - #if torch.distributed.get_rank() == 0: - # print (f'setup_metadata budget_local {budget_local}') if budget_local is not None: self.num_dispatched_tokens = budget_local self.num_permuted_tokens = budget_local From b2e77eb4518fd2c11d66d030c865a8ac2a92b04e Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Sat, 22 Nov 2025 15:39:03 +0800 Subject: [PATCH 09/57] Check in temporary solution for detecing overflow in receiving buffer Optimize packed stashing by not setting poped tensor with 0's --- megatron/core/full_cuda_graph.py | 8 ++ .../pipeline_parallel/moe_packed_offload.py | 8 +- megatron/core/transformer/moe/moe_utils.py | 96 +++++++++++++++++++ .../core/transformer/moe/token_dispatcher.py | 28 +++++- 4 files changed, 134 insertions(+), 6 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index eccac3388b7..2d7686d860a 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -192,6 +192,14 @@ def __call__(self, *args, **kwargs): packed_moe_expert_offloading_reset(enabled=self.packed_moe_expert_offloading) FullCudaGraphWrapper.cuda_graph[training_str].replay() + # Check if there is any overflow in the receiving buffer + # TODO: Hacky for now. Will improve this after moving the budget check logic into HybridEP + for model_chunk in model: + for layer in model_chunk.module.module.decoder.layers: + mlp = layer.mlp + if hasattr(mlp, 'token_dispatcher'): + if not mlp.token_dispatcher.under_budget.item(): + raise Exception(f"Rank {torch.distributed.get_rank()} overbudget") self.next_iter(training_str) return FullCudaGraphWrapper.result[training_str] diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index cbf36b7db05..33849fc8321 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -271,7 +271,7 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=Non self.original_shape = list(tensor.shape) self.num_elements = tensor.numel() self.element_size = tensor.element_size() - self.hidden_size = self.num_elements // self.original_shape[0] + self.hidden_size = self.original_shape[1] self.dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype self.device = tensor.device @@ -329,7 +329,6 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): # Launch Triton kernel to copy data # self.offload_stream.wait_stream(torch.cuda.current_stream()) # with torch.cuda.stream(self.offload_stream): - # TODO: make this async. Something unexpected with TE on deallocate the tensor _stash_copy_kernel[grid]( flat_tensor, @@ -343,7 +342,6 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): num_iterations=num_iterations, max_tokens=self.max_tokens*self.hidden_size, ) - # save reference to original tensor to avoid deallocation before offload is complete self._original_tensor = self._tensor # set tensor to None. This will be replaced by reload_from_stash. @@ -357,7 +355,7 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): if not HAVE_TRITON: raise RuntimeError("Triton is required for PackedTensor.reload_from_stash(). Please install triton.") if isinstance(self._original_tensor, MXFP8Tensor): - columnwise_data = torch.zeros(self.original_shape, dtype=self.dtype, device=self.device) + columnwise_data = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) self._tensor = MXFP8Tensor( shape=self._original_tensor.shape, dtype=self._original_tensor.dtype, @@ -370,7 +368,7 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): ) flat_tensor = self._tensor._columnwise_data.flatten() else: - self._tensor = torch.zeros(self.original_shape, dtype=self.dtype, device=self.device) + self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) flat_tensor = self._tensor.flatten() num_elements_tensor = self.num_tokens_tensor.mul(self.hidden_size) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index dbcc25a905c..dc4181df43c 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -44,6 +44,10 @@ HAVE_TE = False +import triton +import triton.language as tl + + def switch_load_balancing_loss_func( probs: torch.Tensor, tokens_per_expert: torch.Tensor, @@ -1540,3 +1544,95 @@ def wrapped_func(moe_layer, *args, **kwargs): return wrapped_func return decorator + +@triton.jit +def _drop_routing_map_kernel( + routing_map_ptr, + under_budget_ptr, + routing_map_dropped_ptr, + num_elements: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel to drop routing map based on budget constraints. + + Args: + routing_map_ptr: Pointer to the input routing_map tensor + under_budget_ptr: Pointer to the boolean tensor indicating if all EP ranks are under budget + routing_map_dropped_ptr: Pointer to the output routing_map tensor + num_elements: Total number of elements to process + BLOCK_SIZE: Block size for Triton kernel + """ + # Get the program ID + pid = tl.program_id(axis=0) + + # Read the under_budget value (scalar tensor with single element) + under_budget_val = tl.load(under_budget_ptr) + + # Calculate the offset for this program + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + # Load the routing_map values + mask = offset < num_elements + routing_map_val = tl.load(routing_map_ptr + offset, mask=mask, other=0.0) + + # Multiply routing_map by under_budget: if under_budget is 0 (False), output is 0; if 1 (True), output is routing_map_val + output_val = routing_map_val * under_budget_val + + # Store the result + tl.store(routing_map_dropped_ptr + offset, output_val, mask=mask) + + +def drop_routing_map_triton( + routing_map: torch.Tensor, + budget: torch.Tensor, + num_tokens_per_ep_rank: torch.Tensor +) -> torch.Tensor: + """Drop tokens from routing_map that exceed the budget per EP rank using Triton. + + Args: + routing_map: Tensor indicating which tokens are assigned to each expert. + budget: Integer tensor with the maximum number of tokens per EP rank. + num_tokens_per_ep_rank: Tensor with actual number of tokens per EP rank. + + Returns: + Modified routing_map with tokens exceeding budget zeroed out if any EP rank + exceeds budget, otherwise returns the original routing_map. + """ + + # Calculate boolean tensor: under_budget is True only if ALL EP ranks are under budget + under_budget = (num_tokens_per_ep_rank <= budget).all() + + # Convert boolean to int8 + under_budget_int = under_budget.to(torch.int8) + + # Convert routing_map to numeric type if it's boolean + if routing_map.dtype == torch.bool: + routing_map_numeric = routing_map.to(torch.int8) + else: + routing_map_numeric = routing_map + + # Create output tensor with same dtype as input + routing_map_dropped = torch.empty_like(routing_map_numeric) + + # Flatten tensors for kernel processing + routing_map_flat = routing_map_numeric.flatten() + num_elements = routing_map_flat.numel() + + # Determine grid size + BLOCK_SIZE = 1024 + grid = (triton.cdiv(num_elements, BLOCK_SIZE),) + + # Launch kernel with under_budget tensor pointer (as int8) + _drop_routing_map_kernel[grid]( + routing_map_flat, + under_budget_int, + routing_map_dropped.flatten(), + num_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Convert back to boolean if original was boolean + if routing_map.dtype == torch.bool: + routing_map_dropped = routing_map_dropped.to(torch.bool) + + return routing_map_dropped, under_budget.to(torch.bool) diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index f480de5b632..e46faf8eb0c 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -33,6 +33,7 @@ permute, sort_chunks_by_idxs, unpermute, + drop_routing_map_triton, ) from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.transformer_config import TransformerConfig @@ -1406,6 +1407,7 @@ def __init__( ) self.packed_offloading_capacity_factor = self.config.moe_expert_capacity_factor_for_packed_offloading self.budget_local_gpu = None + self.under_budget = torch.ones(1, dtype=torch.bool, device='cuda') def set_shared_experts(self, shared_experts): raise NotImplementedError( @@ -1447,7 +1449,6 @@ def _initialize_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor) - self.budget_local_gpu = None return routing_map, probs - @jit_fuser def dispatch_preprocess( self, hidden_states: torch.Tensor, routing_map: torch.Tensor, probs: torch.Tensor ): @@ -1474,9 +1475,12 @@ def dispatch_preprocess( budget_local = None if self.packed_offloading_capacity_factor is not None: budget_local = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) + routing_map, under_budget = self.budget_check(routing_map, budget_local) + self.under_budget &= under_budget # if self.ep_rank == 0: # print (f'budget_local {budget_local} = {routing_map.shape[0]} x {self.config.moe_router_topk} x {self.packed_offloading_capacity_factor}') self._comm_manager.setup_metadata(routing_map, probs, budget_local) + return hidden_states, self._comm_manager.token_probs def token_dispatch( @@ -1570,3 +1574,25 @@ def combine_postprocess(self, hidden_states: torch.Tensor): The final MoE layer output reshaped to its original dimensions. """ return hidden_states.view(self.hidden_shape) + + def budget_check(self, routing_map, budget): + # TODO: the check should be done in hybridep + num_experts = self.config.num_moe_experts + num_expert_per_ep_rank = num_experts // self.ep_size + num_tokens_per_expert = routing_map.sum(dim=0) + num_tokens_per_expert = ( + gather_from_sequence_parallel_region( + num_tokens_per_expert, group=self.tp_ep_group + ) + .reshape(self.ep_size, self.tp_size, num_experts) + .transpose(0, 1) + ) + + num_global_tokens_per_expert =num_tokens_per_expert.sum(dim=1) + if self.config.fp8: + pad_multiple = get_fp8_align_size(self.config.fp8_recipe) + num_global_tokens_per_expert += -num_global_tokens_per_expert % pad_multiple + num_tokens_per_ep_rank = num_global_tokens_per_expert.view(num_global_tokens_per_expert.shape[0], self.ep_size, -1).sum(dim=-1) + + routing_map_maybe_dropped, under_budget = drop_routing_map_triton(routing_map, budget, num_tokens_per_ep_rank) + return routing_map_maybe_dropped, under_budget From 2947a029c773216e2cb34dd162c6bcefc9d4a534 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Sun, 23 Nov 2025 10:42:10 +0800 Subject: [PATCH 10/57] Reconstruct the stash buffer into a 2D structure Fix race condition in triton kernels --- .../pipeline_parallel/moe_packed_offload.py | 371 ++++++++++++++---- 1 file changed, 292 insertions(+), 79 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 33849fc8321..84c14ac44f9 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -59,6 +59,172 @@ def set_ideal_affinity_for_current_gpu(): pynvml.nvmlDeviceSetCpuAffinity(handle) GLOBAL_BLOCK_SIZE = 1024 + +@triton.jit +def _stash_copy_kernel_2d( + src_ptr, + dst_ptr, + num_tokens_ptr, # Number of tokens to copy + alloc_offset_ptr, # In tokens (read-only) + free_offset_ptr, # In tokens (read-only) + capacity_ptr, # In tokens (read-only) + overflow_ptr, + new_free_offset_ptr, # Output: new free_offset value (written by kernel) + HIDDEN_SIZE: tl.constexpr, # Hidden dimension (compile-time constant) + BLOCK_SIZE: tl.constexpr, # Threads per block (for hidden dimension) + tokens_per_block: tl.constexpr, # Number of tokens each block handles +): + """2D Triton kernel to copy tensor data to stash buffer. + + Grid: (num_blocks,) - fixed number of blocks + Each block handles multiple tokens (tokens_per_block) using a while loop. + Works directly with contiguous 2D tensors [tokens, hidden_size]. + Offsets are tracked in tokens, not elements. + """ + pid = tl.program_id(axis=0) + + # Load parameters (in tokens, not elements) + num_tokens = tl.load(num_tokens_ptr) + alloc_offset = tl.load(alloc_offset_ptr) + free_offset = tl.load(free_offset_ptr) + capacity = tl.load(capacity_ptr) + + # All blocks check for overflow (same computation, avoids race condition) + if free_offset >= alloc_offset: + # No wraparound: available space is from free_offset to capacity, then 0 to alloc_offset + avail_space = capacity - (free_offset - alloc_offset) + else: + # Wraparound: available space is from free_offset to alloc_offset + avail_space = alloc_offset - free_offset + overflow_detected = avail_space < num_tokens + + # Only block 0 writes the overflow flag + if pid == 0 and overflow_detected: + tl.store(overflow_ptr, 1) + + # All blocks return early if overflow detected + if overflow_detected: + return + + # Each block handles multiple tokens + token_start = pid * tokens_per_block + token_end = min(token_start + tokens_per_block, num_tokens) + + # Process tokens assigned to this block + token_idx = token_start + while token_idx < token_end: + # Calculate destination token index with wraparound + dst_token_idx = free_offset + token_idx + if dst_token_idx >= capacity: + dst_token_idx = dst_token_idx - capacity + + # Each thread handles elements of the hidden dimension + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + + # Check if we need masking (only if HIDDEN_SIZE not divisible by BLOCK_SIZE) + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + + # 2D indexing: base + token_idx * HIDDEN_SIZE + hidden_offsets + src_base = src_ptr + token_idx * HIDDEN_SIZE + dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE + + if need_mask: + # Use mask for all iterations when HIDDEN_SIZE not divisible by BLOCK_SIZE + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + # No mask needed - HIDDEN_SIZE is multiple of BLOCK_SIZE + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + + token_idx += 1 + + # Update new_free_offset (only first block writes it) + if pid == 0: + new_free_offset = free_offset + num_tokens + if new_free_offset >= capacity: + new_free_offset = new_free_offset - capacity + tl.store(new_free_offset_ptr, new_free_offset) + +@triton.jit +def _stash_pop_kernel_2d( + src_ptr, + dst_ptr, + num_tokens_ptr, # Number of tokens to reload + tensor_offset_ptr, # In tokens - where data was stashed (read-only) + alloc_offset_ptr, # In tokens (read-only, not used in pop) + free_offset_ptr, # In tokens (write: updated directly by kernel) + capacity_ptr, # In tokens (read-only) + HIDDEN_SIZE: tl.constexpr, # Hidden dimension (compile-time constant) + BLOCK_SIZE: tl.constexpr, # Threads per block (for hidden dimension) + tokens_per_block: tl.constexpr, # Number of tokens each block handles +): + """2D Triton kernel to reload tensor data from stash buffer. + + Grid: (num_blocks,) - fixed number of blocks + Each block handles multiple tokens (tokens_per_block) using a while loop. + Works directly with contiguous 2D tensors [tokens, hidden_size]. + Offsets are tracked in tokens, not elements. + Uses LIFO (stack) semantics - moves free_offset backward after popping. + """ + pid = tl.program_id(axis=0) + + # Load parameters (in tokens, not elements) + num_tokens = tl.load(num_tokens_ptr) + tensor_offset = tl.load(tensor_offset_ptr) # Where data was stashed + capacity = tl.load(capacity_ptr) + + # Each block handles multiple tokens + token_start = pid * tokens_per_block + token_end = min(token_start + tokens_per_block, num_tokens) + + # Process tokens assigned to this block + token_idx = token_start + while token_idx < token_end: + # Calculate source token index with wraparound + src_token_idx = tensor_offset + token_idx + if src_token_idx >= capacity: + src_token_idx = src_token_idx - capacity + + # Each thread handles elements of the hidden dimension + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + + # Check if we need masking + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + + # 2D indexing + src_base = src_ptr + src_token_idx * HIDDEN_SIZE + dst_base = dst_ptr + token_idx * HIDDEN_SIZE + + if need_mask: + # Use mask for all iterations when HIDDEN_SIZE not divisible by BLOCK_SIZE + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + # No mask needed + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + + token_idx += 1 + + # For LIFO (stack) behavior: move free_offset backward + # After popping, free_offset should be at tensor_offset (freeing the space we just read) + if pid == 0: + # The data was stashed at tensor_offset, so after popping, free_offset moves back to tensor_offset + tl.store(free_offset_ptr, tensor_offset) + @triton.jit def _stash_copy_kernel( src_ptr, @@ -227,32 +393,48 @@ def _stash_pop_kernel( class StashBuffer: """ - A class to represent a stash buffer. + A class to represent a 2D stash buffer. + + The buffer is organized as [num_tokens, hidden_size]. + Offsets (free_offset, alloc_offset) are tracked in tokens, not elements. """ - def __init__(self, size, device, overflow, dtype): - + def __init__(self, num_tokens, hidden_size, device, overflow, dtype): + """ + Args: + num_tokens: Maximum number of tokens the buffer can hold + hidden_size: Hidden dimension size + device: Device for the buffer + overflow: Overflow flag tensor (shared across all buffers) + dtype: Data type + """ self.buffer = None + self.hidden_size = hidden_size + self.num_tokens_capacity = num_tokens + + # Create 2D buffer [num_tokens, hidden_size] if os.getenv('PACKED_OFFLOAD_CPU', '0') == '1': - self.buffer = torch.empty(size, dtype=dtype, device='cpu', pin_memory=True) + self.buffer = torch.empty((num_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True) else: - self.buffer = torch.empty(size, dtype=dtype, device=device) - self.overflow = overflow # GPU flag + self.buffer = torch.empty((num_tokens, hidden_size), dtype=dtype, device=device) + + self.overflow = overflow # GPU flag (shared) self.device = device - self.free_offset = torch.zeros(1, dtype=torch.int64, device=device) # start offset of free space - self.alloc_offset = torch.zeros(1, dtype=torch.int64, device=device) # start offset of allocations + + # Offsets are in TOKENS + self.free_offset = torch.zeros(1, dtype=torch.int64, device=device) # tail (write pointer) + self.alloc_offset = torch.zeros(1, dtype=torch.int64, device=device) # head (read pointer) self.capacity = torch.zeros(1, dtype=torch.int64, device=device) - self.capacity.fill_(size) + self.capacity.fill_(num_tokens) # Capacity in tokens self.dtype = dtype + def reset(self): - """Reset the stash buffer.""" - #assert self.alloc_offset.item() == self.free_offset.item(), f"alloc_offset {self.alloc_offset.item()} != free_offset {self.free_offset.item()}" - #print + """Reset the stash buffer offsets.""" self.free_offset.zero_() self.alloc_offset.zero_() def __repr__(self): - return f"StashBuffer(capacity={self.capacity}, device={self.device})" + return f"StashBuffer(capacity={self.num_tokens_capacity} tokens, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" class PackedTensor: @@ -269,7 +451,7 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=Non self.max_tokens = max_tokens # Original tensor information self.original_shape = list(tensor.shape) - self.num_elements = tensor.numel() + self.max_num_tokens = self.original_shape[0] self.element_size = tensor.element_size() self.hidden_size = self.original_shape[1] self.dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype @@ -307,41 +489,51 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): self._tensor = self._tensor.contiguous() if self.num_tokens_tensor.dim() == 0: self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) - num_elements_tensor = self.num_tokens_tensor.mul(self.hidden_size) - # Flatten the tensor to get total number of elements - flat_tensor = self._tensor.flatten() if not isinstance(self._tensor, MXFP8Tensor) else self._tensor._columnwise_data.flatten() + + # Get 2D tensor (no flattening) + if isinstance(self._tensor, MXFP8Tensor): + tensor_to_copy = self._tensor._columnwise_data + else: + tensor_to_copy = self._tensor # Determine grid size with cap on max blocks BLOCK_SIZE = GLOBAL_BLOCK_SIZE - max_size = flat_tensor.numel() - total_blocks_needed = triton.cdiv(max_size, BLOCK_SIZE) + total_blocks_needed = self.max_num_tokens # Ideally 1 block per token - # Cap the number of blocks and calculate iterations per block + # Cap the number of blocks and calculate tokens per block num_blocks = min(total_blocks_needed, max_blocks) - num_iterations = triton.cdiv(total_blocks_needed, num_blocks) + tokens_per_block = triton.cdiv(self.max_num_tokens, num_blocks) if DEBUG: - debug_print (f"offload_to_stash ({self.layer_name}) {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} num_elements {num_elements_tensor.item()} max_blocks {max_blocks} total_blocks_needed {total_blocks_needed} num_blocks {num_blocks} num_iterations {num_iterations} oveflow {stash_buffer.overflow.item()}") + debug_print (f"offload_to_stash ({self.layer_name}) {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks} tokens_per_block {tokens_per_block} overflow {stash_buffer.overflow.item()}") # grid = (num_blocks,) self.stash_buffer_offset = stash_buffer.free_offset.clone() - # Launch Triton kernel to copy data + # Create temporary tensor for new offset (kernel will write to this) + new_free_offset_tensor = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch Triton kernel to copy data (2D version) # self.offload_stream.wait_stream(torch.cuda.current_stream()) # with torch.cuda.stream(self.offload_stream): # TODO: make this async. Something unexpected with TE on deallocate the tensor - _stash_copy_kernel[grid]( - flat_tensor, + _stash_copy_kernel_2d[grid]( + tensor_to_copy, stash_buffer.buffer, - num_elements_tensor, - stash_buffer.alloc_offset, # Read-only: Write boundary - stash_buffer.free_offset, # Read+Write: Start offset for next offload - stash_buffer.capacity, # Read-only: Capacity of the buffer - stash_buffer.overflow, # Read+Write: Over capacity flag updated by kernel + self.num_tokens_tensor, # Use stored num_tokens (not from shape) + stash_buffer.alloc_offset, # Read-only: Write boundary (in tokens) + stash_buffer.free_offset, # Read-only: Current offset + stash_buffer.capacity, # Read-only: Capacity of the buffer (in tokens) + stash_buffer.overflow, # Read+Write: Over capacity flag + new_free_offset_tensor, # Write: New free_offset computed by kernel + HIDDEN_SIZE=self.hidden_size, BLOCK_SIZE=BLOCK_SIZE, - num_iterations=num_iterations, - max_tokens=self.max_tokens*self.hidden_size, + tokens_per_block=tokens_per_block, ) + + # Copy new offset value after kernel completes (stream-ordered) + stash_buffer.free_offset.copy_(new_free_offset_tensor) + # save reference to original tensor to avoid deallocation before offload is complete self._original_tensor = self._tensor # set tensor to None. This will be replaced by reload_from_stash. @@ -366,44 +558,45 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): columnwise_scale_inv=self._original_tensor._columnwise_scale_inv, quantizer=self._original_tensor._quantizer, ) - flat_tensor = self._tensor._columnwise_data.flatten() + tensor_to_reload = self._tensor._columnwise_data else: self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) - flat_tensor = self._tensor.flatten() - - num_elements_tensor = self.num_tokens_tensor.mul(self.hidden_size) + tensor_to_reload = self._tensor + # Determine grid size with cap on max blocks BLOCK_SIZE = GLOBAL_BLOCK_SIZE - max_size = self.num_elements - total_blocks_needed = triton.cdiv(max_size, BLOCK_SIZE) + total_blocks_needed = self.max_num_tokens # Ideally 1 block per token - # Cap the number of blocks and calculate iterations per block + # Cap the number of blocks and calculate tokens per block num_blocks = min(total_blocks_needed, max_blocks) - num_iterations = triton.cdiv(total_blocks_needed, num_blocks) + tokens_per_block = triton.cdiv(self.max_num_tokens, num_blocks) if DEBUG: - debug_print (f"reload_from_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} num_elements {num_elements_tensor.item()} max_blocks {max_blocks} total_blocks_needed {total_blocks_needed} num_blocks {num_blocks} num_iterations {num_iterations}") + debug_print (f"reload_from_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks} tokens_per_block {tokens_per_block}") # grid = (num_blocks,) - # Launch Triton kernel to copy data + # Launch Triton kernel to copy data (2D version) # self.offload_stream.wait_stream(torch.cuda.current_stream()) # with torch.cuda.stream(self.offload_stream): # TODO: make this async. Something unexpected with TE on deallocate the tensor - _stash_pop_kernel[grid]( + # Note: free_offset is directly updated by the kernel (LIFO stack behavior) + _stash_pop_kernel_2d[grid]( stash_buffer.buffer, - flat_tensor, - num_elements_tensor, - self.stash_buffer_offset, # Read-only: Start offset for reload - stash_buffer.alloc_offset, # Read+write: Free stash buffer for model chunk - stash_buffer.free_offset, # Read: Start offset for offload - stash_buffer.capacity, # Read-only: Capacity of the buffer + tensor_to_reload, + self.num_tokens_tensor, # Use stored num_tokens (not from shape) + self.stash_buffer_offset, # Read-only: Start offset for reload (in tokens) + stash_buffer.alloc_offset, # Read-only: Not used in pop kernel + stash_buffer.free_offset, # Write: Moved backward by kernel (LIFO) + stash_buffer.capacity, # Read-only: Capacity of the buffer (in tokens) + HIDDEN_SIZE=self.hidden_size, BLOCK_SIZE=BLOCK_SIZE, - num_iterations=num_iterations, + tokens_per_block=tokens_per_block, ) + #torch.cuda.synchronize() if DEBUG: debug_print (f"After reload_from_stash reload_offset {self.stash_buffer_offset.item()} alloc_offset {stash_buffer.alloc_offset.item()} free_offset {stash_buffer.free_offset.item()} capacity {stash_buffer.capacity.item()}") @@ -511,9 +704,9 @@ def __init__(self): self.current_microbatch = None self.current_schedule_index = None - self.page_size = GLOBAL_BLOCK_SIZE - self.max_pages_per_vp_stage = None - self.temp_pages_per_vp_stage = None + # Track max tokens needed per vp_stage, dtype, and hidden_size + self.max_tokens_per_vp_stage = None + self.temp_tokens_per_vp_stage = None self.num_tokens_tensor = None self.max_num_tokens = None self.stash_buffers = None @@ -568,7 +761,8 @@ def offload_packed_tensors(self, pp_schedule_layer): while len(self.packed_tensors_to_offload) > 0: packed_tensor = self.packed_tensors_to_offload.pop(0) - packed_tensor.offload_to_stash(self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype]) + stash_buffer = self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype][packed_tensor.hidden_size] + packed_tensor.offload_to_stash(stash_buffer) self.packed_tensors_to_reload[pp_schedule_layer].append(packed_tensor) self.packed_tensors_offload_in_progress.append(packed_tensor) else: @@ -591,23 +785,33 @@ def reload_packed_tensors(self, pp_schedule_layer): debug_print(f"reload_packed_tensors {count}") while len(self.packed_tensors_to_reload[pp_schedule_layer]) > 0: packed_tensor = self.packed_tensors_to_reload[pp_schedule_layer].pop(0) - packed_tensor.reload_from_stash(self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype]) + stash_buffer = self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype][packed_tensor.hidden_size] + packed_tensor.reload_from_stash(stash_buffer) else: pass assert len(self.packed_tensors_to_reload[pp_schedule_layer]) == 0, f"packed_tensors_to_reload {pp_schedule_layer} is not empty {self.packed_tensors_to_reload[pp_schedule_layer]}" - def allocate_offload_pages(self, stash_buffer_size_factor=1.10): - """Allocate offload pages for each vp stage.""" + def allocate_offload_buffers(self, stash_buffer_size_factor=1.10): + """Allocate offload buffers for each vp stage, organized by [vp_stage][dtype][hidden_size].""" self.stash_buffers = [] self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) + for vp_stage in range(self.vp_size): self.stash_buffers.append({}) - for dtype in self.max_pages_per_vp_stage[vp_stage]: - self.max_pages_per_vp_stage[vp_stage][dtype] = int(self.max_pages_per_vp_stage[vp_stage][dtype] * stash_buffer_size_factor) - self.stash_buffers[vp_stage][dtype] = StashBuffer(self.max_pages_per_vp_stage[vp_stage][dtype]*GLOBAL_BLOCK_SIZE, self.device, self.overflow, dtype) - if torch.distributed.get_rank() == 0: - print(f'allocated stash buffer {vp_stage} {dtype} {self.stash_buffers[vp_stage][dtype]}') + for dtype in self.max_tokens_per_vp_stage[vp_stage]: + self.stash_buffers[vp_stage][dtype] = {} + for hidden_size in self.max_tokens_per_vp_stage[vp_stage][dtype]: + # Calculate number of tokens we can store (with safety factor) + num_tokens = int(self.max_tokens_per_vp_stage[vp_stage][dtype][hidden_size] * stash_buffer_size_factor) + + # Create 2D buffer + self.stash_buffers[vp_stage][dtype][hidden_size] = StashBuffer( + num_tokens, hidden_size, self.device, self.overflow, dtype + ) + + if torch.distributed.get_rank() == 0: + print(f'allocated stash buffer vp_stage={vp_stage} dtype={dtype} hidden_size={hidden_size}: {self.stash_buffers[vp_stage][dtype][hidden_size]}') def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): """Update the pp schedule.""" @@ -658,15 +862,26 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if self.status == 'capture': self.num_tokens = self.num_tokens_tensor.item() - num_elements = tensor.numel() * self.num_tokens // self.max_num_tokens - num_pages = (num_elements + self.page_size - 1) // self.page_size dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype - if dtype not in self.temp_pages_per_vp_stage[self.current_vp_stage]: - self.temp_pages_per_vp_stage[self.current_vp_stage][dtype] = 0 - self.max_pages_per_vp_stage[self.current_vp_stage][dtype] = 0 - self.temp_pages_per_vp_stage[self.current_vp_stage][dtype] += num_pages - self.max_pages_per_vp_stage[self.current_vp_stage][dtype] = max(self.max_pages_per_vp_stage[self.current_vp_stage][dtype], self.temp_pages_per_vp_stage[self.current_vp_stage][dtype]) + # Get hidden_size from tensor shape + if isinstance(tensor, MXFP8Tensor): + hidden_size = tensor._columnwise_data.shape[1] if tensor._columnwise_data.ndim > 1 else tensor._columnwise_data.numel() + else: + hidden_size = tensor.shape[1] if tensor.ndim > 1 else tensor.numel() + + if dtype not in self.temp_tokens_per_vp_stage[self.current_vp_stage]: + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} + self.max_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} + if hidden_size not in self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype]: + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 + self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 + + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] += self.num_tokens + self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = max( + self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size], + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] + ) # Since capture stage does not use CUDA graph, we can truncate the saved tensor to actual num_tokens # Truncate the tensor to the actual number of tokens @@ -694,9 +909,7 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: if isinstance(saved_state, PackedTensor): if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() - num_elements = saved_state.num_elements * num_tokens // self.max_num_tokens - num_pages = (num_elements + self.page_size - 1) // self.page_size - self.temp_pages_per_vp_stage[saved_state.vp_stage][saved_state.dtype] -= num_pages + self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][saved_state.hidden_size] -= num_tokens # Pad the tensor to the max number of tokens npad = self.max_num_tokens - num_tokens @@ -743,7 +956,6 @@ def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num offload_manager = PackedOffloadManager.get_instance() offload_manager.max_num_tokens = max_num_tokens assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) - offload_manager.num_tokens_tensor = num_tokens_tensor offload_manager.set_current_layer_name(name) if name is not None else None pack_unpack_context = torch.autograd.graph.saved_tensors_hooks(offload_manager.on_save_for_backward, offload_manager.on_get_saved_tensor) @@ -767,9 +979,9 @@ def packed_moe_expert_offloading_init_chunk_handler(vp_size, vp_stage): offload_manager.vp_size = vp_size else: offload_manager.vp_size = 1 - if offload_manager.max_pages_per_vp_stage is None: - offload_manager.max_pages_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] - offload_manager.temp_pages_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] + if offload_manager.max_tokens_per_vp_stage is None: + offload_manager.max_tokens_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] + offload_manager.temp_tokens_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] def packed_moe_expert_offloading_set_last_layer(is_last_layer=False): """Set the last layer flag.""" @@ -802,9 +1014,9 @@ def packed_moe_expert_offloading_reset(enabled=True): elif offload_manager.status == 'capture': offload_manager.status = 'captured' stash_buffer_size_factor = float(os.getenv('STASH_BUFFER_SIZE_FACTOR', '1.10')) - offload_manager.allocate_offload_pages(stash_buffer_size_factor=stash_buffer_size_factor) + offload_manager.allocate_offload_buffers(stash_buffer_size_factor=stash_buffer_size_factor) debug_print(f'packed_moe_expert_offloading_reset captured schedule: {offload_manager._pp_schedule}') - debug_print(f'packed_moe_expert_offloading_reset max_pages_per_vp_stage: {offload_manager.max_pages_per_vp_stage}') + debug_print(f'packed_moe_expert_offloading_reset max_tokens_per_vp_stage: {offload_manager.max_tokens_per_vp_stage}') elif offload_manager.status == 'captured': pass else: @@ -817,7 +1029,8 @@ def packed_moe_expert_offloading_reset(enabled=True): for vp_buffers in offload_manager.stash_buffers: for dtype in vp_buffers.keys(): - vp_buffers[dtype].reset() + for hidden_size in vp_buffers[dtype].keys(): + vp_buffers[dtype][hidden_size].reset() offload_manager.overflow.zero_() offload_manager.current_layer = [1 for _ in range(offload_manager.vp_size)] offload_manager.current_microbatch = [1 for _ in range(offload_manager.vp_size)] From a45b7fe42419105ea2bd8ae58b9d45917bf6fe0b Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Mon, 24 Nov 2025 11:37:49 +0800 Subject: [PATCH 11/57] Refactor the code to check overflow in HybridEP receiving buffer --- megatron/core/full_cuda_graph.py | 23 +++-- megatron/core/transformer/moe/moe_utils.py | 24 +++--- .../core/transformer/moe/token_dispatcher.py | 83 +++++++++---------- 3 files changed, 66 insertions(+), 64 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 2d7686d860a..38257e8cf7c 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -191,18 +191,23 @@ def __call__(self, *args, **kwargs): if training_str == 'training': packed_moe_expert_offloading_reset(enabled=self.packed_moe_expert_offloading) FullCudaGraphWrapper.cuda_graph[training_str].replay() - - # Check if there is any overflow in the receiving buffer - # TODO: Hacky for now. Will improve this after moving the budget check logic into HybridEP - for model_chunk in model: - for layer in model_chunk.module.module.decoder.layers: - mlp = layer.mlp - if hasattr(mlp, 'token_dispatcher'): - if not mlp.token_dispatcher.under_budget.item(): - raise Exception(f"Rank {torch.distributed.get_rank()} overbudget") + self.speculative_cuda_graph_check(model) self.next_iter(training_str) return FullCudaGraphWrapper.result[training_str] + def speculative_cuda_graph_check(self, model): + ''' check speculative execution modules ''' + if self.packed_moe_expert_offloading: + # Check if there is any overflow in the receiving buffer + over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') + for model_chunk in model: + for layer in model_chunk.module.module.decoder.layers: + mlp = layer.mlp + if hasattr(mlp, 'token_dispatcher') and hasattr(mlp.token_dispatcher, 'check_over_budget'): + over_budget |= mlp.token_dispatcher.check_over_budget() + if over_budget.item(): + raise Exception(f"Rank {torch.distributed.get_rank()} overbudget") + def curr_iter(self, stage): """Return current training/validation iteration.""" return FullCudaGraphWrapper.curr_iteration[stage] diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index dc4181df43c..472477661eb 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -1548,7 +1548,7 @@ def wrapped_func(moe_layer, *args, **kwargs): @triton.jit def _drop_routing_map_kernel( routing_map_ptr, - under_budget_ptr, + over_budget_ptr, routing_map_dropped_ptr, num_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -1557,7 +1557,7 @@ def _drop_routing_map_kernel( Args: routing_map_ptr: Pointer to the input routing_map tensor - under_budget_ptr: Pointer to the boolean tensor indicating if all EP ranks are under budget + over_budget_ptr: Pointer to the boolean tensor indicating if any EP rank is over budget routing_map_dropped_ptr: Pointer to the output routing_map tensor num_elements: Total number of elements to process BLOCK_SIZE: Block size for Triton kernel @@ -1565,8 +1565,8 @@ def _drop_routing_map_kernel( # Get the program ID pid = tl.program_id(axis=0) - # Read the under_budget value (scalar tensor with single element) - under_budget_val = tl.load(under_budget_ptr) + # Read the over_budget value (scalar tensor with single element) + over_budget_val = tl.load(over_budget_ptr) # Calculate the offset for this program offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -1575,8 +1575,8 @@ def _drop_routing_map_kernel( mask = offset < num_elements routing_map_val = tl.load(routing_map_ptr + offset, mask=mask, other=0.0) - # Multiply routing_map by under_budget: if under_budget is 0 (False), output is 0; if 1 (True), output is routing_map_val - output_val = routing_map_val * under_budget_val + # If over_budget is 1 (True), output is 0 (drop); if over_budget is 0 (False), output is routing_map_val (keep) + output_val = routing_map_val * (1 - over_budget_val) # Store the result tl.store(routing_map_dropped_ptr + offset, output_val, mask=mask) @@ -1599,11 +1599,11 @@ def drop_routing_map_triton( exceeds budget, otherwise returns the original routing_map. """ - # Calculate boolean tensor: under_budget is True only if ALL EP ranks are under budget - under_budget = (num_tokens_per_ep_rank <= budget).all() + # Calculate boolean tensor: over_budget is True if ANY EP rank exceeds budget + over_budget = (num_tokens_per_ep_rank > budget).any() # Convert boolean to int8 - under_budget_int = under_budget.to(torch.int8) + over_budget_int = over_budget.to(torch.int8) # Convert routing_map to numeric type if it's boolean if routing_map.dtype == torch.bool: @@ -1622,10 +1622,10 @@ def drop_routing_map_triton( BLOCK_SIZE = 1024 grid = (triton.cdiv(num_elements, BLOCK_SIZE),) - # Launch kernel with under_budget tensor pointer (as int8) + # Launch kernel with over_budget tensor pointer (as int8) _drop_routing_map_kernel[grid]( routing_map_flat, - under_budget_int, + over_budget_int, routing_map_dropped.flatten(), num_elements, BLOCK_SIZE=BLOCK_SIZE, @@ -1635,4 +1635,4 @@ def drop_routing_map_triton( if routing_map.dtype == torch.bool: routing_map_dropped = routing_map_dropped.to(torch.bool) - return routing_map_dropped, under_budget.to(torch.bool) + return routing_map_dropped, over_budget.to(torch.bool) diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index e46faf8eb0c..b93ff8d4167 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1024,13 +1024,43 @@ def __init__( "https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep." ) - def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor, budget_local: int = None): + self.packed_offloading_capacity_factor = self.config.moe_expert_capacity_factor_for_packed_offloading + self.over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') + + def budget_check(self, routing_map, budget): + # TODO: the check should be done in hybridep to avoid the AG below + # routing_map: [num_local_tokens, world_size, num_local_experts] + num_local_tokens_per_expert = routing_map.sum(dim=0).flatten() + num_tokens_per_expert = torch.empty( + self.group.size(), + self.num_experts, + device=num_local_tokens_per_expert.device, + dtype=num_local_tokens_per_expert.dtype, + ) + torch.distributed.all_gather_into_tensor( + num_tokens_per_expert, num_local_tokens_per_expert, self.group + ) + + num_global_tokens_per_expert =num_tokens_per_expert.sum(dim=0) + if self.config.fp8: + pad_multiple = get_fp8_align_size(self.config.fp8_recipe) + num_global_tokens_per_expert += -num_global_tokens_per_expert % pad_multiple + num_tokens_per_ep_rank = num_global_tokens_per_expert.view(routing_map.shape[1], routing_map.shape[2]).sum(dim=-1) + routing_map_maybe_dropped, over_budget = drop_routing_map_triton(routing_map, budget, num_tokens_per_ep_rank) + return routing_map_maybe_dropped, over_budget + + def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): num_tokens = routing_map.shape[0] self.routing_map = routing_map.reshape(num_tokens, self.num_experts) self.token_probs = probs.reshape(num_tokens, self.num_experts) - if budget_local is not None: - self.num_dispatched_tokens = budget_local - self.num_permuted_tokens = budget_local + + if self.packed_offloading_capacity_factor is not None: + budget = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) + routing_map_maybe_dropped, over_budget = self.budget_check(routing_map, budget) + self.over_budget |= over_budget + self.num_dispatched_tokens = budget + self.num_permuted_tokens = budget + self.routing_map = routing_map_maybe_dropped.reshape(num_tokens, self.num_experts) # Compute the capacity for each expert at the drop_and_pad mode if self.drop_and_pad: num_out_tokens = num_tokens * self.config.moe_router_topk @@ -1405,9 +1435,6 @@ def __init__( "Please set --moe-flex-dispatcher-backend=deepep or " "--moe-flex-dispatcher-backend=hybridep" ) - self.packed_offloading_capacity_factor = self.config.moe_expert_capacity_factor_for_packed_offloading - self.budget_local_gpu = None - self.under_budget = torch.ones(1, dtype=torch.bool, device='cuda') def set_shared_experts(self, shared_experts): raise NotImplementedError( @@ -1440,13 +1467,6 @@ def _initialize_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor) - .reshape(num_local_tokens, world_size, self.num_local_experts) ).contiguous() - if self.packed_offloading_capacity_factor is not None: - budget_local = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) - #if self.ep_rank == 0: - # print (f'budget_local {budget_local} = {routing_map.shape[0]} x {self.config.moe_router_topk} x {self.packed_offloading_capacity_factor}') - self.budget_local_gpu = torch.full((1,), budget_local, device='cuda') - else: - self.budget_local_gpu = None return routing_map, probs def dispatch_preprocess( @@ -1472,14 +1492,7 @@ def dispatch_preprocess( # Initialize metadata routing_map, probs = self._initialize_metadata(routing_map, probs) - budget_local = None - if self.packed_offloading_capacity_factor is not None: - budget_local = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) - routing_map, under_budget = self.budget_check(routing_map, budget_local) - self.under_budget &= under_budget -# if self.ep_rank == 0: -# print (f'budget_local {budget_local} = {routing_map.shape[0]} x {self.config.moe_router_topk} x {self.packed_offloading_capacity_factor}') - self._comm_manager.setup_metadata(routing_map, probs, budget_local) + self._comm_manager.setup_metadata(routing_map, probs) return hidden_states, self._comm_manager.token_probs @@ -1575,24 +1588,8 @@ def combine_postprocess(self, hidden_states: torch.Tensor): """ return hidden_states.view(self.hidden_shape) - def budget_check(self, routing_map, budget): - # TODO: the check should be done in hybridep - num_experts = self.config.num_moe_experts - num_expert_per_ep_rank = num_experts // self.ep_size - num_tokens_per_expert = routing_map.sum(dim=0) - num_tokens_per_expert = ( - gather_from_sequence_parallel_region( - num_tokens_per_expert, group=self.tp_ep_group - ) - .reshape(self.ep_size, self.tp_size, num_experts) - .transpose(0, 1) - ) - - num_global_tokens_per_expert =num_tokens_per_expert.sum(dim=1) - if self.config.fp8: - pad_multiple = get_fp8_align_size(self.config.fp8_recipe) - num_global_tokens_per_expert += -num_global_tokens_per_expert % pad_multiple - num_tokens_per_ep_rank = num_global_tokens_per_expert.view(num_global_tokens_per_expert.shape[0], self.ep_size, -1).sum(dim=-1) - - routing_map_maybe_dropped, under_budget = drop_routing_map_triton(routing_map, budget, num_tokens_per_ep_rank) - return routing_map_maybe_dropped, under_budget + def check_over_budget(self): + if hasattr(self._comm_manager, 'over_budget'): + return self._comm_manager.over_budget + else: + return None From fe504bdc8b89df9b77bab276198b0664b5434002 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Mon, 24 Nov 2025 19:33:43 +0800 Subject: [PATCH 12/57] Use CPU offloading context manager as a WAR for now to WAR the problem of overlap_grad_reduce see https://github.com/NVIDIA/TransformerEngine/blob/e1edaaec2bb1e6542e0e2dff81d5217ff5e1eb89/transformer_engine/pytorch/module/grouped_linear.py#L229-L233 --- .../pipeline_parallel/moe_packed_offload.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 84c14ac44f9..2cb33cb39c1 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -945,6 +945,36 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: return saved_state +class PackedOffloadContext: + """Wrapper context manager that adds custom enter/exit behavior around saved_tensors_hooks.""" + + def __init__(self, offload_manager): + self.offload_manager = offload_manager + self.saved_tensors_context = torch.autograd.graph.saved_tensors_hooks( + offload_manager.on_save_for_backward, + offload_manager.on_get_saved_tensor + ) + + def __enter__(self): + from megatron.core.extensions.transformer_engine import cpu_offload + + if cpu_offload is not None: + cpu_offload.CPUOffloadEnabled = True + # Call the underlying context manager's __enter__ + result = self.saved_tensors_context.__enter__() + + # Add more custom logic after entering if needed + return result + + def __exit__(self, *args: Any): + # Call the underlying context manager's __exit__ + result = self.saved_tensors_context.__exit__(*args) + from megatron.core.extensions.transformer_engine import cpu_offload + + if cpu_offload is not None: + cpu_offload.CPUOffloadEnabled = False + return result + def packed_moe_expert_offloading_group_start(tensor, name=None): """Mark the start of a layer group and prepare for offload/reload.""" rank = torch.distributed.get_rank() @@ -958,7 +988,7 @@ def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) offload_manager.num_tokens_tensor = num_tokens_tensor offload_manager.set_current_layer_name(name) if name is not None else None - pack_unpack_context = torch.autograd.graph.saved_tensors_hooks(offload_manager.on_save_for_backward, offload_manager.on_get_saved_tensor) + pack_unpack_context = PackedOffloadContext(offload_manager) return pack_unpack_context def packed_moe_expert_offloading_group_commit(tensor, name=None): From 93fb18398c46b394158ee5970352bc025cc607c4 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Tue, 25 Nov 2025 13:12:21 +0800 Subject: [PATCH 13/57] Add support for paged stashing --- .../pipeline_parallel/moe_packed_offload.py | 621 ++++++++++++------ megatron/core/pipeline_parallel/schedules.py | 12 +- 2 files changed, 441 insertions(+), 192 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 2cb33cb39c1..f9515fdda80 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -114,9 +114,7 @@ def _stash_copy_kernel_2d( token_idx = token_start while token_idx < token_end: # Calculate destination token index with wraparound - dst_token_idx = free_offset + token_idx - if dst_token_idx >= capacity: - dst_token_idx = dst_token_idx - capacity + dst_token_idx = (free_offset + token_idx) % capacity # Each thread handles elements of the hidden dimension elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE @@ -147,9 +145,7 @@ def _stash_copy_kernel_2d( # Update new_free_offset (only first block writes it) if pid == 0: - new_free_offset = free_offset + num_tokens - if new_free_offset >= capacity: - new_free_offset = new_free_offset - capacity + new_free_offset = (free_offset + num_tokens) % capacity tl.store(new_free_offset_ptr, new_free_offset) @triton.jit @@ -188,9 +184,7 @@ def _stash_pop_kernel_2d( token_idx = token_start while token_idx < token_end: # Calculate source token index with wraparound - src_token_idx = tensor_offset + token_idx - if src_token_idx >= capacity: - src_token_idx = src_token_idx - capacity + src_token_idx = (tensor_offset + token_idx) % capacity # Each thread handles elements of the hidden dimension elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE @@ -225,171 +219,6 @@ def _stash_pop_kernel_2d( # The data was stashed at tensor_offset, so after popping, free_offset moves back to tensor_offset tl.store(free_offset_ptr, tensor_offset) -@triton.jit -def _stash_copy_kernel( - src_ptr, - dst_ptr, - size_ptr, - alloc_offset_ptr, - free_offset_ptr, - capacity_ptr, - overflow_ptr, - BLOCK_SIZE: tl.constexpr, - num_iterations: tl.constexpr, - max_tokens: tl.constexpr, -): - """Triton kernel to copy tensor data to stash buffer. - - Each block can handle multiple chunks of data (num_iterations) to limit total blocks. - Ignores out-of-bound writes if offset + size exceeds capacity. - - Args: - src_ptr: Pointer to source tensor (flattened) - dst_ptr: Pointer to destination buffer (stash_buffer) - size_ptr: Pointer to scalar tensor containing the size to copy - offset_original_ptr: Pointer to GPU tensor containing original offset (read-only) - over_capacity_ptr: Pointer to counter tensor (incremented when over capacity) - capacity: Total capacity of the buffer - BLOCK_SIZE: Block size for Triton kernel - num_iterations: Number of iterations each block should handle - """ - # Get the program ID - pid = tl.program_id(axis=0) - num_programs = tl.num_programs(axis=0) - - # Load the size value from GPU tensor - size = tl.load(size_ptr) - - # Load original offset from GPU tensor (for position calculations) - alloc_offset = tl.load(alloc_offset_ptr) - free_offset = tl.load(free_offset_ptr) - capacity = tl.load(capacity_ptr) - # Only the first thread checks capacity - # Do this BEFORE the loop so it always happens - overflow = False - # Check if over capacity and increment counter - avail_space = free_offset - alloc_offset - if avail_space < 0: - avail_space = -avail_space - else: - avail_space = capacity - avail_space - if avail_space < size or max_tokens < size: - overflow = True - if pid == 0 and overflow: - tl.store(overflow_ptr, 1) - - #if pid == 1: - # tl.device_print("free_offset: ", free_offset) - if overflow: - return - - # Each block handles num_iterations chunks of BLOCK_SIZE elements - # Use while loop with early exit condition in the loop test - iteration = 0 - block_start = (pid * num_iterations + iteration) * BLOCK_SIZE - while iteration < num_iterations and block_start < size: - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - # Create mask for valid elements within source size - src_mask = offsets < size - - # Create mask for valid destination indices (within buffer capacity) - dst_indices = free_offset + offsets - dst_mask = dst_indices >= capacity - dst_indices = tl.where(dst_mask, dst_indices - capacity, dst_indices) - - # Load from source - src_data = tl.load(src_ptr + offsets, mask=src_mask, other=0.0) - - # Store to destination (ignores out-of-bound writes) - tl.store(dst_ptr + dst_indices, src_data, mask=src_mask) - - # Move to next iteration - iteration += 1 - block_start = (pid * num_iterations + iteration) * BLOCK_SIZE - - # Check if over capacity and increment counter - size_page_aligned = tl.cdiv(size, BLOCK_SIZE) * BLOCK_SIZE - - free_offset = free_offset + size_page_aligned - if free_offset > capacity: - free_offset -= capacity - if pid == 0: - tl.store(free_offset_ptr, free_offset) - -@triton.jit -def _stash_pop_kernel( - src_ptr, - dst_ptr, - size_ptr, - tensor_offset_ptr, - alloc_offset_ptr, - free_offset_ptr, - capacity_ptr, - BLOCK_SIZE: tl.constexpr, - num_iterations: tl.constexpr, -): - """Triton kernel to copy tensor data from stash buffer. - - Each block can handle multiple chunks of data (num_iterations) to limit total blocks. - Ignores out-of-bound writes if offset + size exceeds capacity. - - Args: - src_ptr: Pointer to source tensor (flattened) - dst_ptr: Pointer to destination buffer (stash_buffer) - size_ptr: Pointer to scalar tensor containing the size to copy - offset_original_ptr: Pointer to GPU tensor containing original offset (read-only) - over_capacity_ptr: Pointer to counter tensor (incremented when over capacity) - capacity: Total capacity of the buffer - BLOCK_SIZE: Block size for Triton kernel - num_iterations: Number of iterations each block should handle - """ - # Get the program ID - pid = tl.program_id(axis=0) - num_programs = tl.num_programs(axis=0) - - # Load the size value from GPU tensor - size = tl.load(size_ptr) - - # Load original offset from GPU tensor (for position calculations) - tensor_offset = tl.load(tensor_offset_ptr) - alloc_offset = tl.load(alloc_offset_ptr) - free_offset = tl.load(free_offset_ptr) - capacity = tl.load(capacity_ptr) - - # Each block handles num_iterations chunks of BLOCK_SIZE elements - # Use while loop with early exit condition in the loop test - iteration = 0 - block_start = (pid * num_iterations + iteration) * BLOCK_SIZE - while iteration < num_iterations and block_start < size: - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - # Create mask for valid elements within source size - dst_mask = offsets < size - - # Create mask for valid destination indices (within buffer capacity) - src_indices = tensor_offset + offsets - src_mask = src_indices >= capacity - src_indices = tl.where(src_mask, src_indices - capacity, src_indices) - - # Load from source - src_data = tl.load(src_ptr + src_indices, mask=dst_mask, other=0.0) - - # Store to destination (ignores out-of-bound writes) - tl.store(dst_ptr + offsets, src_data, mask=dst_mask) - - # Move to next iteration - iteration += 1 - block_start = (pid * num_iterations + iteration) * BLOCK_SIZE - - # Check if over capacity and increment counter - size_page_aligned = tl.cdiv(size, BLOCK_SIZE) * BLOCK_SIZE - tensor_offset = tensor_offset + size_page_aligned - if tensor_offset > capacity: - tensor_offset -= capacity - if pid == 0: - mask = tensor_offset > alloc_offset - tl.store(alloc_offset_ptr, tensor_offset, mask=mask) class StashBuffer: """ @@ -437,6 +266,237 @@ def __repr__(self): return f"StashBuffer(capacity={self.num_tokens_capacity} tokens, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" +class PagedStashBuffer: + """ + A paged stash buffer with page-level memory management. + + The buffer is organized as [num_pages, page_size, hidden_size]. + Uses a free list (circular buffer) to track available pages. + """ + + def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype): + """ + Args: + num_tokens: Maximum number of tokens the buffer can hold + hidden_size: Hidden dimension size + page_size: Number of tokens per page + device: Device for the buffer + overflow: Overflow flag tensor (shared across all buffers) + dtype: Data type + """ + self.hidden_size = hidden_size + self.page_size = page_size + self.num_pages = (num_tokens + page_size - 1) // page_size # Ceiling division + self.total_tokens = self.num_pages * page_size + + # Create 2D buffer [total_tokens, hidden_size] + # Organized as pages: [page_0_tokens, page_1_tokens, ...] + if os.getenv('PACKED_OFFLOAD_CPU', '0') == '1': + self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True) + else: + self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device=device) + + self.overflow = overflow # GPU flag (shared) + self.device = device + self.dtype = dtype + + # Free list as circular buffer: stores available page IDs + self.free_list = torch.arange(self.num_pages, dtype=torch.int64, device=device) + + # Head and tail pointers for free_list circular buffer + self.free_list_head = torch.zeros(1, dtype=torch.int64, device=device) # Read pointer (allocation) + self.free_list_tail = torch.tensor([self.num_pages], dtype=torch.int64, device=device) # Write pointer (deallocation) + + # Capacity of free list + self.free_list_capacity = torch.tensor([self.num_pages], dtype=torch.int64, device=device) + + def reset(self): + """Reset the paged buffer - reinitialize free list.""" + self.free_list = torch.arange(self.num_pages, dtype=torch.int64, device=self.device) + self.free_list_head.zero_() + self.free_list_tail.fill_(self.num_pages) + + def __repr__(self): + return f"PagedStashBuffer(num_pages={self.num_pages}, page_size={self.page_size}, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" + + +@triton.jit +def _paged_stash_copy_kernel( + src_ptr, + dst_ptr, + num_tokens_ptr, + free_list_ptr, + free_list_head_ptr, # Read-only: current head position + free_list_tail_ptr, # Read-only: current tail position (for overflow check) + free_list_capacity_ptr, + page_record_ptr, # Output: records which pages were used + overflow_ptr, + new_free_list_head_ptr, # Output: new head position (written by kernel) + PAGE_SIZE: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel to copy tokens to paged stash buffer. + + Allocates pages from free list (reads from head, advances head). + Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. + Grid: (num_blocks,) where blocks process tokens in a strided pattern. + Writes new head to temporary tensor to avoid race conditions. + """ + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load parameters + num_tokens = tl.load(num_tokens_ptr) + free_list_head = tl.load(free_list_head_ptr) + free_list_tail = tl.load(free_list_tail_ptr) + free_list_capacity = tl.load(free_list_capacity_ptr) + + # Check available pages (unwrapped indices: simple subtraction, no modulo needed) + avail_pages = free_list_tail - free_list_head + + # Calculate required pages + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + overflow_detected = avail_pages < required_pages + + # Only block 0 writes overflow flag + if pid == 0 and overflow_detected: + tl.store(overflow_ptr, 1) + + # All blocks return early if overflow + if overflow_detected: + return + + # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] + token_idx = pid + while token_idx < num_tokens: + # Determine which page this token belongs to + page_slot = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + + # Read page ID from free list (with wraparound) + free_list_idx = (free_list_head + page_slot) % free_list_capacity + page_id = tl.load(free_list_ptr + free_list_idx) + + # First token in page: record the page ID (only if this block handles token 0 of the page) + if token_in_page == 0: + tl.store(page_record_ptr + page_slot, page_id) + + # Calculate destination address in paged buffer + dst_token_idx = page_id * PAGE_SIZE + token_in_page + + # Copy token data (2D: hidden dimension) + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + + src_base = src_ptr + token_idx * HIDDEN_SIZE + dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE + + if need_mask: + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + + # Stride to next token for this block + token_idx += num_blocks + + # Calculate and store new free list head (only block 0) + # We consumed pages, so advance head forward (unwrapped: no modulo) + # Write to temporary tensor to avoid race conditions + if pid == 0: + new_head = free_list_head + required_pages + tl.store(new_free_list_head_ptr, new_head) + + +@triton.jit +def _paged_stash_pop_kernel( + src_ptr, + dst_ptr, + num_tokens_ptr, + page_record_ptr, # Input: which pages to read + free_list_ptr, + free_list_head_ptr, # Read-only: current head position (not used) + free_list_tail_ptr, # Read-only: current tail position + free_list_capacity_ptr, + new_free_list_tail_ptr, # Output: new tail position (written by kernel) + PAGE_SIZE: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel to reload tokens from paged stash buffer. + + Returns pages to free list (writes to tail, advances tail). + Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. + Grid: (num_blocks,) where blocks process tokens in a strided pattern. + Writes new tail to temporary tensor to avoid race conditions. + """ + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load parameters + num_tokens = tl.load(num_tokens_ptr) + free_list_tail = tl.load(free_list_tail_ptr) + free_list_capacity = tl.load(free_list_capacity_ptr) + + # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] + token_idx = pid + while token_idx < num_tokens: + # Determine which page this token belongs to + page_slot = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + + # Read page ID from page record + page_id = tl.load(page_record_ptr + page_slot) + + # Calculate source address in paged buffer + src_token_idx = page_id * PAGE_SIZE + token_in_page + + # Copy token data (2D: hidden dimension) + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + + src_base = src_ptr + src_token_idx * HIDDEN_SIZE + dst_base = dst_ptr + token_idx * HIDDEN_SIZE + + if need_mask: + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + + # First token in page: release page back to free list + if token_in_page == 0: + # Write page ID back to free list at tail position (with wraparound) + write_idx = (free_list_tail + page_slot) % free_list_capacity + tl.store(free_list_ptr + write_idx, page_id) + + # Stride to next token for this block + token_idx += num_blocks + + # Calculate and store new free list tail (only block 0) + # We returned pages, so advance tail forward (unwrapped: no modulo) + # Write to temporary tensor to avoid race conditions + if pid == 0: + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + new_tail = free_list_tail + required_pages + tl.store(new_free_list_tail_ptr, new_tail) + + class PackedTensor: """ A class to represent a packed tensor. @@ -502,7 +562,7 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): # Cap the number of blocks and calculate tokens per block num_blocks = min(total_blocks_needed, max_blocks) - tokens_per_block = triton.cdiv(self.max_num_tokens, num_blocks) + tokens_per_block = (self.max_num_tokens + num_blocks - 1) // num_blocks # Ceiling division if DEBUG: debug_print (f"offload_to_stash ({self.layer_name}) {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks} tokens_per_block {tokens_per_block} overflow {stash_buffer.overflow.item()}") @@ -570,7 +630,7 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): # Cap the number of blocks and calculate tokens per block num_blocks = min(total_blocks_needed, max_blocks) - tokens_per_block = triton.cdiv(self.max_num_tokens, num_blocks) + tokens_per_block = (self.max_num_tokens + num_blocks - 1) // num_blocks # Ceiling division if DEBUG: debug_print (f"reload_from_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks} tokens_per_block {tokens_per_block}") @@ -603,6 +663,167 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): def __repr__(self): return f"PackedTensor(original_shape={self.original_shape}, num_tokens={self.num_tokens_tensor.item()}, vp_stage={self.vp_stage})" + +class PagedTensor: + """ + A paged tensor that stores data in pages within a paged stash buffer. + Similar to PackedTensor but uses page-level memory management. + """ + + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=None, max_tokens=None, page_size=64): + """ + Args: + tensor: The tensor to store + num_tokens_tensor: Scalar tensor containing actual number of tokens + vp_stage: Virtual pipeline stage + layer_name: Name of the layer + max_tokens: Maximum number of tokens + page_size: Number of tokens per page + """ + self._tensor = tensor + self._original_tensor = None + assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1 + self.num_tokens_tensor = num_tokens_tensor.clone() + self.vp_stage = vp_stage + self.layer_name = layer_name + self.max_tokens = max_tokens + self.page_size = page_size + + # Original tensor information + self.original_shape = list(tensor.shape) + self.max_num_tokens = self.original_shape[0] + self.element_size = tensor.element_size() + self.hidden_size = self.original_shape[1] + self.dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype + self.device = tensor.device + + # Calculate number of pages needed + self.max_num_pages = (self.max_num_tokens + page_size - 1) // page_size # Ceiling division + + # Page record: stores which pages are being used for this tensor + self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) + + def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): + """Offload the paged tensor to paged stash buffer.""" + if not HAVE_TRITON: + raise RuntimeError("Triton is required for PagedTensor.offload_to_stash(). Please install triton.") + + self._tensor = self._tensor.contiguous() + if self.num_tokens_tensor.dim() == 0: + self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) + + # Get 2D tensor + if isinstance(self._tensor, MXFP8Tensor): + tensor_to_copy = self._tensor._columnwise_data + else: + tensor_to_copy = self._tensor + + # Determine grid size + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + total_blocks_needed = self.max_num_tokens + num_blocks = min(total_blocks_needed, max_blocks) + + if DEBUG: + debug_print(f"PagedTensor offload ({self.layer_name}) {self._tensor.shape}-{self.dtype} page_size={self.page_size} num_tokens={self.num_tokens_tensor.item()} num_blocks={num_blocks}") + + grid = (num_blocks,) + + # Create temporary tensor for new head (kernel will write to this) + new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch paged stash copy kernel (strided access pattern) + # Allocates pages from free list (reads from head, advances head) + _paged_stash_copy_kernel[grid]( + tensor_to_copy, + paged_stash_buffer.buffer, + self.num_tokens_tensor, + paged_stash_buffer.free_list, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + self.page_record, + paged_stash_buffer.overflow, + new_free_list_head, # Temporary tensor for new head + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + # if paged_stash_buffer.overflow.item() == 1 and torch.distributed.get_rank() == 0: import pdb; pdb.set_trace() + # torch.distributed.barrier() + + # Copy new head value after kernel completes (stream-ordered, avoids race condition) + paged_stash_buffer.free_list_head.copy_(new_free_list_head) + + # Save reference to original tensor + self._original_tensor = self._tensor + self._tensor = None + + if DEBUG: + debug_print(f"After PagedTensor offload page_record={self.page_record[:5]}") + + def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): + """Reload the paged tensor from paged stash buffer.""" + if not HAVE_TRITON: + raise RuntimeError("Triton is required for PagedTensor.reload_from_stash(). Please install triton.") + + # Allocate output tensor + if isinstance(self._original_tensor, MXFP8Tensor): + columnwise_data = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) + self._tensor = MXFP8Tensor( + shape=self._original_tensor.shape, + dtype=self._original_tensor.dtype, + fp8_dtype=self._original_tensor._fp8_dtype, + rowwise_data=self._original_tensor._rowwise_data, + rowwise_scale_inv=self._original_tensor._rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=self._original_tensor._columnwise_scale_inv, + quantizer=self._original_tensor._quantizer, + ) + tensor_to_reload = self._tensor._columnwise_data + else: + self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) + tensor_to_reload = self._tensor + + # Determine grid size + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + total_blocks_needed = self.max_num_tokens + num_blocks = min(total_blocks_needed, max_blocks) + + if DEBUG: + debug_print(f"PagedTensor reload {self._tensor.shape}-{self.dtype} page_size={self.page_size} num_blocks={num_blocks}") + + grid = (num_blocks,) + + # Create temporary tensor for new tail (kernel will write to this) + new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch paged stash pop kernel (strided access pattern) + # Returns pages to free list (writes to tail, advances tail) + _paged_stash_pop_kernel[grid]( + paged_stash_buffer.buffer, + tensor_to_reload, + self.num_tokens_tensor, + self.page_record, + paged_stash_buffer.free_list, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + new_free_list_tail, # Temporary tensor for new tail + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Copy new tail value after kernel completes (stream-ordered, avoids race condition) + paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) + + if DEBUG: + debug_print(f"After PagedTensor reload") + + def __repr__(self): + return f"PagedTensor(original_shape={self.original_shape}, num_tokens={self.num_tokens_tensor.item()}, page_size={self.page_size}, vp_stage={self.vp_stage})" + + class PP_ScheduleFunction(torch.autograd.Function): """ This function is used to update the pp schedule. @@ -712,6 +933,10 @@ def __init__(self): self.stash_buffers = None self.overflow = None self.device = None + + # Page size for paged memory management + self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page + self.use_paged_stash = os.getenv('USE_PAGED_STASH', '0') == '1' # Enable via env var @property def pack_stream(self): @@ -805,13 +1030,19 @@ def allocate_offload_buffers(self, stash_buffer_size_factor=1.10): # Calculate number of tokens we can store (with safety factor) num_tokens = int(self.max_tokens_per_vp_stage[vp_stage][dtype][hidden_size] * stash_buffer_size_factor) - # Create 2D buffer - self.stash_buffers[vp_stage][dtype][hidden_size] = StashBuffer( - num_tokens, hidden_size, self.device, self.overflow, dtype - ) + # Create buffer (paged or regular based on configuration) + if self.use_paged_stash: + self.stash_buffers[vp_stage][dtype][hidden_size] = PagedStashBuffer( + num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype + ) + else: + self.stash_buffers[vp_stage][dtype][hidden_size] = StashBuffer( + num_tokens, hidden_size, self.device, self.overflow, dtype + ) if torch.distributed.get_rank() == 0: - print(f'allocated stash buffer vp_stage={vp_stage} dtype={dtype} hidden_size={hidden_size}: {self.stash_buffers[vp_stage][dtype][hidden_size]}') + buffer_type = "paged" if self.use_paged_stash else "regular" + print(f'allocated {buffer_type} stash buffer vp_stage={vp_stage} dtype={dtype} hidden_size={hidden_size}: {self.stash_buffers[vp_stage][dtype][hidden_size]}') def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): """Update the pp schedule.""" @@ -896,7 +1127,25 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: tensor_truncated.copy_(tensor[:self.num_tokens, ...]) tensor = tensor_truncated - packed_tensor = PackedTensor(tensor, num_tokens_tensor=self.num_tokens_tensor, vp_stage=self.current_vp_stage, layer_name=self._current_layer_name, max_tokens=self.max_num_tokens) + # Create tensor (paged or regular based on configuration) + if self.use_paged_stash: + packed_tensor = PagedTensor( + tensor, + num_tokens_tensor=self.num_tokens_tensor, + vp_stage=self.current_vp_stage, + layer_name=self._current_layer_name, + max_tokens=self.max_num_tokens, + page_size=self.page_size + ) + else: + packed_tensor = PackedTensor( + tensor, + num_tokens_tensor=self.num_tokens_tensor, + vp_stage=self.current_vp_stage, + layer_name=self._current_layer_name, + max_tokens=self.max_num_tokens + ) + if self.status == 'captured': self.add_packed_tensor_to_offload(packed_tensor) return packed_tensor @@ -906,7 +1155,7 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: Hook called when autograd retrieves a saved tensor during backward pass. Returns the actual tensor (potentially reloading from CPU). """ - if isinstance(saved_state, PackedTensor): + if isinstance(saved_state, (PackedTensor, PagedTensor)): if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][saved_state.hidden_size] -= num_tokens @@ -1027,10 +1276,10 @@ def packed_moe_expert_offloading_reset(enabled=True): # current layer and microbatch for each vp stage for forward pass offload_manager.current_schedule_index = 0 if os.getenv('MEM_PROFILE', '0') == '1': - if offload_manager.iteration == 1 and torch.distributed.get_rank() == 0: - torch.cuda.memory._record_memory_history() + if offload_manager.iteration == 0 and torch.distributed.get_rank() == 0: + torch.cuda.memory._record_memory_history(max_entries=1000000) print(f'packed_moe_expert_offloading_reset record_memory_history') - if offload_manager.iteration == 10 and torch.distributed.get_rank() == 0: + if offload_manager.iteration == 5 and torch.distributed.get_rank() == 0: torch.cuda.memory._dump_snapshot("packed_offloading_cg.pkl") torch.cuda.memory._record_memory_history(enabled=None) print(f'packed_moe_expert_offloading_reset dump_snapshot') diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index b099be93280..3dc823538b0 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -1558,8 +1558,8 @@ def forward_backward_helper_wrapper( print (f'{torch.distributed.get_rank()}: forward_backward_pipelining_with_interleaving num_warmup_microbatches {num_warmup_microbatches} num_microbatches_1f1b {num_microbatches_remaining} total_num_microbatches {total_num_microbatches}') for k in range(num_warmup_microbatches): cur_model_chunk_id = get_model_chunk_id(k, forward=True) - if torch.distributed.get_rank() in [0, 2]: - print(f'{pipeline_parallel_rank}: +++++ warmup iteration {k}, fwd_model_chunk_id {model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}') + # if torch.distributed.get_rank() in [0, 2]: + # print(f'{pipeline_parallel_rank}: +++++ warmup iteration {k}, fwd_model_chunk_id {model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}') if config.overlap_p2p_comm_warmup_flush: if ( not ( @@ -1731,8 +1731,8 @@ def forward_backward_helper_wrapper( if config.overlap_p2p_comm: backward_k = k - if torch.distributed.get_rank() in [0, 2]: - print(f'{pipeline_parallel_rank}: +++++ steady iteration forward_k {forward_k} fwd_model_chunk_id {model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}, backward_k {backward_k} bwd_model_chunk_id {-model_chunk_ids[pipeline_parallel_rank][get_model_chunk_id(backward_k, forward=False)]}') + # if torch.distributed.get_rank() in [0, 2]: + # print(f'{pipeline_parallel_rank}: +++++ steady iteration forward_k {forward_k} fwd_model_chunk_id {model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}, backward_k {backward_k} bwd_model_chunk_id {-model_chunk_ids[pipeline_parallel_rank][get_model_chunk_id(backward_k, forward=False)]}') # Sync forward recv def pp_pre_forward(vp_stage=None): @@ -1951,8 +1951,8 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): ) for k in range(num_microbatches_remaining, total_num_microbatches): cur_model_chunk_id = get_model_chunk_id(k, forward=False) - if torch.distributed.get_rank() in [0, 2]: - print(f'{pipeline_parallel_rank}: cooldown iteration k {k} bwd_model_chunk_id {-model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}') + # if torch.distributed.get_rank() in [0, 2]: + # print(f'{pipeline_parallel_rank}: cooldown iteration k {k} bwd_model_chunk_id {-model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}') if ( not (_is_vp_last_stage(vp_stage=cur_model_chunk_id) and is_pp_last_stage(pp_group)) and k != 0 From ed07de67c8ce33d7f14e1e34b05a46848dbe0840 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Wed, 26 Nov 2025 19:59:37 +0800 Subject: [PATCH 14/57] Add the feature of speculative CE stashing --- .../pipeline_parallel/moe_packed_offload.py | 195 +++++++++++------- 1 file changed, 126 insertions(+), 69 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index f9515fdda80..b8c5ed07bb2 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -312,7 +312,7 @@ def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype): def reset(self): """Reset the paged buffer - reinitialize free list.""" - self.free_list = torch.arange(self.num_pages, dtype=torch.int64, device=self.device) + self.free_list.copy_(torch.arange(self.num_pages, dtype=torch.int64, device=self.device)) self.free_list_head.zero_() self.free_list_tail.fill_(self.num_pages) @@ -670,7 +670,7 @@ class PagedTensor: Similar to PackedTensor but uses page-level memory management. """ - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=None, max_tokens=None, page_size=64): + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0): """ Args: tensor: The tensor to store @@ -679,6 +679,7 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=Non layer_name: Name of the layer max_tokens: Maximum number of tokens page_size: Number of tokens per page + num_d2d_pages: Number of pages to copy using native PyTorch (rest uses Triton) """ self._tensor = tensor self._original_tensor = None @@ -688,6 +689,7 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=Non self.layer_name = layer_name self.max_tokens = max_tokens self.page_size = page_size + self.num_d2d_pages = num_d2d_pages # Original tensor information self.original_shape = list(tensor.shape) @@ -702,9 +704,21 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=Non # Page record: stores which pages are being used for this tensor self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) + + # Static tensor for D2D pages (allocate upfront if needed) + d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) + if d2d_tokens > 0: + self.static_tensor = torch.empty((d2d_tokens, self.hidden_size), dtype=self.dtype, device=self.device) + else: + self.static_tensor = None def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): - """Offload the paged tensor to paged stash buffer.""" + """Offload the paged tensor to paged stash buffer. + + Args: + paged_stash_buffer: The paged stash buffer to offload to + max_blocks: Maximum number of blocks for Triton kernel + """ if not HAVE_TRITON: raise RuntimeError("Triton is required for PagedTensor.offload_to_stash(). Please install triton.") @@ -718,51 +732,72 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 else: tensor_to_copy = self._tensor - # Determine grid size - BLOCK_SIZE = GLOBAL_BLOCK_SIZE - total_blocks_needed = self.max_num_tokens - num_blocks = min(total_blocks_needed, max_blocks) + # Split tensor into two parts: D2D portion and Triton portion + # Use max_num_tokens for consistent size across iterations + d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) + triton_tokens = self.max_num_tokens - d2d_tokens if DEBUG: - debug_print(f"PagedTensor offload ({self.layer_name}) {self._tensor.shape}-{self.dtype} page_size={self.page_size} num_tokens={self.num_tokens_tensor.item()} num_blocks={num_blocks}") - - grid = (num_blocks,) - - # Create temporary tensor for new head (kernel will write to this) - new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) - - # Launch paged stash copy kernel (strided access pattern) - # Allocates pages from free list (reads from head, advances head) - _paged_stash_copy_kernel[grid]( - tensor_to_copy, - paged_stash_buffer.buffer, - self.num_tokens_tensor, - paged_stash_buffer.free_list, - paged_stash_buffer.free_list_head, - paged_stash_buffer.free_list_tail, - paged_stash_buffer.free_list_capacity, - self.page_record, - paged_stash_buffer.overflow, - new_free_list_head, # Temporary tensor for new head - PAGE_SIZE=self.page_size, - HIDDEN_SIZE=self.hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - # if paged_stash_buffer.overflow.item() == 1 and torch.distributed.get_rank() == 0: import pdb; pdb.set_trace() - # torch.distributed.barrier() - - # Copy new head value after kernel completes (stream-ordered, avoids race condition) - paged_stash_buffer.free_list_head.copy_(new_free_list_head) + debug_print(f"PagedTensor offload ({self.layer_name}) {self._tensor.shape}-{self.dtype} page_size={self.page_size} max_num_tokens={self.max_num_tokens} num_d2d_pages={self.num_d2d_pages} d2d_tokens={d2d_tokens} triton_tokens={triton_tokens}") + + # Perform both D2D copy and Triton kernel together + # Part 1: Copy first d2d_tokens to static_tensor using native PyTorch + if d2d_tokens > 0: + self.static_tensor[:d2d_tokens] = tensor_to_copy[:d2d_tokens] + if DEBUG: + debug_print(f"Copied {d2d_tokens} tokens to static_tensor using D2D") + + # Part 2: Copy remaining tokens using Triton kernel + if triton_tokens > 0: + triton_tensor = tensor_to_copy[d2d_tokens:self.max_num_tokens] + # Use actual num_tokens for the kernel (how many tokens to actually copy) + triton_num_tokens = self.num_tokens_tensor - d2d_tokens + + # Determine grid size + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(triton_tokens, max_blocks) + grid = (num_blocks,) + + # Create temporary tensor for new head + new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch paged stash copy kernel + _paged_stash_copy_kernel[grid]( + triton_tensor, + paged_stash_buffer.buffer, + triton_num_tokens, + paged_stash_buffer.free_list, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + self.page_record, # Triton kernel will populate page_record + paged_stash_buffer.overflow, + new_free_list_head, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Update free list head + paged_stash_buffer.free_list_head.copy_(new_free_list_head) + + if DEBUG: + debug_print(f"Copied {triton_tokens} tokens using Triton kernel") # Save reference to original tensor self._original_tensor = self._tensor self._tensor = None if DEBUG: - debug_print(f"After PagedTensor offload page_record={self.page_record[:5]}") + debug_print(f"After PagedTensor offload") def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): - """Reload the paged tensor from paged stash buffer.""" + """Reload the paged tensor from paged stash buffer. + + Args: + paged_stash_buffer: The paged stash buffer to reload from + max_blocks: Maximum number of blocks for Triton kernel + """ if not HAVE_TRITON: raise RuntimeError("Triton is required for PagedTensor.reload_from_stash(). Please install triton.") @@ -784,38 +819,56 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) tensor_to_reload = self._tensor - # Determine grid size - BLOCK_SIZE = GLOBAL_BLOCK_SIZE - total_blocks_needed = self.max_num_tokens - num_blocks = min(total_blocks_needed, max_blocks) + # Split tensor into two parts: D2D portion and Triton portion + # Use max_num_tokens for consistency with offload + d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) + triton_tokens = self.max_num_tokens - d2d_tokens if DEBUG: - debug_print(f"PagedTensor reload {self._tensor.shape}-{self.dtype} page_size={self.page_size} num_blocks={num_blocks}") - - grid = (num_blocks,) - - # Create temporary tensor for new tail (kernel will write to this) - new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) - - # Launch paged stash pop kernel (strided access pattern) - # Returns pages to free list (writes to tail, advances tail) - _paged_stash_pop_kernel[grid]( - paged_stash_buffer.buffer, - tensor_to_reload, - self.num_tokens_tensor, - self.page_record, - paged_stash_buffer.free_list, - paged_stash_buffer.free_list_head, - paged_stash_buffer.free_list_tail, - paged_stash_buffer.free_list_capacity, - new_free_list_tail, # Temporary tensor for new tail - PAGE_SIZE=self.page_size, - HIDDEN_SIZE=self.hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Copy new tail value after kernel completes (stream-ordered, avoids race condition) - paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) + debug_print(f"PagedTensor reload {self._tensor.shape}-{self.dtype} page_size={self.page_size} max_num_tokens={self.max_num_tokens} num_d2d_pages={self.num_d2d_pages} d2d_tokens={d2d_tokens} triton_tokens={triton_tokens}") + + # Perform both D2D copy and Triton kernel together + # Part 1: Copy first d2d_tokens from static_tensor using native PyTorch + if d2d_tokens > 0 and self.static_tensor is not None: + tensor_to_reload[:d2d_tokens] = self.static_tensor[:d2d_tokens] + if DEBUG: + debug_print(f"Reloaded {d2d_tokens} tokens from static_tensor using D2D") + + # Part 2: Copy remaining tokens using Triton kernel + if triton_tokens > 0: + triton_tensor = tensor_to_reload[d2d_tokens:self.max_num_tokens] + # Use actual num_tokens for the kernel (how many tokens to actually copy) + triton_num_tokens = self.num_tokens_tensor - d2d_tokens + + # Determine grid size + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(triton_tokens, max_blocks) + grid = (num_blocks,) + + # Create temporary tensor for new tail + new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch paged stash pop kernel + _paged_stash_pop_kernel[grid]( + paged_stash_buffer.buffer, + triton_tensor, + triton_num_tokens, + self.page_record, # Triton kernel will read from page_record + paged_stash_buffer.free_list, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + new_free_list_tail, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Update free list tail + paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) + + if DEBUG: + debug_print(f"Reloaded {triton_tokens} tokens using Triton kernel") if DEBUG: debug_print(f"After PagedTensor reload") @@ -937,6 +990,9 @@ def __init__(self): # Page size for paged memory management self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page self.use_paged_stash = os.getenv('USE_PAGED_STASH', '0') == '1' # Enable via env var + + # Number of pages to copy using native PyTorch (D2D) + self.num_d2d_pages = int(os.getenv('NUM_D2D_PAGES', '0')) # Default 0 (all Triton) @property def pack_stream(self): @@ -1135,7 +1191,8 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: vp_stage=self.current_vp_stage, layer_name=self._current_layer_name, max_tokens=self.max_num_tokens, - page_size=self.page_size + page_size=self.page_size, + num_d2d_pages=self.num_d2d_pages ) else: packed_tensor = PackedTensor( From 7ab5f9b432cdc1806f5fa89faa611be955dbaaef Mon Sep 17 00:00:00 2001 From: a Date: Wed, 26 Nov 2025 11:51:56 -0800 Subject: [PATCH 15/57] Fix PP schedule --- .../core/pipeline_parallel/moe_packed_offload.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index b8c5ed07bb2..20044c2816e 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -15,7 +15,7 @@ HAVE_TRITON = False # Packed Moe Expert Offload implementation for pipeline parallelism -DEBUG = False +DEBUG = True DEBUG_RANK = [0] def debug_print(message): """Print debug message for a specific rank when DEBUG is enabled.""" @@ -906,7 +906,7 @@ def forward(ctx, tensor, offload_manager): # after forward packed_tensor._original_tensor = None if offload_manager.status == 'captured': - current_schedule_layer = (ctx.vp_stage+1)*100 + ctx.layer_no*10 + ctx.microbatch_no + current_schedule_layer = offload_manager.get_schedule_layer(ctx.vp_stage+1, ctx.layer_no, ctx.microbatch_no) next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index+1] if current_schedule_layer != -next_schedule_layer: # Start offload for current layer @@ -1008,6 +1008,10 @@ def set_current_layer_name(self, name): """Set the current layer name.""" self._current_layer_name = name + def get_schedule_layer(self, vp_stage, layer_no, microbatch_no): + """Get the schedule layer.""" + return vp_stage*1000000 + layer_no*1000 + microbatch_no + def add_packed_tensor_to_offload(self, packed_tensor): """Add a packed tensor to the offload list.""" if self.status == 'captured': @@ -1077,7 +1081,7 @@ def allocate_offload_buffers(self, stash_buffer_size_factor=1.10): """Allocate offload buffers for each vp stage, organized by [vp_stage][dtype][hidden_size].""" self.stash_buffers = [] self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) - + for vp_stage in range(self.vp_size): self.stash_buffers.append({}) for dtype in self.max_tokens_per_vp_stage[vp_stage]: @@ -1119,12 +1123,12 @@ def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): self.current_microbatch[vp_stage-1] += 1 if self.status == 'capture': - self._pp_schedule.append(vp_stage*100 + layer_no*10 + microbatch_no) + self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) num_tokens = self.num_tokens_tensor.item() #debug_print(f"------{self.current_schedule_index} len PP_Schedule {len(self._pp_schedule)}") #debug_print(f" {self.status} {self.current_schedule_index} {self._pp_schedule[self.current_schedule_index]} {vp_stage*100 + layer_no*10 + microbatch_no}") - assert self._pp_schedule[self.current_schedule_index] == vp_stage*100 + layer_no*10 + microbatch_no, f"schedule {self._pp_schedule[self.current_schedule_index]} != {vp_stage*100 + layer_no*10 + microbatch_no}" + assert self._pp_schedule[self.current_schedule_index] == self.get_schedule_layer(vp_stage, layer_no, microbatch_no), f"schedule {self._pp_schedule[self.current_schedule_index]} != {self.get_schedule_layer(vp_stage, layer_no, microbatch_no)}" return layer_no, microbatch_no From 95bc80f48b7547e5df9c72484c0ec5ca1de76ac1 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Wed, 26 Nov 2025 14:47:23 -0800 Subject: [PATCH 16/57] Use common buffer across VP for paged stashing --- .../pipeline_parallel/moe_packed_offload.py | 104 ++++++++++++++---- 1 file changed, 82 insertions(+), 22 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 20044c2816e..958b0eb9b81 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -501,12 +501,13 @@ class PackedTensor: """ A class to represent a packed tensor. """ - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=None, max_tokens=None): + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None): self._tensor = tensor self._original_tensor = None assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1, f"num_tokens_tensor {num_tokens_tensor} is not a scalar tensor" self.num_tokens_tensor = num_tokens_tensor.clone() self.vp_stage = vp_stage + self.schedule_layer_no = schedule_layer_no self.layer_name = layer_name self.max_tokens = max_tokens # Original tensor information @@ -519,6 +520,11 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=Non self.stash_buffer_offset = None + @property + def schedule_layer(self): + """Get the schedule layer.""" + return self.schedule_layer_no + def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): """Offload the packed tensor.""" #self._tensor.record_stream(torch.cuda.current_stream()) @@ -670,7 +676,7 @@ class PagedTensor: Similar to PackedTensor but uses page-level memory management. """ - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0): + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0): """ Args: tensor: The tensor to store @@ -686,6 +692,7 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=Non assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1 self.num_tokens_tensor = num_tokens_tensor.clone() self.vp_stage = vp_stage + self.schedule_layer_no = schedule_layer_no self.layer_name = layer_name self.max_tokens = max_tokens self.page_size = page_size @@ -711,7 +718,12 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, layer_name=Non self.static_tensor = torch.empty((d2d_tokens, self.hidden_size), dtype=self.dtype, device=self.device) else: self.static_tensor = None - + + @property + def schedule_layer(self): + """Get the schedule layer.""" + return self.schedule_layer_no + def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): """Offload the paged tensor to paged stash buffer. @@ -936,6 +948,23 @@ def backward(ctx, *grad_output): # before backward if ctx.offload_manager.status == 'captured' and ctx.offload_manager.current_schedule_index < len(ctx.offload_manager._pp_schedule): next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index] if next_schedule_layer < 0: + debug_print(f"PP_ScheduleFunction backward reload_packed_tensors {next_schedule_layer}") + # For last layer last microbatch, wait for offload to complete before reloading + if ctx.offload_manager._pack_stream_status == 'offloading': + assert len(ctx.offload_manager.packed_tensors_offload_in_progress) > 0, f"packed_tensors_offload_in_progress is empty" + offloaded_tensor = ctx.offload_manager.packed_tensors_offload_in_progress[0] + if next_schedule_layer == -offloaded_tensor.schedule_layer: + current_stream.wait_stream(ctx.offload_manager.pack_stream) + ctx.offload_manager._pack_stream_status = 'idle' + # Deallocate original tensor after offload is complete + while len(ctx.offload_manager.packed_tensors_offload_in_progress) > 0: + packed_tensor = ctx.offload_manager.packed_tensors_offload_in_progress.pop(0) + if not DEBUG: + if isinstance(packed_tensor._original_tensor, MXFP8Tensor): + packed_tensor._original_tensor._columnwise_data = None + else: + packed_tensor._original_tensor = None + ctx.offload_manager.reload_packed_tensors(-next_schedule_layer) return grad_output + (None, None) @@ -981,6 +1010,10 @@ def __init__(self): # Track max tokens needed per vp_stage, dtype, and hidden_size self.max_tokens_per_vp_stage = None self.temp_tokens_per_vp_stage = None + # Track max tokens needed across all vp_stages grouped by dtype and hidden_size + self.max_tokens_across_vp_stages = None + self.temp_tokens_across_vp_stages = None + self.num_tokens_tensor = None self.max_num_tokens = None self.stash_buffers = None @@ -1046,7 +1079,9 @@ def offload_packed_tensors(self, pp_schedule_layer): while len(self.packed_tensors_to_offload) > 0: packed_tensor = self.packed_tensors_to_offload.pop(0) - stash_buffer = self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype][packed_tensor.hidden_size] + #print (f'offload_packed_tensors vp_stage={packed_tensor.vp_stage} dtype={packed_tensor.dtype} hidden_size={packed_tensor.hidden_size} use_paged_stash={self.use_paged_stash}') + stash_buffers_vp_stage = self.stash_buffers[packed_tensor.vp_stage] if not self.use_paged_stash else self.stash_buffers[0] + stash_buffer = stash_buffers_vp_stage[packed_tensor.dtype][packed_tensor.hidden_size] packed_tensor.offload_to_stash(stash_buffer) self.packed_tensors_to_reload[pp_schedule_layer].append(packed_tensor) self.packed_tensors_offload_in_progress.append(packed_tensor) @@ -1070,7 +1105,8 @@ def reload_packed_tensors(self, pp_schedule_layer): debug_print(f"reload_packed_tensors {count}") while len(self.packed_tensors_to_reload[pp_schedule_layer]) > 0: packed_tensor = self.packed_tensors_to_reload[pp_schedule_layer].pop(0) - stash_buffer = self.stash_buffers[packed_tensor.vp_stage][packed_tensor.dtype][packed_tensor.hidden_size] + stash_buffers_vp_stage = self.stash_buffers[packed_tensor.vp_stage] if not self.use_paged_stash else self.stash_buffers[0] + stash_buffer = stash_buffers_vp_stage[packed_tensor.dtype][packed_tensor.hidden_size] packed_tensor.reload_from_stash(stash_buffer) else: pass @@ -1082,6 +1118,20 @@ def allocate_offload_buffers(self, stash_buffer_size_factor=1.10): self.stash_buffers = [] self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) + if self.use_paged_stash: + self.stash_buffers.append({}) + for dtype, hidden_size in self.max_tokens_across_vp_stages: + if dtype not in self.stash_buffers[0]: + self.stash_buffers[0][dtype] = {} + assert hidden_size not in self.stash_buffers[0][dtype] + num_tokens = int(self.max_tokens_across_vp_stages[dtype, hidden_size] * stash_buffer_size_factor) + self.stash_buffers[0][dtype][hidden_size] = PagedStashBuffer( + num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype + ) + if torch.distributed.get_rank() == 0: + print(f'allocated paged stash buffer dtype={dtype} hidden_size={hidden_size}: {self.stash_buffers[0][dtype][hidden_size]}') + return + # Regular stash buffers for vp_stage in range(self.vp_size): self.stash_buffers.append({}) for dtype in self.max_tokens_per_vp_stage[vp_stage]: @@ -1090,15 +1140,10 @@ def allocate_offload_buffers(self, stash_buffer_size_factor=1.10): # Calculate number of tokens we can store (with safety factor) num_tokens = int(self.max_tokens_per_vp_stage[vp_stage][dtype][hidden_size] * stash_buffer_size_factor) - # Create buffer (paged or regular based on configuration) - if self.use_paged_stash: - self.stash_buffers[vp_stage][dtype][hidden_size] = PagedStashBuffer( - num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype - ) - else: - self.stash_buffers[vp_stage][dtype][hidden_size] = StashBuffer( - num_tokens, hidden_size, self.device, self.overflow, dtype - ) + # Create buffer (regular) + self.stash_buffers[vp_stage][dtype][hidden_size] = StashBuffer( + num_tokens, hidden_size, self.device, self.overflow, dtype + ) if torch.distributed.get_rank() == 0: buffer_type = "paged" if self.use_paged_stash else "regular" @@ -1145,7 +1190,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if self.max_num_tokens is None or tensor.size(0) != self.max_num_tokens: return tensor.detach() if isinstance(tensor, MXFP8Tensor): - debug_print(f'on_save_for_backward MXFP8Tensor ({self._current_layer_name}) ndim {tensor.ndim} shape {tensor.shape} {tensor.dtype} rowwise {tensor._rowwise_data is not None} columnwise {(tensor._columnwise_data.shape, tensor._columnwise_data.dtype) if tensor._columnwise_data is not None else None}-scale_inv {tensor._columnwise_scale_inv.shape} {tensor._columnwise_scale_inv.dtype}') + #debug_print(f'on_save_for_backward MXFP8Tensor ({self._current_layer_name}) ndim {tensor.ndim} shape {tensor.shape} {tensor.dtype} rowwise {tensor._rowwise_data is not None} columnwise {(tensor._columnwise_data.shape, tensor._columnwise_data.dtype) if tensor._columnwise_data is not None else None}-scale_inv {tensor._columnwise_scale_inv.shape} {tensor._columnwise_scale_inv.dtype}') assert tensor._rowwise_data is None, f"rowwise_data is not None; Only columnwise data is supported for packed offloading" #if tensor.size(1) in [7168, 4096, 1] and DEBUG: @@ -1167,13 +1212,21 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if hidden_size not in self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype]: self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 - + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] += self.num_tokens self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = max( self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size], self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] ) - + if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: + self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.max_tokens_across_vp_stages[dtype, hidden_size] = 0 + + self.temp_tokens_across_vp_stages[dtype, hidden_size] += self.num_tokens + self.max_tokens_across_vp_stages[dtype, hidden_size] = max( + self.max_tokens_across_vp_stages[dtype, hidden_size], + self.temp_tokens_across_vp_stages[dtype, hidden_size] + ) # Since capture stage does not use CUDA graph, we can truncate the saved tensor to actual num_tokens # Truncate the tensor to the actual number of tokens new_size = (self.num_tokens, *tensor.shape[1:]) @@ -1192,7 +1245,8 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: packed_tensor = PagedTensor( tensor, num_tokens_tensor=self.num_tokens_tensor, - vp_stage=self.current_vp_stage, + vp_stage=self.current_vp_stage, + schedule_layer_no=self._pp_schedule[self.current_schedule_index] if self._pp_schedule is not None and self.current_schedule_index < len(self._pp_schedule) else None, layer_name=self._current_layer_name, max_tokens=self.max_num_tokens, page_size=self.page_size, @@ -1202,8 +1256,9 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: packed_tensor = PackedTensor( tensor, num_tokens_tensor=self.num_tokens_tensor, - vp_stage=self.current_vp_stage, - layer_name=self._current_layer_name, + vp_stage=self.current_vp_stage, + schedule_layer_no=self._pp_schedule[self.current_schedule_index] if self._pp_schedule is not None and self.current_schedule_index < len(self._pp_schedule) else None, + layer_name=self._current_layer_name, max_tokens=self.max_num_tokens ) @@ -1220,7 +1275,7 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][saved_state.hidden_size] -= num_tokens - + self.temp_tokens_across_vp_stages[saved_state.dtype, saved_state.hidden_size] -= num_tokens # Pad the tensor to the max number of tokens npad = self.max_num_tokens - num_tokens pad = () @@ -1247,7 +1302,8 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: equal = torch.equal(original_flat_sub, tensor_flat_sub) num_not_equal = (original_flat_sub != tensor_flat_sub).sum() idx_not_equal = (original_flat_sub != tensor_flat_sub).nonzero() - debug_print(f"on_get_saved_tensor original: {saved_state._original_tensor.shape} tensor: {saved_state._tensor.shape} equal tensors {equal} num_not_equal {num_not_equal}/{num_elements} idx_not_equal {idx_not_equal} original_tensor {original_flat_sub[idx_not_equal]} tensor {tensor_flat_sub[idx_not_equal]}") + assert equal, f"on_get_saved_tensor equal tensors {equal} num_not_equal {num_not_equal}/{num_elements} idx_not_equal {idx_not_equal} original_tensor {original_flat_sub[idx_not_equal]} tensor {tensor_flat_sub[idx_not_equal]}" + #debug_print(f"on_get_saved_tensor original: {saved_state._original_tensor.shape} tensor: {saved_state._tensor.shape} equal tensors {equal} num_not_equal {num_not_equal}/{num_elements} idx_not_equal {idx_not_equal} original_tensor {original_flat_sub[idx_not_equal]} tensor {tensor_flat_sub[idx_not_equal]}") #debug_print(f"on_get_saved_tensor equal tensors {torch.equal(saved_state._original_tensor, saved_state._tensor)} original_tensor {original_flat[-100:]} tensor {tensor_flat[-100:]}") return saved_state._tensor else: @@ -1322,6 +1378,9 @@ def packed_moe_expert_offloading_init_chunk_handler(vp_size, vp_stage): if offload_manager.max_tokens_per_vp_stage is None: offload_manager.max_tokens_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] offload_manager.temp_tokens_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] + if offload_manager.max_tokens_across_vp_stages is None: + offload_manager.max_tokens_across_vp_stages = {} + offload_manager.temp_tokens_across_vp_stages = {} def packed_moe_expert_offloading_set_last_layer(is_last_layer=False): """Set the last layer flag.""" @@ -1357,6 +1416,7 @@ def packed_moe_expert_offloading_reset(enabled=True): offload_manager.allocate_offload_buffers(stash_buffer_size_factor=stash_buffer_size_factor) debug_print(f'packed_moe_expert_offloading_reset captured schedule: {offload_manager._pp_schedule}') debug_print(f'packed_moe_expert_offloading_reset max_tokens_per_vp_stage: {offload_manager.max_tokens_per_vp_stage}') + debug_print(f'packed_moe_expert_offloading_reset max_tokens_across_vp_stages: {offload_manager.max_tokens_across_vp_stages}') elif offload_manager.status == 'captured': pass else: From 8a593b0777cd07b40603a676f3181529b9d529f0 Mon Sep 17 00:00:00 2001 From: a Date: Wed, 26 Nov 2025 16:14:04 -0800 Subject: [PATCH 17/57] Disable Packed Offloading for validation --- megatron/core/full_cuda_graph.py | 3 +-- .../pipeline_parallel/moe_packed_offload.py | 13 +++++++++++-- megatron/core/pipeline_parallel/schedules.py | 18 +++--------------- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 38257e8cf7c..84efd4f0780 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -188,8 +188,7 @@ def __call__(self, *args, **kwargs): if FullCudaGraphWrapper.cuda_graph[training_str] is None: FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: - if training_str == 'training': - packed_moe_expert_offloading_reset(enabled=self.packed_moe_expert_offloading) + packed_moe_expert_offloading_reset(enabled=self.packed_moe_expert_offloading and training) FullCudaGraphWrapper.cuda_graph[training_str].replay() self.speculative_cuda_graph_check(model) self.next_iter(training_str) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 958b0eb9b81..19bed77242c 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -15,7 +15,7 @@ HAVE_TRITON = False # Packed Moe Expert Offload implementation for pipeline parallelism -DEBUG = True +DEBUG = False DEBUG_RANK = [0] def debug_print(message): """Print debug message for a specific rank when DEBUG is enabled.""" @@ -988,6 +988,7 @@ def get_instance(cls): def __init__(self): """Initialize the manager with queues and dedicated CUDA streams.""" # allocate streams and events for synchronization + self.enabled = False self._pack_stream = torch.cuda.Stream() self._unpack_stream = torch.cuda.Stream() self._pack_stream_status = 'idle' # idle, offloading @@ -1350,6 +1351,8 @@ def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num """Get the fine-grained offload context""" #debug_print(f'get_packed_moe_expert_offloading_context name {name}') offload_manager = PackedOffloadManager.get_instance() + if not offload_manager.enabled: + return nullcontext() offload_manager.max_num_tokens = max_num_tokens assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) offload_manager.num_tokens_tensor = num_tokens_tensor @@ -1363,13 +1366,16 @@ def packed_moe_expert_offloading_group_commit(tensor, name=None): #debug_print(f'{rank}: packed_moe_expert_offloading_group_commit tensor {tensor.shape}-{tensor.dtype} name {name}') offload_manager = PackedOffloadManager.get_instance() offload_manager.device = tensor.device - + if not offload_manager.enabled: + return tensor return PP_ScheduleFunction.apply(tensor, offload_manager) def packed_moe_expert_offloading_init_chunk_handler(vp_size, vp_stage): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" #debug_print(f'packed_moe_expert_offloading_init_chunk_handler vp_size {vp_size} vp_stage {vp_stage}') offload_manager = PackedOffloadManager.get_instance() + if not offload_manager.enabled: + return offload_manager.current_vp_stage = vp_stage if vp_stage is not None else 0 if vp_size is not None: offload_manager.vp_size = vp_size @@ -1387,11 +1393,14 @@ def packed_moe_expert_offloading_set_last_layer(is_last_layer=False): #PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) #debug_print(f'packed_moe_expert_offloading_set_last_layer is_last_layer {is_last_layer}') offload_manager = PackedOffloadManager.get_instance() + if not offload_manager.enabled: + return offload_manager._last_layer = is_last_layer def packed_moe_expert_offloading_reset(enabled=True): """Reset the chunk handler, called at the start of a training iteration.""" offload_manager = PackedOffloadManager.get_instance() + offload_manager.enabled = enabled offload_manager.iteration += 1 # current layer and microbatch for each vp stage for forward pass offload_manager.current_schedule_index = 0 diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 3dc823538b0..75824d4f3a7 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -593,8 +593,7 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - if not forward_only: - packed_moe_expert_offloading_reset(enabled=config.packed_moe_expert_offloading) + packed_moe_expert_offloading_reset(enabled=config.packed_moe_expert_offloading and not forward_only) no_sync_func = config.no_sync_func if no_sync_func is None: @@ -1055,8 +1054,7 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" - if not forward_only and config.packed_moe_expert_offloading: - packed_moe_expert_offloading_reset() + packed_moe_expert_offloading_reset(enabled=config.packed_moe_expert_offloading and not forward_only) if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -1554,12 +1552,8 @@ def forward_backward_helper_wrapper( send_next_wait_handle = None send_prev_wait_handle = None recv_next_wait_handles = [] - model_chunk_ids = {0: [1,3], 1:[2,4]} - print (f'{torch.distributed.get_rank()}: forward_backward_pipelining_with_interleaving num_warmup_microbatches {num_warmup_microbatches} num_microbatches_1f1b {num_microbatches_remaining} total_num_microbatches {total_num_microbatches}') for k in range(num_warmup_microbatches): cur_model_chunk_id = get_model_chunk_id(k, forward=True) - # if torch.distributed.get_rank() in [0, 2]: - # print(f'{pipeline_parallel_rank}: +++++ warmup iteration {k}, fwd_model_chunk_id {model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}') if config.overlap_p2p_comm_warmup_flush: if ( not ( @@ -1731,9 +1725,6 @@ def forward_backward_helper_wrapper( if config.overlap_p2p_comm: backward_k = k - # if torch.distributed.get_rank() in [0, 2]: - # print(f'{pipeline_parallel_rank}: +++++ steady iteration forward_k {forward_k} fwd_model_chunk_id {model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}, backward_k {backward_k} bwd_model_chunk_id {-model_chunk_ids[pipeline_parallel_rank][get_model_chunk_id(backward_k, forward=False)]}') - # Sync forward recv def pp_pre_forward(vp_stage=None): if vp_stage is None: @@ -1951,8 +1942,6 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): ) for k in range(num_microbatches_remaining, total_num_microbatches): cur_model_chunk_id = get_model_chunk_id(k, forward=False) - # if torch.distributed.get_rank() in [0, 2]: - # print(f'{pipeline_parallel_rank}: cooldown iteration k {k} bwd_model_chunk_id {-model_chunk_ids[pipeline_parallel_rank][cur_model_chunk_id]}') if ( not (_is_vp_last_stage(vp_stage=cur_model_chunk_id) and is_pp_last_stage(pp_group)) and k != 0 @@ -2249,8 +2238,7 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - if not forward_only and config.packed_moe_expert_offloading: - packed_moe_expert_offloading_reset() + packed_moe_expert_offloading_reset(enabled=config.packed_moe_expert_offloading and not forward_only) # Disable async grad reductions no_sync_func = config.no_sync_func From db2b1acbdac1b63a7a09c8aa5e4f5acf5be15d16 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Thu, 27 Nov 2025 15:39:11 +0800 Subject: [PATCH 18/57] Fixe perf issue in packed stash/pop kernels --- .../pipeline_parallel/moe_packed_offload.py | 58 +++++++------------ 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 19bed77242c..146df0532fe 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -72,16 +72,16 @@ def _stash_copy_kernel_2d( new_free_offset_ptr, # Output: new free_offset value (written by kernel) HIDDEN_SIZE: tl.constexpr, # Hidden dimension (compile-time constant) BLOCK_SIZE: tl.constexpr, # Threads per block (for hidden dimension) - tokens_per_block: tl.constexpr, # Number of tokens each block handles ): """2D Triton kernel to copy tensor data to stash buffer. Grid: (num_blocks,) - fixed number of blocks - Each block handles multiple tokens (tokens_per_block) using a while loop. + Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. Works directly with contiguous 2D tensors [tokens, hidden_size]. Offsets are tracked in tokens, not elements. """ pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) # Load parameters (in tokens, not elements) num_tokens = tl.load(num_tokens_ptr) @@ -106,13 +106,9 @@ def _stash_copy_kernel_2d( if overflow_detected: return - # Each block handles multiple tokens - token_start = pid * tokens_per_block - token_end = min(token_start + tokens_per_block, num_tokens) - - # Process tokens assigned to this block - token_idx = token_start - while token_idx < token_end: + # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] + token_idx = pid + while token_idx < num_tokens: # Calculate destination token index with wraparound dst_token_idx = (free_offset + token_idx) % capacity @@ -141,7 +137,8 @@ def _stash_copy_kernel_2d( data = tl.load(src_base + hidden_offsets) tl.store(dst_base + hidden_offsets, data) - token_idx += 1 + # Stride to next token for this block + token_idx += num_blocks # Update new_free_offset (only first block writes it) if pid == 0: @@ -159,30 +156,26 @@ def _stash_pop_kernel_2d( capacity_ptr, # In tokens (read-only) HIDDEN_SIZE: tl.constexpr, # Hidden dimension (compile-time constant) BLOCK_SIZE: tl.constexpr, # Threads per block (for hidden dimension) - tokens_per_block: tl.constexpr, # Number of tokens each block handles ): """2D Triton kernel to reload tensor data from stash buffer. Grid: (num_blocks,) - fixed number of blocks - Each block handles multiple tokens (tokens_per_block) using a while loop. + Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. Works directly with contiguous 2D tensors [tokens, hidden_size]. Offsets are tracked in tokens, not elements. Uses LIFO (stack) semantics - moves free_offset backward after popping. """ pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) # Load parameters (in tokens, not elements) num_tokens = tl.load(num_tokens_ptr) tensor_offset = tl.load(tensor_offset_ptr) # Where data was stashed capacity = tl.load(capacity_ptr) - # Each block handles multiple tokens - token_start = pid * tokens_per_block - token_end = min(token_start + tokens_per_block, num_tokens) - - # Process tokens assigned to this block - token_idx = token_start - while token_idx < token_end: + # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] + token_idx = pid + while token_idx < num_tokens: # Calculate source token index with wraparound src_token_idx = (tensor_offset + token_idx) % capacity @@ -211,7 +204,8 @@ def _stash_pop_kernel_2d( data = tl.load(src_base + hidden_offsets) tl.store(dst_base + hidden_offsets, data) - token_idx += 1 + # Stride to next token for this block + token_idx += num_blocks # For LIFO (stack) behavior: move free_offset backward # After popping, free_offset should be at tensor_offset (freeing the space we just read) @@ -564,14 +558,10 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): # Determine grid size with cap on max blocks BLOCK_SIZE = GLOBAL_BLOCK_SIZE - total_blocks_needed = self.max_num_tokens # Ideally 1 block per token - - # Cap the number of blocks and calculate tokens per block - num_blocks = min(total_blocks_needed, max_blocks) - tokens_per_block = (self.max_num_tokens + num_blocks - 1) // num_blocks # Ceiling division + num_blocks = min(self.max_num_tokens, max_blocks) if DEBUG: - debug_print (f"offload_to_stash ({self.layer_name}) {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks} tokens_per_block {tokens_per_block} overflow {stash_buffer.overflow.item()}") + debug_print (f"offload_to_stash ({self.layer_name}) {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks} overflow {stash_buffer.overflow.item()}") # grid = (num_blocks,) self.stash_buffer_offset = stash_buffer.free_offset.clone() @@ -579,7 +569,7 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): # Create temporary tensor for new offset (kernel will write to this) new_free_offset_tensor = torch.empty(1, dtype=torch.int64, device=self.device) - # Launch Triton kernel to copy data (2D version) + # Launch Triton kernel to copy data (2D version, strided access) # self.offload_stream.wait_stream(torch.cuda.current_stream()) # with torch.cuda.stream(self.offload_stream): # TODO: make this async. Something unexpected with TE on deallocate the tensor @@ -594,7 +584,6 @@ def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): new_free_offset_tensor, # Write: New free_offset computed by kernel HIDDEN_SIZE=self.hidden_size, BLOCK_SIZE=BLOCK_SIZE, - tokens_per_block=tokens_per_block, ) # Copy new offset value after kernel completes (stream-ordered) @@ -632,19 +621,15 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): # Determine grid size with cap on max blocks BLOCK_SIZE = GLOBAL_BLOCK_SIZE - total_blocks_needed = self.max_num_tokens # Ideally 1 block per token - - # Cap the number of blocks and calculate tokens per block - num_blocks = min(total_blocks_needed, max_blocks) - tokens_per_block = (self.max_num_tokens + num_blocks - 1) // num_blocks # Ceiling division + num_blocks = min(self.max_num_tokens, max_blocks) if DEBUG: - debug_print (f"reload_from_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks} tokens_per_block {tokens_per_block}") + debug_print (f"reload_from_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks}") # grid = (num_blocks,) - # Launch Triton kernel to copy data (2D version) + # Launch Triton kernel to copy data (2D version, strided access) # self.offload_stream.wait_stream(torch.cuda.current_stream()) # with torch.cuda.stream(self.offload_stream): @@ -660,7 +645,6 @@ def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): stash_buffer.capacity, # Read-only: Capacity of the buffer (in tokens) HIDDEN_SIZE=self.hidden_size, BLOCK_SIZE=BLOCK_SIZE, - tokens_per_block=tokens_per_block, ) #torch.cuda.synchronize() @@ -973,7 +957,7 @@ class PackedOffloadManager: """ Singleton manager for coordinating activation offloading across pipeline stages. Manages chunk handlers, synchronizes GPU-GPU transfers, - and handles virtual pipeline parallelism. + and handles virtual pipeline parallelism """ OFFLOAD_MGR = None From 22beaa5d7468bf4b398b8d3256a8337e7401ef7f Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Sun, 7 Dec 2025 13:14:39 +0800 Subject: [PATCH 19/57] Minor fix for tensor allocation and padding requirement on budget --- megatron/core/pipeline_parallel/moe_packed_offload.py | 8 ++++---- megatron/core/transformer/moe/token_dispatcher.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 146df0532fe..accc99f74d1 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -299,10 +299,10 @@ def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype): # Head and tail pointers for free_list circular buffer self.free_list_head = torch.zeros(1, dtype=torch.int64, device=device) # Read pointer (allocation) - self.free_list_tail = torch.tensor([self.num_pages], dtype=torch.int64, device=device) # Write pointer (deallocation) + self.free_list_tail = self.num_pages*torch.ones(1, dtype=torch.int64, device=device) # Write pointer (deallocation) # Capacity of free list - self.free_list_capacity = torch.tensor([self.num_pages], dtype=torch.int64, device=device) + self.free_list_capacity = self.num_pages*torch.ones(1, dtype=torch.int64, device=device) def reset(self): """Reset the paged buffer - reinitialize free list.""" @@ -1171,8 +1171,8 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: Returns a tag to identify the tensor later. """ - - if self.max_num_tokens is None or tensor.size(0) != self.max_num_tokens: + # Handle 0-dim tensors (torch.Size([])) - they have no size(0) + if self.max_num_tokens is None or tensor.dim() == 0 or tensor.size(0) != self.max_num_tokens: return tensor.detach() if isinstance(tensor, MXFP8Tensor): #debug_print(f'on_save_for_backward MXFP8Tensor ({self._current_layer_name}) ndim {tensor.ndim} shape {tensor.shape} {tensor.dtype} rowwise {tensor._rowwise_data is not None} columnwise {(tensor._columnwise_data.shape, tensor._columnwise_data.dtype) if tensor._columnwise_data is not None else None}-scale_inv {tensor._columnwise_scale_inv.shape} {tensor._columnwise_scale_inv.dtype}') diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index b93ff8d4167..822ee7cfc10 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1055,7 +1055,9 @@ def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): self.token_probs = probs.reshape(num_tokens, self.num_experts) if self.packed_offloading_capacity_factor is not None: - budget = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) + pad_multiple = get_fp8_align_size(self.config.fp8_recipe) + budget = int(routing_map.shape[0] * self.config.moe_router_topk * (self.packed_offloading_capacity_factor+1)) + budget += -budget % pad_multiple routing_map_maybe_dropped, over_budget = self.budget_check(routing_map, budget) self.over_budget |= over_budget self.num_dispatched_tokens = budget From 3fba43330c758940437db59746a5ec87be6eeb05 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Sun, 7 Dec 2025 14:36:22 +0800 Subject: [PATCH 20/57] Packed/paged offloading is current not stream-safe. Need to put stash/restore on the same stream fixed a minor issue in calcualting budget --- megatron/core/pipeline_parallel/moe_packed_offload.py | 4 +++- megatron/core/transformer/moe/token_dispatcher.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index accc99f74d1..4dca75a74fa 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -974,7 +974,9 @@ def __init__(self): # allocate streams and events for synchronization self.enabled = False self._pack_stream = torch.cuda.Stream() - self._unpack_stream = torch.cuda.Stream() + # Currently paged/packed stashing is not stream-safe, so use the same stream for packing + # and unpacking + self._unpack_stream = self._pack_stream self._pack_stream_status = 'idle' # idle, offloading self._unpack_stream_status = 'idle' # idle, reloading self.packed_tensors_to_offload = [] diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 822ee7cfc10..1bb7e2e2f5f 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1056,7 +1056,7 @@ def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): if self.packed_offloading_capacity_factor is not None: pad_multiple = get_fp8_align_size(self.config.fp8_recipe) - budget = int(routing_map.shape[0] * self.config.moe_router_topk * (self.packed_offloading_capacity_factor+1)) + budget = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) budget += -budget % pad_multiple routing_map_maybe_dropped, over_budget = self.budget_check(routing_map, budget) self.over_budget |= over_budget From 63a7dcae0090ed814b46c455e0971f680ef66ae1 Mon Sep 17 00:00:00 2001 From: tongliu Date: Mon, 8 Dec 2025 20:51:33 -0800 Subject: [PATCH 21/57] add new hybrid ep --- megatron/core/transformer/moe/fused_a2a.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/megatron/core/transformer/moe/fused_a2a.py b/megatron/core/transformer/moe/fused_a2a.py index 39f50a4a670..ae446cf16fa 100644 --- a/megatron/core/transformer/moe/fused_a2a.py +++ b/megatron/core/transformer/moe/fused_a2a.py @@ -3,6 +3,7 @@ # Copyright (c) 2025 DeepSeek # Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE +from megatron.core.utils import internal_api from megatron.core.utils import internal_api try: @@ -365,6 +366,9 @@ def forward( # If we provide the num_permuted_tokens, we do not need to use sync to # wait for the data in pinned memory ready non_blocking = num_permuted_tokens is not None + # If we provide the num_permuted_tokens, we do not need to use sync to + # wait for the data in pinned memory ready + non_blocking = num_permuted_tokens is not None # Process the dispatch ( dispatched_hidden, @@ -381,6 +385,7 @@ def forward( pad_multiple=pad_multiple, num_permuted_tokens=num_permuted_tokens, non_blocking=non_blocking, + non_blocking=non_blocking, ) ctx.handle = handle @@ -401,10 +406,12 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper handle = ctx.handle combined_hidden, combined_probs = _hybrid_ep_buffer.combine_with_unpermute( hidden=grad_x, probs=grad_probs, handle=handle, pad_multiple=ctx.pad_multiple + hidden=grad_x, probs=grad_probs, handle=handle, pad_multiple=ctx.pad_multiple ) return combined_hidden, None, combined_probs, None, None, None, None, None, None, None +@internal_api @internal_api class HybridEPCombine(torch.autograd.Function): ''' @@ -412,12 +419,14 @@ class HybridEPCombine(torch.autograd.Function): ''' @staticmethod + def forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None): def forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None): ''' Forward pass of fused combine of the HybridEP backend ''' combined_hidden, _ = _hybrid_ep_buffer.combine_with_unpermute( hidden=x, handle=handle, pad_multiple=pad_multiple + hidden=x, handle=handle, pad_multiple=pad_multiple ) ctx.handle = handle ctx.pad_multiple = pad_multiple @@ -442,6 +451,7 @@ def backward(ctx, grad_x): if HAVE_HYBRIDEP: + @internal_api @internal_api def hybrid_ep_dispatch( x, @@ -493,6 +503,8 @@ def hybrid_ep_dispatch( pad_multiple, ) + @internal_api + def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): @internal_api def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): ''' @@ -512,6 +524,7 @@ def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): is performed. ''' return HybridEPCombine.apply(x, handle, num_permuted_tokens, pad_multiple) + return HybridEPCombine.apply(x, handle, num_permuted_tokens, pad_multiple) else: hybrid_ep_dispatch = None From f9f2c7bc3c9d65d266f3514e7f5a3c99e904e6d1 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Wed, 10 Dec 2025 10:10:23 +0800 Subject: [PATCH 22/57] Remove the overflow check in framework because it is now done by hybridEP --- megatron/core/transformer/moe/fused_a2a.py | 18 ------------ .../core/transformer/moe/token_dispatcher.py | 29 ++----------------- 2 files changed, 3 insertions(+), 44 deletions(-) diff --git a/megatron/core/transformer/moe/fused_a2a.py b/megatron/core/transformer/moe/fused_a2a.py index ae446cf16fa..ad935cea66e 100644 --- a/megatron/core/transformer/moe/fused_a2a.py +++ b/megatron/core/transformer/moe/fused_a2a.py @@ -3,9 +3,6 @@ # Copyright (c) 2025 DeepSeek # Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE -from megatron.core.utils import internal_api -from megatron.core.utils import internal_api - try: from deep_ep import Buffer from deep_ep.utils import EventHandle, EventOverlap @@ -366,9 +363,6 @@ def forward( # If we provide the num_permuted_tokens, we do not need to use sync to # wait for the data in pinned memory ready non_blocking = num_permuted_tokens is not None - # If we provide the num_permuted_tokens, we do not need to use sync to - # wait for the data in pinned memory ready - non_blocking = num_permuted_tokens is not None # Process the dispatch ( dispatched_hidden, @@ -385,7 +379,6 @@ def forward( pad_multiple=pad_multiple, num_permuted_tokens=num_permuted_tokens, non_blocking=non_blocking, - non_blocking=non_blocking, ) ctx.handle = handle @@ -406,27 +399,22 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper handle = ctx.handle combined_hidden, combined_probs = _hybrid_ep_buffer.combine_with_unpermute( hidden=grad_x, probs=grad_probs, handle=handle, pad_multiple=ctx.pad_multiple - hidden=grad_x, probs=grad_probs, handle=handle, pad_multiple=ctx.pad_multiple ) return combined_hidden, None, combined_probs, None, None, None, None, None, None, None -@internal_api -@internal_api class HybridEPCombine(torch.autograd.Function): ''' Fused combine operation for permute + combine a2a + permute using the HybridEP backend ''' @staticmethod - def forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None): def forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None): ''' Forward pass of fused combine of the HybridEP backend ''' combined_hidden, _ = _hybrid_ep_buffer.combine_with_unpermute( hidden=x, handle=handle, pad_multiple=pad_multiple - hidden=x, handle=handle, pad_multiple=pad_multiple ) ctx.handle = handle ctx.pad_multiple = pad_multiple @@ -451,8 +439,6 @@ def backward(ctx, grad_x): if HAVE_HYBRIDEP: - @internal_api - @internal_api def hybrid_ep_dispatch( x, routing_map, @@ -503,9 +489,6 @@ def hybrid_ep_dispatch( pad_multiple, ) - @internal_api - def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): - @internal_api def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): ''' Perform fused combine operation for unpermute + combine a2a + unpermute @@ -524,7 +507,6 @@ def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): is performed. ''' return HybridEPCombine.apply(x, handle, num_permuted_tokens, pad_multiple) - return HybridEPCombine.apply(x, handle, num_permuted_tokens, pad_multiple) else: hybrid_ep_dispatch = None diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 1bb7e2e2f5f..ad339e301ba 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1027,28 +1027,6 @@ def __init__( self.packed_offloading_capacity_factor = self.config.moe_expert_capacity_factor_for_packed_offloading self.over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') - def budget_check(self, routing_map, budget): - # TODO: the check should be done in hybridep to avoid the AG below - # routing_map: [num_local_tokens, world_size, num_local_experts] - num_local_tokens_per_expert = routing_map.sum(dim=0).flatten() - num_tokens_per_expert = torch.empty( - self.group.size(), - self.num_experts, - device=num_local_tokens_per_expert.device, - dtype=num_local_tokens_per_expert.dtype, - ) - torch.distributed.all_gather_into_tensor( - num_tokens_per_expert, num_local_tokens_per_expert, self.group - ) - - num_global_tokens_per_expert =num_tokens_per_expert.sum(dim=0) - if self.config.fp8: - pad_multiple = get_fp8_align_size(self.config.fp8_recipe) - num_global_tokens_per_expert += -num_global_tokens_per_expert % pad_multiple - num_tokens_per_ep_rank = num_global_tokens_per_expert.view(routing_map.shape[1], routing_map.shape[2]).sum(dim=-1) - routing_map_maybe_dropped, over_budget = drop_routing_map_triton(routing_map, budget, num_tokens_per_ep_rank) - return routing_map_maybe_dropped, over_budget - def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): num_tokens = routing_map.shape[0] self.routing_map = routing_map.reshape(num_tokens, self.num_experts) @@ -1058,11 +1036,7 @@ def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): pad_multiple = get_fp8_align_size(self.config.fp8_recipe) budget = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) budget += -budget % pad_multiple - routing_map_maybe_dropped, over_budget = self.budget_check(routing_map, budget) - self.over_budget |= over_budget - self.num_dispatched_tokens = budget self.num_permuted_tokens = budget - self.routing_map = routing_map_maybe_dropped.reshape(num_tokens, self.num_experts) # Compute the capacity for each expert at the drop_and_pad mode if self.drop_and_pad: num_out_tokens = num_tokens * self.config.moe_router_topk @@ -1107,6 +1081,9 @@ def dispatch( pad_multiple=self.pad_multiple, ) ) + if self.packed_offloading_capacity_factor is not None: + over_budget = self.handle[8] != 0 # this is overflow_flag + self.over_budget |= over_budget if self.num_permuted_tokens is None: self.tokens_per_expert = tokens_per_expert.to(torch.int64) From 34438d9ccbc431000898972babbb6ed7f4f36cd3 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Wed, 10 Dec 2025 14:32:02 +0800 Subject: [PATCH 23/57] Fix one merge conflict fix one change that broke full-iter CUDA graph --- megatron/core/transformer/moe/token_dispatcher.py | 2 +- megatron/training/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index ad339e301ba..41a2f59e449 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1033,7 +1033,7 @@ def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): self.token_probs = probs.reshape(num_tokens, self.num_experts) if self.packed_offloading_capacity_factor is not None: - pad_multiple = get_fp8_align_size(self.config.fp8_recipe) + pad_multiple = get_align_size_for_quantization(self.config) budget = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) budget += -budget % pad_multiple self.num_permuted_tokens = budget diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 7844b450136..d083f07e7ba 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -562,7 +562,8 @@ def _broadcast(item): def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) - n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + n_tensor = torch.ones(1, dtype=torch.int64, device=dev) * n + # n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) _broadcast(n_tensor) if n == 0: From 0d552889288b66f4af1192d3ebf22342c347e798 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Thu, 11 Dec 2025 12:21:32 -0800 Subject: [PATCH 24/57] Code cleanup --- megatron/core/full_cuda_graph.py | 2 +- .../pipeline_parallel/moe_packed_offload.py | 291 +----------------- 2 files changed, 15 insertions(+), 278 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 84efd4f0780..430a965b477 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -165,7 +165,7 @@ def __call__(self, *args, **kwargs): training_str = 'training' if training else 'validation' curr_iteration = self.curr_iter(training_str) if curr_iteration == self.cuda_graph_warmup_steps: - print(f'Capture CUDA graph for {training_str}!!!') + logger.info(f'Capture CUDA graph for {training_str}!!!') torch.distributed.barrier() assert FullCudaGraphWrapper.cuda_graph[training_str] is None FullCudaGraphWrapper.cuda_graph[training_str] = torch.cuda.CUDAGraph() diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 4dca75a74fa..4e354497c11 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -14,18 +14,6 @@ except ImportError: HAVE_TRITON = False -# Packed Moe Expert Offload implementation for pipeline parallelism -DEBUG = False -DEBUG_RANK = [0] -def debug_print(message): - """Print debug message for a specific rank when DEBUG is enabled.""" - # pylint: disable=bad-builtin - if not DEBUG: - return - assert torch.distributed.is_initialized() - if torch.distributed.get_rank() in DEBUG_RANK: - print(f'{torch.distributed.get_rank()}: {message}') - def set_ideal_affinity_for_current_gpu(): """Set CPU affinity for the current GPU to optimize host-device transfers.""" import uuid @@ -38,7 +26,6 @@ def set_ideal_affinity_for_current_gpu(): import cuda.cuda as cuda_driver import cuda.cudart as cuda_runtime except ImportError: - # print("cuda-python may not be installed, skipping GPU affinity setting") warnings.warn("cuda-python may not be installed, skipping GPU affinity setting") return try: @@ -491,173 +478,9 @@ def _paged_stash_pop_kernel( tl.store(new_free_list_tail_ptr, new_tail) -class PackedTensor: - """ - A class to represent a packed tensor. - """ - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None): - self._tensor = tensor - self._original_tensor = None - assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1, f"num_tokens_tensor {num_tokens_tensor} is not a scalar tensor" - self.num_tokens_tensor = num_tokens_tensor.clone() - self.vp_stage = vp_stage - self.schedule_layer_no = schedule_layer_no - self.layer_name = layer_name - self.max_tokens = max_tokens - # Original tensor information - self.original_shape = list(tensor.shape) - self.max_num_tokens = self.original_shape[0] - self.element_size = tensor.element_size() - self.hidden_size = self.original_shape[1] - self.dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype - self.device = tensor.device - - self.stash_buffer_offset = None - - @property - def schedule_layer(self): - """Get the schedule layer.""" - return self.schedule_layer_no - - def offload_to_stash(self, stash_buffer: StashBuffer, max_blocks=2048): - """Offload the packed tensor.""" - #self._tensor.record_stream(torch.cuda.current_stream()) - # TODO: Call offload function to offload the tensor - # After offload stream joins main stream, the tensor is no longer needed and can be freed - - #pass - - """Copy tensor content into stash_buffer starting at current offset using Triton kernel. - - Out-of-bound writes are silently ignored by the kernel. - Increments self.over_capacity counter if capacity was exceeded. - - Args: - tensor (torch.Tensor): The tensor to stash. Will be flattened before copying. - size (torch.Tensor): GPU tensor containing the number of bytes to copy. - max_blocks (int): Maximum number of blocks to launch. Defaults to 2048. - - Returns: - offset: GPU tensor indicating the offset where the tensor was stashed. - - Raises: - RuntimeError: If Triton is not available. - """ - if not HAVE_TRITON: - raise RuntimeError("Triton is required for PackedTensor.offload_to_stash(). Please install triton.") - - self._tensor = self._tensor.contiguous() - if self.num_tokens_tensor.dim() == 0: - self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) - - # Get 2D tensor (no flattening) - if isinstance(self._tensor, MXFP8Tensor): - tensor_to_copy = self._tensor._columnwise_data - else: - tensor_to_copy = self._tensor - - # Determine grid size with cap on max blocks - BLOCK_SIZE = GLOBAL_BLOCK_SIZE - num_blocks = min(self.max_num_tokens, max_blocks) - - if DEBUG: - debug_print (f"offload_to_stash ({self.layer_name}) {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks} overflow {stash_buffer.overflow.item()}") - # - grid = (num_blocks,) - self.stash_buffer_offset = stash_buffer.free_offset.clone() - - # Create temporary tensor for new offset (kernel will write to this) - new_free_offset_tensor = torch.empty(1, dtype=torch.int64, device=self.device) - - # Launch Triton kernel to copy data (2D version, strided access) - # self.offload_stream.wait_stream(torch.cuda.current_stream()) - # with torch.cuda.stream(self.offload_stream): - # TODO: make this async. Something unexpected with TE on deallocate the tensor - _stash_copy_kernel_2d[grid]( - tensor_to_copy, - stash_buffer.buffer, - self.num_tokens_tensor, # Use stored num_tokens (not from shape) - stash_buffer.alloc_offset, # Read-only: Write boundary (in tokens) - stash_buffer.free_offset, # Read-only: Current offset - stash_buffer.capacity, # Read-only: Capacity of the buffer (in tokens) - stash_buffer.overflow, # Read+Write: Over capacity flag - new_free_offset_tensor, # Write: New free_offset computed by kernel - HIDDEN_SIZE=self.hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Copy new offset value after kernel completes (stream-ordered) - stash_buffer.free_offset.copy_(new_free_offset_tensor) - - # save reference to original tensor to avoid deallocation before offload is complete - self._original_tensor = self._tensor - # set tensor to None. This will be replaced by reload_from_stash. - self._tensor = None - if DEBUG: - debug_print (f"After offload_to_stash offset {self.stash_buffer_offset.item()} free_offset {stash_buffer.free_offset.item()} overflow {stash_buffer.overflow.item()} capacity {stash_buffer.capacity.item()} max_tokens {self.max_tokens}") - - - def reload_from_stash(self, stash_buffer: StashBuffer, max_blocks=2048): - """Reload the packed tensor from the stash.""" - if not HAVE_TRITON: - raise RuntimeError("Triton is required for PackedTensor.reload_from_stash(). Please install triton.") - if isinstance(self._original_tensor, MXFP8Tensor): - columnwise_data = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) - self._tensor = MXFP8Tensor( - shape=self._original_tensor.shape, - dtype=self._original_tensor.dtype, - fp8_dtype=self._original_tensor._fp8_dtype, - rowwise_data=self._original_tensor._rowwise_data, - rowwise_scale_inv=self._original_tensor._rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=self._original_tensor._columnwise_scale_inv, - quantizer=self._original_tensor._quantizer, - ) - tensor_to_reload = self._tensor._columnwise_data - else: - self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) - tensor_to_reload = self._tensor - - - # Determine grid size with cap on max blocks - BLOCK_SIZE = GLOBAL_BLOCK_SIZE - num_blocks = min(self.max_num_tokens, max_blocks) - - if DEBUG: - debug_print (f"reload_from_stash {self._tensor.shape}-{self.dtype} stash_buffer {stash_buffer.buffer.dtype} num_tokens {self.num_tokens_tensor.item()} hidden_size {self.hidden_size} max_blocks {max_blocks} num_blocks {num_blocks}") - # - grid = (num_blocks,) - - - # Launch Triton kernel to copy data (2D version, strided access) - # self.offload_stream.wait_stream(torch.cuda.current_stream()) - # with torch.cuda.stream(self.offload_stream): - - # TODO: make this async. Something unexpected with TE on deallocate the tensor - # Note: free_offset is directly updated by the kernel (LIFO stack behavior) - _stash_pop_kernel_2d[grid]( - stash_buffer.buffer, - tensor_to_reload, - self.num_tokens_tensor, # Use stored num_tokens (not from shape) - self.stash_buffer_offset, # Read-only: Start offset for reload (in tokens) - stash_buffer.alloc_offset, # Read-only: Not used in pop kernel - stash_buffer.free_offset, # Write: Moved backward by kernel (LIFO) - stash_buffer.capacity, # Read-only: Capacity of the buffer (in tokens) - HIDDEN_SIZE=self.hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - #torch.cuda.synchronize() - if DEBUG: - debug_print (f"After reload_from_stash reload_offset {self.stash_buffer_offset.item()} alloc_offset {stash_buffer.alloc_offset.item()} free_offset {stash_buffer.free_offset.item()} capacity {stash_buffer.capacity.item()}") - def __repr__(self): - return f"PackedTensor(original_shape={self.original_shape}, num_tokens={self.num_tokens_tensor.item()}, vp_stage={self.vp_stage})" - - class PagedTensor: """ A paged tensor that stores data in pages within a paged stash buffer. - Similar to PackedTensor but uses page-level memory management. """ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0): @@ -733,16 +556,10 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) triton_tokens = self.max_num_tokens - d2d_tokens - if DEBUG: - debug_print(f"PagedTensor offload ({self.layer_name}) {self._tensor.shape}-{self.dtype} page_size={self.page_size} max_num_tokens={self.max_num_tokens} num_d2d_pages={self.num_d2d_pages} d2d_tokens={d2d_tokens} triton_tokens={triton_tokens}") - # Perform both D2D copy and Triton kernel together # Part 1: Copy first d2d_tokens to static_tensor using native PyTorch if d2d_tokens > 0: self.static_tensor[:d2d_tokens] = tensor_to_copy[:d2d_tokens] - if DEBUG: - debug_print(f"Copied {d2d_tokens} tokens to static_tensor using D2D") - # Part 2: Copy remaining tokens using Triton kernel if triton_tokens > 0: triton_tensor = tensor_to_copy[d2d_tokens:self.max_num_tokens] @@ -777,15 +594,9 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 # Update free list head paged_stash_buffer.free_list_head.copy_(new_free_list_head) - if DEBUG: - debug_print(f"Copied {triton_tokens} tokens using Triton kernel") - # Save reference to original tensor self._original_tensor = self._tensor self._tensor = None - - if DEBUG: - debug_print(f"After PagedTensor offload") def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): """Reload the paged tensor from paged stash buffer. @@ -820,15 +631,10 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) triton_tokens = self.max_num_tokens - d2d_tokens - if DEBUG: - debug_print(f"PagedTensor reload {self._tensor.shape}-{self.dtype} page_size={self.page_size} max_num_tokens={self.max_num_tokens} num_d2d_pages={self.num_d2d_pages} d2d_tokens={d2d_tokens} triton_tokens={triton_tokens}") - # Perform both D2D copy and Triton kernel together # Part 1: Copy first d2d_tokens from static_tensor using native PyTorch if d2d_tokens > 0 and self.static_tensor is not None: tensor_to_reload[:d2d_tokens] = self.static_tensor[:d2d_tokens] - if DEBUG: - debug_print(f"Reloaded {d2d_tokens} tokens from static_tensor using D2D") # Part 2: Copy remaining tokens using Triton kernel if triton_tokens > 0: @@ -862,15 +668,6 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 # Update free list tail paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) - - if DEBUG: - debug_print(f"Reloaded {triton_tokens} tokens using Triton kernel") - - if DEBUG: - debug_print(f"After PagedTensor reload") - - def __repr__(self): - return f"PagedTensor(original_shape={self.original_shape}, num_tokens={self.num_tokens_tensor.item()}, page_size={self.page_size}, vp_stage={self.vp_stage})" class PP_ScheduleFunction(torch.autograd.Function): @@ -895,11 +692,10 @@ def forward(ctx, tensor, offload_manager): # after forward # Deallocate original tensor after offload is complete while len(offload_manager.packed_tensors_offload_in_progress) > 0: packed_tensor = offload_manager.packed_tensors_offload_in_progress.pop(0) - if not DEBUG: - if isinstance(packed_tensor._original_tensor, MXFP8Tensor): - packed_tensor._original_tensor._columnwise_data = None - else: - packed_tensor._original_tensor = None + if isinstance(packed_tensor._original_tensor, MXFP8Tensor): + packed_tensor._original_tensor._columnwise_data = None + else: + packed_tensor._original_tensor = None if offload_manager.status == 'captured': current_schedule_layer = offload_manager.get_schedule_layer(ctx.vp_stage+1, ctx.layer_no, ctx.microbatch_no) @@ -920,7 +716,6 @@ def forward(ctx, tensor, offload_manager): # after forward @staticmethod def backward(ctx, *grad_output): # before backward # pylint: disable=missing-function-docstring - #debug_print(f"PP_ScheduleFunction vp_stage {ctx.vp_stage} before backward") if ctx.vp_stage is not None: ctx.offload_manager.update_pp_schedule(-(ctx.vp_stage+1), -ctx.layer_no, -ctx.microbatch_no) ctx.offload_manager.current_schedule_index += 1 @@ -932,7 +727,6 @@ def backward(ctx, *grad_output): # before backward if ctx.offload_manager.status == 'captured' and ctx.offload_manager.current_schedule_index < len(ctx.offload_manager._pp_schedule): next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index] if next_schedule_layer < 0: - debug_print(f"PP_ScheduleFunction backward reload_packed_tensors {next_schedule_layer}") # For last layer last microbatch, wait for offload to complete before reloading if ctx.offload_manager._pack_stream_status == 'offloading': assert len(ctx.offload_manager.packed_tensors_offload_in_progress) > 0, f"packed_tensors_offload_in_progress is empty" @@ -943,11 +737,10 @@ def backward(ctx, *grad_output): # before backward # Deallocate original tensor after offload is complete while len(ctx.offload_manager.packed_tensors_offload_in_progress) > 0: packed_tensor = ctx.offload_manager.packed_tensors_offload_in_progress.pop(0) - if not DEBUG: - if isinstance(packed_tensor._original_tensor, MXFP8Tensor): - packed_tensor._original_tensor._columnwise_data = None - else: - packed_tensor._original_tensor = None + if isinstance(packed_tensor._original_tensor, MXFP8Tensor): + packed_tensor._original_tensor._columnwise_data = None + else: + packed_tensor._original_tensor = None ctx.offload_manager.reload_packed_tensors(-next_schedule_layer) @@ -1056,17 +849,12 @@ def offload_packed_tensors(self, pp_schedule_layer): with torch.cuda.stream(self.pack_stream): if self.status == 'captured': self._pack_stream_status = 'offloading' - #assert self.packed_tensors_to_reload - #for packed_tensor in self.packed_tensors_to_offload: - # packed_tensor.offload_to_stash(self.stash_buffers[packed_tensor.vp_stage]) - debug_print(f"offload_packed_tensors {len(self.packed_tensors_to_offload)}") if pp_schedule_layer not in self.packed_tensors_to_reload: self.packed_tensors_to_reload[pp_schedule_layer] = [] assert len(self.packed_tensors_to_reload[pp_schedule_layer]) == 0, f"packed_tensors_to_reload {pp_schedule_layer} is not empty {self.packed_tensors_to_reload[pp_schedule_layer]}" while len(self.packed_tensors_to_offload) > 0: packed_tensor = self.packed_tensors_to_offload.pop(0) - #print (f'offload_packed_tensors vp_stage={packed_tensor.vp_stage} dtype={packed_tensor.dtype} hidden_size={packed_tensor.hidden_size} use_paged_stash={self.use_paged_stash}') stash_buffers_vp_stage = self.stash_buffers[packed_tensor.vp_stage] if not self.use_paged_stash else self.stash_buffers[0] stash_buffer = stash_buffers_vp_stage[packed_tensor.dtype][packed_tensor.hidden_size] packed_tensor.offload_to_stash(stash_buffer) @@ -1089,7 +877,6 @@ def reload_packed_tensors(self, pp_schedule_layer): if len(self.packed_tensors_to_reload[item]) > 0: count += 1 - debug_print(f"reload_packed_tensors {count}") while len(self.packed_tensors_to_reload[pp_schedule_layer]) > 0: packed_tensor = self.packed_tensors_to_reload[pp_schedule_layer].pop(0) stash_buffers_vp_stage = self.stash_buffers[packed_tensor.vp_stage] if not self.use_paged_stash else self.stash_buffers[0] @@ -1158,8 +945,6 @@ def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) num_tokens = self.num_tokens_tensor.item() - #debug_print(f"------{self.current_schedule_index} len PP_Schedule {len(self._pp_schedule)}") - #debug_print(f" {self.status} {self.current_schedule_index} {self._pp_schedule[self.current_schedule_index]} {vp_stage*100 + layer_no*10 + microbatch_no}") assert self._pp_schedule[self.current_schedule_index] == self.get_schedule_layer(vp_stage, layer_no, microbatch_no), f"schedule {self._pp_schedule[self.current_schedule_index]} != {self.get_schedule_layer(vp_stage, layer_no, microbatch_no)}" @@ -1177,11 +962,8 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if self.max_num_tokens is None or tensor.dim() == 0 or tensor.size(0) != self.max_num_tokens: return tensor.detach() if isinstance(tensor, MXFP8Tensor): - #debug_print(f'on_save_for_backward MXFP8Tensor ({self._current_layer_name}) ndim {tensor.ndim} shape {tensor.shape} {tensor.dtype} rowwise {tensor._rowwise_data is not None} columnwise {(tensor._columnwise_data.shape, tensor._columnwise_data.dtype) if tensor._columnwise_data is not None else None}-scale_inv {tensor._columnwise_scale_inv.shape} {tensor._columnwise_scale_inv.dtype}') assert tensor._rowwise_data is None, f"rowwise_data is not None; Only columnwise data is supported for packed offloading" - #if tensor.size(1) in [7168, 4096, 1] and DEBUG: - # return tensor.detach() if self.status == 'capture': self.num_tokens = self.num_tokens_tensor.item() @@ -1228,6 +1010,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: tensor = tensor_truncated # Create tensor (paged or regular based on configuration) + assert self.use_paged_stash, "Paged stashing must be used." if self.use_paged_stash: packed_tensor = PagedTensor( tensor, @@ -1239,16 +1022,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: page_size=self.page_size, num_d2d_pages=self.num_d2d_pages ) - else: - packed_tensor = PackedTensor( - tensor, - num_tokens_tensor=self.num_tokens_tensor, - vp_stage=self.current_vp_stage, - schedule_layer_no=self._pp_schedule[self.current_schedule_index] if self._pp_schedule is not None and self.current_schedule_index < len(self._pp_schedule) else None, - layer_name=self._current_layer_name, - max_tokens=self.max_num_tokens - ) - + if self.status == 'captured': self.add_packed_tensor_to_offload(packed_tensor) return packed_tensor @@ -1258,7 +1032,7 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: Hook called when autograd retrieves a saved tensor during backward pass. Returns the actual tensor (potentially reloading from CPU). """ - if isinstance(saved_state, (PackedTensor, PagedTensor)): + if isinstance(saved_state, (PagedTensor)): if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][saved_state.hidden_size] -= num_tokens @@ -1274,28 +1048,9 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: else: saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad) - if not DEBUG: - assert saved_state._tensor is not None, f"saved_state._tensor is None {saved_state._tensor}" - if saved_state._tensor is not None: - if self.status == 'captured' and DEBUG: - #debug_print(f"on_get_saved_tensor {saved_state._original_tensor.shape} {saved_state.num_tokens_tensor.item()}") - original_tensor = saved_state._original_tensor if not isinstance(saved_state._original_tensor, MXFP8Tensor) else saved_state._original_tensor._columnwise_data - if original_tensor is not None: - original_flat = original_tensor.flatten() if not isinstance(original_tensor, MXFP8Tensor) else original_tensor._columnwise_data.flatten() - tensor_flat = saved_state._tensor.flatten() if not isinstance(saved_state._tensor, MXFP8Tensor) else saved_state._tensor._columnwise_data.flatten() - num_elements = saved_state.num_tokens_tensor.item() * saved_state.hidden_size - original_flat_sub = original_flat[:num_elements] - tensor_flat_sub = tensor_flat[:num_elements] - equal = torch.equal(original_flat_sub, tensor_flat_sub) - num_not_equal = (original_flat_sub != tensor_flat_sub).sum() - idx_not_equal = (original_flat_sub != tensor_flat_sub).nonzero() - assert equal, f"on_get_saved_tensor equal tensors {equal} num_not_equal {num_not_equal}/{num_elements} idx_not_equal {idx_not_equal} original_tensor {original_flat_sub[idx_not_equal]} tensor {tensor_flat_sub[idx_not_equal]}" - #debug_print(f"on_get_saved_tensor original: {saved_state._original_tensor.shape} tensor: {saved_state._tensor.shape} equal tensors {equal} num_not_equal {num_not_equal}/{num_elements} idx_not_equal {idx_not_equal} original_tensor {original_flat_sub[idx_not_equal]} tensor {tensor_flat_sub[idx_not_equal]}") - #debug_print(f"on_get_saved_tensor equal tensors {torch.equal(saved_state._original_tensor, saved_state._tensor)} original_tensor {original_flat[-100:]} tensor {tensor_flat[-100:]}") - return saved_state._tensor - else: - return saved_state._original_tensor - + assert saved_state._tensor is not None, f"saved_state._tensor is None {saved_state._tensor}" + return saved_state._tensor + return saved_state class PackedOffloadContext: @@ -1335,7 +1090,6 @@ def packed_moe_expert_offloading_group_start(tensor, name=None): def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num_tokens_tensor=None): """Get the fine-grained offload context""" - #debug_print(f'get_packed_moe_expert_offloading_context name {name}') offload_manager = PackedOffloadManager.get_instance() if not offload_manager.enabled: return nullcontext() @@ -1349,7 +1103,6 @@ def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num def packed_moe_expert_offloading_group_commit(tensor, name=None): """Mark the end of a layer group and prepare for offload/reload.""" rank = torch.distributed.get_rank() - #debug_print(f'{rank}: packed_moe_expert_offloading_group_commit tensor {tensor.shape}-{tensor.dtype} name {name}') offload_manager = PackedOffloadManager.get_instance() offload_manager.device = tensor.device if not offload_manager.enabled: @@ -1358,7 +1111,6 @@ def packed_moe_expert_offloading_group_commit(tensor, name=None): def packed_moe_expert_offloading_init_chunk_handler(vp_size, vp_stage): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" - #debug_print(f'packed_moe_expert_offloading_init_chunk_handler vp_size {vp_size} vp_stage {vp_stage}') offload_manager = PackedOffloadManager.get_instance() if not offload_manager.enabled: return @@ -1376,8 +1128,6 @@ def packed_moe_expert_offloading_init_chunk_handler(vp_size, vp_stage): def packed_moe_expert_offloading_set_last_layer(is_last_layer=False): """Set the last layer flag.""" - #PipelineOffloadManager.get_instance().set_last_layer(is_last_layer) - #debug_print(f'packed_moe_expert_offloading_set_last_layer is_last_layer {is_last_layer}') offload_manager = PackedOffloadManager.get_instance() if not offload_manager.enabled: return @@ -1390,14 +1140,6 @@ def packed_moe_expert_offloading_reset(enabled=True): offload_manager.iteration += 1 # current layer and microbatch for each vp stage for forward pass offload_manager.current_schedule_index = 0 - if os.getenv('MEM_PROFILE', '0') == '1': - if offload_manager.iteration == 0 and torch.distributed.get_rank() == 0: - torch.cuda.memory._record_memory_history(max_entries=1000000) - print(f'packed_moe_expert_offloading_reset record_memory_history') - if offload_manager.iteration == 5 and torch.distributed.get_rank() == 0: - torch.cuda.memory._dump_snapshot("packed_offloading_cg.pkl") - torch.cuda.memory._record_memory_history(enabled=None) - print(f'packed_moe_expert_offloading_reset dump_snapshot') if not enabled: return @@ -1409,13 +1151,8 @@ def packed_moe_expert_offloading_reset(enabled=True): offload_manager.status = 'captured' stash_buffer_size_factor = float(os.getenv('STASH_BUFFER_SIZE_FACTOR', '1.10')) offload_manager.allocate_offload_buffers(stash_buffer_size_factor=stash_buffer_size_factor) - debug_print(f'packed_moe_expert_offloading_reset captured schedule: {offload_manager._pp_schedule}') - debug_print(f'packed_moe_expert_offloading_reset max_tokens_per_vp_stage: {offload_manager.max_tokens_per_vp_stage}') - debug_print(f'packed_moe_expert_offloading_reset max_tokens_across_vp_stages: {offload_manager.max_tokens_across_vp_stages}') elif offload_manager.status == 'captured': pass - else: - debug_print(f'packed_moe_expert_offloading_reset unknown status: {offload_manager.status}') if offload_manager.status == 'captured': if not torch.cuda.is_current_stream_capturing(): From da2052376e43e63eef056939d600ae1c9b1d8071 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Thu, 11 Dec 2025 18:25:45 -0800 Subject: [PATCH 25/57] Add second autograd to avoid triple buffering --- .../pipeline_parallel/moe_packed_offload.py | 85 +++++++++++-------- megatron/core/transformer/moe/experts.py | 2 + 2 files changed, 51 insertions(+), 36 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 4e354497c11..5f1f9ebb762 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -670,7 +670,31 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) -class PP_ScheduleFunction(torch.autograd.Function): +class PP_PreScheduleFunction(torch.autograd.Function): + """ + This function is used to update the pp schedule. + """ + + @staticmethod + def forward(ctx, tensor, offload_manager): # after forward + # pylint: disable=missing-function-docstring + ctx.offload_manager = offload_manager + # Wait for offload to complete before starting the next layer + offload_manager.wait_for_offload_to_complete() + return tensor + + @staticmethod + def backward(ctx, *grad_output): # before backward + # pylint: disable=missing-function-docstring + # Initiate reload for next layer + if ctx.offload_manager.status == 'captured' and ctx.offload_manager.current_schedule_index < len(ctx.offload_manager._pp_schedule): + next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index] + if next_schedule_layer < 0: + ctx.offload_manager.reload_packed_tensors(-next_schedule_layer) + + return grad_output + (None, None) + +class PP_PostScheduleFunction(torch.autograd.Function): """ This function is used to update the pp schedule. """ @@ -684,19 +708,8 @@ def forward(ctx, tensor, offload_manager): # after forward if ctx.vp_stage is None: ctx.vp_stage = 0 ctx.layer_no, ctx.microbatch_no = offload_manager.update_pp_schedule(ctx.vp_stage+1) - current_stream = torch.cuda.current_stream() - if offload_manager._pack_stream_status == 'offloading': - current_stream.wait_stream(offload_manager.pack_stream) - offload_manager._pack_stream_status = 'idle' - - # Deallocate original tensor after offload is complete - while len(offload_manager.packed_tensors_offload_in_progress) > 0: - packed_tensor = offload_manager.packed_tensors_offload_in_progress.pop(0) - if isinstance(packed_tensor._original_tensor, MXFP8Tensor): - packed_tensor._original_tensor._columnwise_data = None - else: - packed_tensor._original_tensor = None + # Initiate offload for current layer and reload for next layer if offload_manager.status == 'captured': current_schedule_layer = offload_manager.get_schedule_layer(ctx.vp_stage+1, ctx.layer_no, ctx.microbatch_no) next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index+1] @@ -720,29 +733,11 @@ def backward(ctx, *grad_output): # before backward ctx.offload_manager.update_pp_schedule(-(ctx.vp_stage+1), -ctx.layer_no, -ctx.microbatch_no) ctx.offload_manager.current_schedule_index += 1 current_stream = torch.cuda.current_stream() + + ctx.offload_manager.wait_for_offload_to_complete() if ctx.offload_manager._unpack_stream_status == 'reloading': current_stream.wait_stream(ctx.offload_manager.unpack_stream) ctx.offload_manager._unpack_stream_status = 'idle' - - if ctx.offload_manager.status == 'captured' and ctx.offload_manager.current_schedule_index < len(ctx.offload_manager._pp_schedule): - next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index] - if next_schedule_layer < 0: - # For last layer last microbatch, wait for offload to complete before reloading - if ctx.offload_manager._pack_stream_status == 'offloading': - assert len(ctx.offload_manager.packed_tensors_offload_in_progress) > 0, f"packed_tensors_offload_in_progress is empty" - offloaded_tensor = ctx.offload_manager.packed_tensors_offload_in_progress[0] - if next_schedule_layer == -offloaded_tensor.schedule_layer: - current_stream.wait_stream(ctx.offload_manager.pack_stream) - ctx.offload_manager._pack_stream_status = 'idle' - # Deallocate original tensor after offload is complete - while len(ctx.offload_manager.packed_tensors_offload_in_progress) > 0: - packed_tensor = ctx.offload_manager.packed_tensors_offload_in_progress.pop(0) - if isinstance(packed_tensor._original_tensor, MXFP8Tensor): - packed_tensor._original_tensor._columnwise_data = None - else: - packed_tensor._original_tensor = None - - ctx.offload_manager.reload_packed_tensors(-next_schedule_layer) return grad_output + (None, None) @@ -863,7 +858,22 @@ def offload_packed_tensors(self, pp_schedule_layer): else: pass assert len(self.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {self.packed_tensors_to_offload}" - + + def wait_for_offload_to_complete(self): + """Wait for offload to complete.""" + current_stream = torch.cuda.current_stream() + if self._pack_stream_status == 'offloading': + current_stream.wait_stream(self.pack_stream) + self._pack_stream_status = 'idle' + + # Deallocate original tensor after offload is complete + while len(self.packed_tensors_offload_in_progress) > 0: + packed_tensor = self.packed_tensors_offload_in_progress.pop(0) + if isinstance(packed_tensor._original_tensor, MXFP8Tensor): + packed_tensor._original_tensor._columnwise_data = None + else: + packed_tensor._original_tensor = None + def reload_packed_tensors(self, pp_schedule_layer): """Reload the packed tensors.""" current_stream = torch.cuda.current_stream() @@ -1086,7 +1096,10 @@ def __exit__(self, *args: Any): def packed_moe_expert_offloading_group_start(tensor, name=None): """Mark the start of a layer group and prepare for offload/reload.""" rank = torch.distributed.get_rank() - return tensor + offload_manager = PackedOffloadManager.get_instance() + if not offload_manager.enabled: + return tensor + return PP_PreScheduleFunction.apply(tensor, offload_manager) def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num_tokens_tensor=None): """Get the fine-grained offload context""" @@ -1107,7 +1120,7 @@ def packed_moe_expert_offloading_group_commit(tensor, name=None): offload_manager.device = tensor.device if not offload_manager.enabled: return tensor - return PP_ScheduleFunction.apply(tensor, offload_manager) + return PP_PostScheduleFunction.apply(tensor, offload_manager) def packed_moe_expert_offloading_init_chunk_handler(vp_size, vp_stage): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index ab6f42f98a6..25ced73f9d8 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -1012,6 +1012,8 @@ def forward( with off_interface( self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" ) as permuted_local_hidden_states: + if self.config.packed_moe_expert_offloading: + permuted_local_hidden_states = packed_moe_expert_offloading_group_start(permuted_local_hidden_states, name="expert_fc1") if self.packed_offload_expert_fc1: offload_context = get_packed_moe_expert_offloading_context(name="expert_fc1", max_num_tokens=permuted_local_hidden_states.shape[0], num_tokens_tensor=tokens_per_expert.sum()) else: From a219d7df94dbfd82d6752a1eea4181ed60c74971 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Thu, 11 Dec 2025 19:35:06 -0800 Subject: [PATCH 26/57] Avoid unnecessary wait_stream for reload in case of 1f1b --- megatron/core/pipeline_parallel/moe_packed_offload.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py index 5f1f9ebb762..49b692b1a16 100644 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ b/megatron/core/pipeline_parallel/moe_packed_offload.py @@ -718,7 +718,7 @@ def forward(ctx, tensor, offload_manager): # after forward ctx.offload_manager.offload_packed_tensors(current_schedule_layer) if next_schedule_layer < 0: # reload for next backward layer - ctx.offload_manager.reload_packed_tensors(-next_schedule_layer) + ctx.offload_manager.reload_packed_tensors(-next_schedule_layer, no_wait=True) else: ctx.offload_manager.remove_packed_tensor_from_offload() @@ -874,10 +874,12 @@ def wait_for_offload_to_complete(self): else: packed_tensor._original_tensor = None - def reload_packed_tensors(self, pp_schedule_layer): + def reload_packed_tensors(self, pp_schedule_layer, no_wait=False): """Reload the packed tensors.""" - current_stream = torch.cuda.current_stream() - self.unpack_stream.wait_stream(current_stream) + # Avoid waiting for main stream if reload is immediately after offload since offload is already waiting for main stream + if not no_wait or self.unpack_stream != self.pack_stream: + current_stream = torch.cuda.current_stream() + self.unpack_stream.wait_stream(current_stream) with torch.cuda.stream(self.unpack_stream): if self.status == 'captured': From 0bede5b7958b878563da451a7c3166a3fae6fa26 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Thu, 18 Dec 2025 14:31:36 +0800 Subject: [PATCH 27/57] Check in dynamic-shape-aware SwiGLU triton kernel --- megatron/core/fusions/fused_bias_swiglu.py | 284 ++++++++++++++++++++- megatron/core/transformer/moe/experts.py | 2 + 2 files changed, 276 insertions(+), 10 deletions(-) diff --git a/megatron/core/fusions/fused_bias_swiglu.py b/megatron/core/fusions/fused_bias_swiglu.py index 632470876c9..3f6c70c75d8 100644 --- a/megatron/core/fusions/fused_bias_swiglu.py +++ b/megatron/core/fusions/fused_bias_swiglu.py @@ -5,6 +5,8 @@ import torch import torch.nn.functional as F +import triton +import triton.language as tl from megatron.core.jit import jit_fuser from megatron.core.utils import nvtx_decorator @@ -190,20 +192,51 @@ def backward(ctx, grad_output): class WeightedSwiGLUFunction(torch.autograd.Function): @staticmethod - # bias is an optional argument - def forward(ctx, input, weights, fp8_input_store): + def forward(ctx, input, weights, fp8_input_store, num_tokens_tensor=None): + """Forward pass for weighted SwiGLU. + + Args: + input: [total_tokens, hidden_size * 2] + weights: [total_tokens, 1] + fp8_input_store: Whether to store in FP8 + num_tokens_tensor: Optional scalar tensor with actual token count + (uses Triton if provided) + """ + # Convert input for backward pass input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input - ctx.save_for_backward(input_for_backward, weights) + + # Use Triton implementation if num_tokens_tensor provided and available + if num_tokens_tensor is not None and input.dim() == 2: + output = weighted_swiglu_triton(input, weights, num_tokens_tensor) + ctx.save_for_backward(input_for_backward, weights, num_tokens_tensor) + ctx.use_triton = True + else: + # Fallback to JIT fused implementation + output = weighted_swiglu(input, weights) + ctx.save_for_backward(input_for_backward, weights) + ctx.use_triton = False + ctx.ori_input_dtype = input.dtype ctx.fp8_input_store = fp8_input_store - return weighted_swiglu(input, weights) + return output @staticmethod def backward(ctx, grad_output): - input, weights = ctx.saved_tensors - input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input - tmp, wgrad = weighted_swiglu_back(grad_output, input, weights) - return tmp, wgrad, None + """Backward pass for weighted SwiGLU.""" + if ctx.use_triton: + # Triton backward path + input, weights, num_tokens_tensor = ctx.saved_tensors + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + grad_input, grad_weights = weighted_swiglu_triton_back( + grad_output, input, weights, num_tokens_tensor + ) + return grad_input, grad_weights, None, None + else: + # JIT fused backward path + input, weights = ctx.saved_tensors + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp, wgrad = weighted_swiglu_back(grad_output, input, weights) + return tmp, wgrad, None, None def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False): @@ -236,7 +269,7 @@ def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) -def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False): +def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False, num_tokens_tensor=None): """ Token-wise-weighted bias swiglu fusion. """ @@ -246,10 +279,241 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False): if bias is not None: raise NotImplementedError("Bias is not supported for weighted swiglu fusion") else: - output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store) + output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store, num_tokens_tensor) return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) # bias_swiglu_impl = BiasSwiGLUFunction.apply # swiglu_impl = SwiGLUFunction.apply + +@triton.jit +def _weighted_swiglu_fwd_kernel( + input_ptr, + weights_ptr, + output_ptr, + num_tokens_ptr, + hidden_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel for weighted SwiGLU forward pass. + + Processes tokens in strided pattern, only operating on valid tokens. + Formula: output = SiLU(input[:, :H]) * input[:, H:] * weights + """ + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load actual number of tokens + num_tokens = tl.load(num_tokens_ptr) + + # Strided access: each block handles tokens [pid, pid+num_blocks, ...] + token_idx = pid + while token_idx < num_tokens: + # Load weight for this token + weight = tl.load(weights_ptr + token_idx) + + # Process hidden dimension + for h_offset in range(0, hidden_size, BLOCK_SIZE): + h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size + + # Load input chunks (gate and value) + input_offset_1 = token_idx * (hidden_size * 2) + h_offset + input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset + + y1 = tl.load( + input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 + ) + y2 = tl.load( + input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 + ) + + # SwiGLU: SiLU(y1) * y2 * weight + # SiLU(x) = x * sigmoid(x) + # Cast to fp32 for sigmoid computation (required by Triton) + y1_fp32 = y1.to(tl.float32) + y2_fp32 = y2.to(tl.float32) + weight_fp32 = weight.to(tl.float32) + + sigmoid_y1 = tl.sigmoid(y1_fp32) + silu_y1 = y1_fp32 * sigmoid_y1 + result = silu_y1 * y2_fp32 * weight_fp32 + + # Store output (cast back to original dtype) + output_offset = token_idx * hidden_size + h_offset + tl.store( + output_ptr + output_offset + tl.arange(0, BLOCK_SIZE), + result.to(y1.dtype), + mask=h_mask, + ) + + # Stride to next token + token_idx += num_blocks + +@triton.jit +def _weighted_swiglu_bwd_kernel( + grad_output_ptr, + input_ptr, + weights_ptr, + grad_input_ptr, + grad_weights_ptr, + num_tokens_ptr, + hidden_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel for weighted SwiGLU backward pass. + + Computes gradients with respect to input and weights for valid tokens only. + """ + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load actual number of tokens + num_tokens = tl.load(num_tokens_ptr) + + # Strided access + token_idx = pid + while token_idx < num_tokens: + # Load weight for this token + weight = tl.load(weights_ptr + token_idx) + + # Accumulator for weight gradient (fp32 for precision) + weight_grad_acc = 0.0 + + # Process hidden dimension + for h_offset in range(0, hidden_size, BLOCK_SIZE): + h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size + + # Load grad_output + grad_out_offset = token_idx * hidden_size + h_offset + grad_out = tl.load( + grad_output_ptr + grad_out_offset + tl.arange(0, BLOCK_SIZE), + mask=h_mask, + other=0.0, + ) + + # Load input chunks + input_offset_1 = token_idx * (hidden_size * 2) + h_offset + input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset + + y1 = tl.load( + input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 + ) + y2 = tl.load( + input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 + ) + + # Cast to fp32 for sigmoid computation (required by Triton) + y1_fp32 = y1.to(tl.float32) + y2_fp32 = y2.to(tl.float32) + grad_out_fp32 = grad_out.to(tl.float32) + weight_fp32 = weight.to(tl.float32) + + # Forward calculations + sigmoid_y1 = tl.sigmoid(y1_fp32) + silu_y1 = y1_fp32 * sigmoid_y1 + + # Gradient for y1 (gate): d(SiLU(y1))/dy1 * y2 * weight * grad_out + # d(SiLU(y1))/dy1 = sigmoid(y1) * (1 + y1 * (1 - sigmoid(y1))) + dsilu_dy1 = sigmoid_y1 * (1.0 + y1_fp32 * (1.0 - sigmoid_y1)) + grad_y1 = grad_out_fp32 * weight_fp32 * dsilu_dy1 * y2_fp32 + + # Gradient for y2 (value): SiLU(y1) * weight * grad_out + grad_y2 = grad_out_fp32 * weight_fp32 * silu_y1 + + # Store input gradients (cast back to original dtype) + tl.store( + grad_input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), + grad_y1.to(y1.dtype), + mask=h_mask, + ) + tl.store( + grad_input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), + grad_y2.to(y2.dtype), + mask=h_mask, + ) + + # Accumulate weight gradient: swiglu(y) * grad_out + # swiglu(y) = silu_y1 * y2 + weight_grad_contribution = silu_y1 * y2_fp32 * grad_out_fp32 + weight_grad_acc += tl.sum(weight_grad_contribution) + + # Store weight gradient after processing all chunks + tl.store(grad_weights_ptr + token_idx, weight_grad_acc) + + # Stride to next token + token_idx += num_blocks + +def weighted_swiglu_triton(input, weights, num_tokens_tensor): + """Triton implementation of weighted SwiGLU forward pass. + + Args: + input: [total_tokens, hidden_size * 2] + weights: [total_tokens, 1] + num_tokens_tensor: Scalar tensor with actual token count + + Returns: + output: [total_tokens, hidden_size] + """ + assert input.dim() == 2, "Input must be 2D [total_tokens, hidden_size*2]" + assert weights.dim() == 2 and weights.size(1) == 1, "Weights must be [total_tokens, 1]" + + total_tokens, hidden_size_2 = input.shape + hidden_size = hidden_size_2 // 2 + + # Allocate output + output = torch.empty((total_tokens, hidden_size), dtype=input.dtype, device=input.device) + + # Launch kernel + BLOCK_SIZE = 128 + num_blocks = min(total_tokens, 4096) + grid = (num_blocks,) + + _weighted_swiglu_fwd_kernel[grid]( + input, + weights, + output, + num_tokens_tensor, + hidden_size=hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output + +def weighted_swiglu_triton_back(grad_output, input, weights, num_tokens_tensor): + """Triton implementation of weighted SwiGLU backward pass. + + Args: + grad_output: [total_tokens, hidden_size] + input: [total_tokens, hidden_size * 2] + weights: [total_tokens, 1] + num_tokens_tensor: Scalar tensor with actual token count + + Returns: + grad_input: [total_tokens, hidden_size * 2] + grad_weights: [total_tokens, 1] + """ + total_tokens, hidden_size_2 = input.shape + hidden_size = hidden_size_2 // 2 + + # Allocate gradients + grad_input = torch.empty_like(input) + grad_weights = torch.empty_like(weights) + + # Launch kernel + BLOCK_SIZE = 128 + num_blocks = min(total_tokens, 4096) + grid = (num_blocks,) + + _weighted_swiglu_bwd_kernel[grid]( + grad_output, + input, + weights, + grad_input, + grad_weights, + num_tokens_tensor, + hidden_size=hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return grad_input, grad_weights \ No newline at end of file diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 25ced73f9d8..25b43f2c3db 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -1069,6 +1069,8 @@ def remove_glu_interleaving(x: torch.Tensor) -> torch.Tensor: bias_parallel, permuted_probs, self.config.activation_func_fp8_input_store, + tokens_per_expert.sum() if self.packed_offload_moe_act else None, + ) elif self.activation_func == quick_gelu and self.config.gated_linear_unit: intermediate_parallel = weighted_bias_quick_geglu_impl( From 837503d120efcb55788960a21239bfb64df4a9de Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Thu, 18 Dec 2025 15:57:23 +0800 Subject: [PATCH 28/57] Major cleanup and refactor Get rid of legacy names like packed offloading Move the main code body of paged stash to transformer/moe/ --- megatron/core/full_cuda_graph.py | 12 +- .../common/model_chunk_schedule_plan.py | 12 +- megatron/core/models/gpt/gpt_model.py | 18 +- .../pipeline_parallel/moe_packed_offload.py | 1186 ----------------- megatron/core/pipeline_parallel/schedules.py | 10 +- megatron/core/transformer/moe/experts.py | 38 +- .../core/transformer/transformer_block.py | 8 +- .../core/transformer/transformer_config.py | 8 +- megatron/training/arguments.py | 8 +- 9 files changed, 57 insertions(+), 1243 deletions(-) delete mode 100644 megatron/core/pipeline_parallel/moe_packed_offload.py diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 430a965b477..40fec15e67b 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -7,8 +7,8 @@ import torch from megatron.core.tensor_parallel.random import get_all_rng_states -from megatron.core.pipeline_parallel.moe_packed_offload import ( - packed_moe_expert_offloading_reset, +from megatron.core.transformer.moe.paged_stash import ( + paged_stash_reset, ) logger = logging.getLogger(__name__) @@ -101,11 +101,11 @@ class FullCudaGraphWrapper: cuda_graph = {'training': None, 'validation': None} result = {'training': None, 'validation': None} - def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1, packed_moe_expert_offloading=False): + def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1, moe_paged_stash=False): self.forward_backward_func = forward_backward_func self.static_loader = StaticBufferLoader() self.cuda_graph_warmup_steps = cuda_graph_warmup_steps - self.packed_moe_expert_offloading = packed_moe_expert_offloading + self.moe_paged_stash = moe_paged_stash def data_read(self, data_iterator, model, training, num_microbatches): """Read all microbatch inputs from Dataloader and copy to static buffers.""" @@ -188,7 +188,7 @@ def __call__(self, *args, **kwargs): if FullCudaGraphWrapper.cuda_graph[training_str] is None: FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: - packed_moe_expert_offloading_reset(enabled=self.packed_moe_expert_offloading and training) + paged_stash_reset(enabled=self.moe_paged_stash and training) FullCudaGraphWrapper.cuda_graph[training_str].replay() self.speculative_cuda_graph_check(model) self.next_iter(training_str) @@ -196,7 +196,7 @@ def __call__(self, *args, **kwargs): def speculative_cuda_graph_check(self, model): ''' check speculative execution modules ''' - if self.packed_moe_expert_offloading: + if self.moe_paged_stash: # Check if there is any overflow in the receiving buffer over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') for model_chunk in model: diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index eef232bb30b..a745ffe2294 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -8,8 +8,8 @@ from megatron.core.enums import Fp8Recipe from megatron.core.fp8_utils import get_fp8_context -from megatron.core.pipeline_parallel.moe_packed_offload import ( - packed_moe_expert_offloading_set_last_layer, +from megatron.core.transformer.moe.paged_stash import ( + paged_stash_set_last_layer, ) from megatron.core.pipeline_parallel.utils import ( AbstractSchedulePlan, @@ -482,8 +482,8 @@ def run( f_layer = f_schedule_plan.get_layer(i) b_layer = b_schedule_plan.pop_layer() torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b") - if f_layer.layer.config.packed_moe_expert_offloading: - packed_moe_expert_offloading_set_last_layer(i == f_num_layers - 1) + if f_layer.layer.config.moe_paged_stash: + paged_stash_set_last_layer(i == f_num_layers - 1) f_input, b_grad = TransformerLayerSchedulePlan.run( f_layer, b_layer, @@ -510,8 +510,8 @@ def run( for i in range(overlapped_layers, f_num_layers): f_layer = f_schedule_plan.get_layer(i) torch.cuda.nvtx.range_push(f"layer_{i}f") - if f_layer.layer.config.packed_moe_expert_offloading: - packed_moe_expert_offloading_set_last_layer(i == f_num_layers - 1) + if f_layer.layer.config.moe_paged_stash: + paged_stash_set_last_layer(i == f_num_layers - 1) f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input) torch.cuda.nvtx.range_pop() diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index dfbd69f7e19..2a829a24929 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -21,8 +21,8 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.moe_packed_offload import ( - packed_moe_expert_offloading_init_chunk_handler, +from megatron.core.transformer.moe.paged_stash import ( + paged_stash_init_chunk_handler, ) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.quantization.utils import get_quant_config_or_none @@ -476,9 +476,9 @@ def preprocess_for_fine_grained_offloading(self): off_interface.mark_not_offloadable(param) self.disable_param_offloading = False - def preprocess_for_packed_moe_expert_offloading(self): - """Preprocess for packed moe expert offloading.""" - return packed_moe_expert_offloading_init_chunk_handler( + def preprocess_for_paged_stash(self): + """Preprocess for paged stash.""" + return paged_stash_init_chunk_handler( vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage, ) @@ -515,8 +515,8 @@ def forward( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() - if self.config.packed_moe_expert_offloading: - self.preprocess_for_packed_moe_expert_offloading() + if self.config.moe_paged_stash: + self.preprocess_for_paged_stash() inference_context = deprecate_inference_params(inference_context, inference_params) @@ -759,8 +759,8 @@ def build_schedule_plan( if self.config.fine_grained_activation_offloading: self.preprocess_for_fine_grained_offloading() - if self.config.packed_moe_expert_offloading: - self.preprocess_for_packed_moe_expert_offloading() + if self.config.moe_paged_stash: + self.preprocess_for_paged_stash() from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan diff --git a/megatron/core/pipeline_parallel/moe_packed_offload.py b/megatron/core/pipeline_parallel/moe_packed_offload.py deleted file mode 100644 index 49b692b1a16..00000000000 --- a/megatron/core/pipeline_parallel/moe_packed_offload.py +++ /dev/null @@ -1,1186 +0,0 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - -import warnings -from collections import deque -from contextlib import nullcontext -from typing import Any -import os -import torch -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -try: - import triton - import triton.language as tl - HAVE_TRITON = True -except ImportError: - HAVE_TRITON = False - -def set_ideal_affinity_for_current_gpu(): - """Set CPU affinity for the current GPU to optimize host-device transfers.""" - import uuid - - try: - import cuda.bindings.driver as cuda_driver - import cuda.bindings.runtime as cuda_runtime - except ImportError: - try: - import cuda.cuda as cuda_driver - import cuda.cudart as cuda_runtime - except ImportError: - warnings.warn("cuda-python may not be installed, skipping GPU affinity setting") - return - try: - import pynvml - except ImportError: - warnings.warn("pynvml is not installed, skipping GPU affinity setting") - return - - # Get current CUDA device ID - err, device_id = cuda_runtime.cudaGetDevice() - assert err == cuda_runtime.cudaError_t.cudaSuccess - # Get device UUID - err, device_uuid = cuda_driver.cuDeviceGetUuid(device_id) - assert err == cuda_driver.CUresult.CUDA_SUCCESS - # Set CPU affinity based on GPU's NUMA node - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes))) - pynvml.nvmlDeviceSetCpuAffinity(handle) - -GLOBAL_BLOCK_SIZE = 1024 - -@triton.jit -def _stash_copy_kernel_2d( - src_ptr, - dst_ptr, - num_tokens_ptr, # Number of tokens to copy - alloc_offset_ptr, # In tokens (read-only) - free_offset_ptr, # In tokens (read-only) - capacity_ptr, # In tokens (read-only) - overflow_ptr, - new_free_offset_ptr, # Output: new free_offset value (written by kernel) - HIDDEN_SIZE: tl.constexpr, # Hidden dimension (compile-time constant) - BLOCK_SIZE: tl.constexpr, # Threads per block (for hidden dimension) -): - """2D Triton kernel to copy tensor data to stash buffer. - - Grid: (num_blocks,) - fixed number of blocks - Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. - Works directly with contiguous 2D tensors [tokens, hidden_size]. - Offsets are tracked in tokens, not elements. - """ - pid = tl.program_id(axis=0) - num_blocks = tl.num_programs(axis=0) - - # Load parameters (in tokens, not elements) - num_tokens = tl.load(num_tokens_ptr) - alloc_offset = tl.load(alloc_offset_ptr) - free_offset = tl.load(free_offset_ptr) - capacity = tl.load(capacity_ptr) - - # All blocks check for overflow (same computation, avoids race condition) - if free_offset >= alloc_offset: - # No wraparound: available space is from free_offset to capacity, then 0 to alloc_offset - avail_space = capacity - (free_offset - alloc_offset) - else: - # Wraparound: available space is from free_offset to alloc_offset - avail_space = alloc_offset - free_offset - overflow_detected = avail_space < num_tokens - - # Only block 0 writes the overflow flag - if pid == 0 and overflow_detected: - tl.store(overflow_ptr, 1) - - # All blocks return early if overflow detected - if overflow_detected: - return - - # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] - token_idx = pid - while token_idx < num_tokens: - # Calculate destination token index with wraparound - dst_token_idx = (free_offset + token_idx) % capacity - - # Each thread handles elements of the hidden dimension - elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE - - # Check if we need masking (only if HIDDEN_SIZE not divisible by BLOCK_SIZE) - need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 - num_iters = elements_per_thread + (1 if need_mask else 0) - - # 2D indexing: base + token_idx * HIDDEN_SIZE + hidden_offsets - src_base = src_ptr + token_idx * HIDDEN_SIZE - dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE - - if need_mask: - # Use mask for all iterations when HIDDEN_SIZE not divisible by BLOCK_SIZE - for iter in range(num_iters): - hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE - hidden_mask = hidden_offsets < HIDDEN_SIZE - data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) - tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) - else: - # No mask needed - HIDDEN_SIZE is multiple of BLOCK_SIZE - for iter in range(elements_per_thread): - hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE - data = tl.load(src_base + hidden_offsets) - tl.store(dst_base + hidden_offsets, data) - - # Stride to next token for this block - token_idx += num_blocks - - # Update new_free_offset (only first block writes it) - if pid == 0: - new_free_offset = (free_offset + num_tokens) % capacity - tl.store(new_free_offset_ptr, new_free_offset) - -@triton.jit -def _stash_pop_kernel_2d( - src_ptr, - dst_ptr, - num_tokens_ptr, # Number of tokens to reload - tensor_offset_ptr, # In tokens - where data was stashed (read-only) - alloc_offset_ptr, # In tokens (read-only, not used in pop) - free_offset_ptr, # In tokens (write: updated directly by kernel) - capacity_ptr, # In tokens (read-only) - HIDDEN_SIZE: tl.constexpr, # Hidden dimension (compile-time constant) - BLOCK_SIZE: tl.constexpr, # Threads per block (for hidden dimension) -): - """2D Triton kernel to reload tensor data from stash buffer. - - Grid: (num_blocks,) - fixed number of blocks - Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. - Works directly with contiguous 2D tensors [tokens, hidden_size]. - Offsets are tracked in tokens, not elements. - Uses LIFO (stack) semantics - moves free_offset backward after popping. - """ - pid = tl.program_id(axis=0) - num_blocks = tl.num_programs(axis=0) - - # Load parameters (in tokens, not elements) - num_tokens = tl.load(num_tokens_ptr) - tensor_offset = tl.load(tensor_offset_ptr) # Where data was stashed - capacity = tl.load(capacity_ptr) - - # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] - token_idx = pid - while token_idx < num_tokens: - # Calculate source token index with wraparound - src_token_idx = (tensor_offset + token_idx) % capacity - - # Each thread handles elements of the hidden dimension - elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE - - # Check if we need masking - need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 - num_iters = elements_per_thread + (1 if need_mask else 0) - - # 2D indexing - src_base = src_ptr + src_token_idx * HIDDEN_SIZE - dst_base = dst_ptr + token_idx * HIDDEN_SIZE - - if need_mask: - # Use mask for all iterations when HIDDEN_SIZE not divisible by BLOCK_SIZE - for iter in range(num_iters): - hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE - hidden_mask = hidden_offsets < HIDDEN_SIZE - data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) - tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) - else: - # No mask needed - for iter in range(elements_per_thread): - hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE - data = tl.load(src_base + hidden_offsets) - tl.store(dst_base + hidden_offsets, data) - - # Stride to next token for this block - token_idx += num_blocks - - # For LIFO (stack) behavior: move free_offset backward - # After popping, free_offset should be at tensor_offset (freeing the space we just read) - if pid == 0: - # The data was stashed at tensor_offset, so after popping, free_offset moves back to tensor_offset - tl.store(free_offset_ptr, tensor_offset) - - -class StashBuffer: - """ - A class to represent a 2D stash buffer. - - The buffer is organized as [num_tokens, hidden_size]. - Offsets (free_offset, alloc_offset) are tracked in tokens, not elements. - """ - - def __init__(self, num_tokens, hidden_size, device, overflow, dtype): - """ - Args: - num_tokens: Maximum number of tokens the buffer can hold - hidden_size: Hidden dimension size - device: Device for the buffer - overflow: Overflow flag tensor (shared across all buffers) - dtype: Data type - """ - self.buffer = None - self.hidden_size = hidden_size - self.num_tokens_capacity = num_tokens - - # Create 2D buffer [num_tokens, hidden_size] - if os.getenv('PACKED_OFFLOAD_CPU', '0') == '1': - self.buffer = torch.empty((num_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True) - else: - self.buffer = torch.empty((num_tokens, hidden_size), dtype=dtype, device=device) - - self.overflow = overflow # GPU flag (shared) - self.device = device - - # Offsets are in TOKENS - self.free_offset = torch.zeros(1, dtype=torch.int64, device=device) # tail (write pointer) - self.alloc_offset = torch.zeros(1, dtype=torch.int64, device=device) # head (read pointer) - self.capacity = torch.zeros(1, dtype=torch.int64, device=device) - self.capacity.fill_(num_tokens) # Capacity in tokens - self.dtype = dtype - - def reset(self): - """Reset the stash buffer offsets.""" - self.free_offset.zero_() - self.alloc_offset.zero_() - - def __repr__(self): - return f"StashBuffer(capacity={self.num_tokens_capacity} tokens, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" - - -class PagedStashBuffer: - """ - A paged stash buffer with page-level memory management. - - The buffer is organized as [num_pages, page_size, hidden_size]. - Uses a free list (circular buffer) to track available pages. - """ - - def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype): - """ - Args: - num_tokens: Maximum number of tokens the buffer can hold - hidden_size: Hidden dimension size - page_size: Number of tokens per page - device: Device for the buffer - overflow: Overflow flag tensor (shared across all buffers) - dtype: Data type - """ - self.hidden_size = hidden_size - self.page_size = page_size - self.num_pages = (num_tokens + page_size - 1) // page_size # Ceiling division - self.total_tokens = self.num_pages * page_size - - # Create 2D buffer [total_tokens, hidden_size] - # Organized as pages: [page_0_tokens, page_1_tokens, ...] - if os.getenv('PACKED_OFFLOAD_CPU', '0') == '1': - self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True) - else: - self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device=device) - - self.overflow = overflow # GPU flag (shared) - self.device = device - self.dtype = dtype - - # Free list as circular buffer: stores available page IDs - self.free_list = torch.arange(self.num_pages, dtype=torch.int64, device=device) - - # Head and tail pointers for free_list circular buffer - self.free_list_head = torch.zeros(1, dtype=torch.int64, device=device) # Read pointer (allocation) - self.free_list_tail = self.num_pages*torch.ones(1, dtype=torch.int64, device=device) # Write pointer (deallocation) - - # Capacity of free list - self.free_list_capacity = self.num_pages*torch.ones(1, dtype=torch.int64, device=device) - - def reset(self): - """Reset the paged buffer - reinitialize free list.""" - self.free_list.copy_(torch.arange(self.num_pages, dtype=torch.int64, device=self.device)) - self.free_list_head.zero_() - self.free_list_tail.fill_(self.num_pages) - - def __repr__(self): - return f"PagedStashBuffer(num_pages={self.num_pages}, page_size={self.page_size}, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" - - -@triton.jit -def _paged_stash_copy_kernel( - src_ptr, - dst_ptr, - num_tokens_ptr, - free_list_ptr, - free_list_head_ptr, # Read-only: current head position - free_list_tail_ptr, # Read-only: current tail position (for overflow check) - free_list_capacity_ptr, - page_record_ptr, # Output: records which pages were used - overflow_ptr, - new_free_list_head_ptr, # Output: new head position (written by kernel) - PAGE_SIZE: tl.constexpr, - HIDDEN_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """Triton kernel to copy tokens to paged stash buffer. - - Allocates pages from free list (reads from head, advances head). - Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. - Grid: (num_blocks,) where blocks process tokens in a strided pattern. - Writes new head to temporary tensor to avoid race conditions. - """ - pid = tl.program_id(axis=0) - num_blocks = tl.num_programs(axis=0) - - # Load parameters - num_tokens = tl.load(num_tokens_ptr) - free_list_head = tl.load(free_list_head_ptr) - free_list_tail = tl.load(free_list_tail_ptr) - free_list_capacity = tl.load(free_list_capacity_ptr) - - # Check available pages (unwrapped indices: simple subtraction, no modulo needed) - avail_pages = free_list_tail - free_list_head - - # Calculate required pages - required_pages = tl.cdiv(num_tokens, PAGE_SIZE) - overflow_detected = avail_pages < required_pages - - # Only block 0 writes overflow flag - if pid == 0 and overflow_detected: - tl.store(overflow_ptr, 1) - - # All blocks return early if overflow - if overflow_detected: - return - - # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] - token_idx = pid - while token_idx < num_tokens: - # Determine which page this token belongs to - page_slot = token_idx // PAGE_SIZE - token_in_page = token_idx % PAGE_SIZE - - # Read page ID from free list (with wraparound) - free_list_idx = (free_list_head + page_slot) % free_list_capacity - page_id = tl.load(free_list_ptr + free_list_idx) - - # First token in page: record the page ID (only if this block handles token 0 of the page) - if token_in_page == 0: - tl.store(page_record_ptr + page_slot, page_id) - - # Calculate destination address in paged buffer - dst_token_idx = page_id * PAGE_SIZE + token_in_page - - # Copy token data (2D: hidden dimension) - elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE - need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 - num_iters = elements_per_thread + (1 if need_mask else 0) - - src_base = src_ptr + token_idx * HIDDEN_SIZE - dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE - - if need_mask: - for iter in range(num_iters): - hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE - hidden_mask = hidden_offsets < HIDDEN_SIZE - data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) - tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) - else: - for iter in range(elements_per_thread): - hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE - data = tl.load(src_base + hidden_offsets) - tl.store(dst_base + hidden_offsets, data) - - # Stride to next token for this block - token_idx += num_blocks - - # Calculate and store new free list head (only block 0) - # We consumed pages, so advance head forward (unwrapped: no modulo) - # Write to temporary tensor to avoid race conditions - if pid == 0: - new_head = free_list_head + required_pages - tl.store(new_free_list_head_ptr, new_head) - - -@triton.jit -def _paged_stash_pop_kernel( - src_ptr, - dst_ptr, - num_tokens_ptr, - page_record_ptr, # Input: which pages to read - free_list_ptr, - free_list_head_ptr, # Read-only: current head position (not used) - free_list_tail_ptr, # Read-only: current tail position - free_list_capacity_ptr, - new_free_list_tail_ptr, # Output: new tail position (written by kernel) - PAGE_SIZE: tl.constexpr, - HIDDEN_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """Triton kernel to reload tokens from paged stash buffer. - - Returns pages to free list (writes to tail, advances tail). - Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. - Grid: (num_blocks,) where blocks process tokens in a strided pattern. - Writes new tail to temporary tensor to avoid race conditions. - """ - pid = tl.program_id(axis=0) - num_blocks = tl.num_programs(axis=0) - - # Load parameters - num_tokens = tl.load(num_tokens_ptr) - free_list_tail = tl.load(free_list_tail_ptr) - free_list_capacity = tl.load(free_list_capacity_ptr) - - # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] - token_idx = pid - while token_idx < num_tokens: - # Determine which page this token belongs to - page_slot = token_idx // PAGE_SIZE - token_in_page = token_idx % PAGE_SIZE - - # Read page ID from page record - page_id = tl.load(page_record_ptr + page_slot) - - # Calculate source address in paged buffer - src_token_idx = page_id * PAGE_SIZE + token_in_page - - # Copy token data (2D: hidden dimension) - elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE - need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 - num_iters = elements_per_thread + (1 if need_mask else 0) - - src_base = src_ptr + src_token_idx * HIDDEN_SIZE - dst_base = dst_ptr + token_idx * HIDDEN_SIZE - - if need_mask: - for iter in range(num_iters): - hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE - hidden_mask = hidden_offsets < HIDDEN_SIZE - data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) - tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) - else: - for iter in range(elements_per_thread): - hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE - data = tl.load(src_base + hidden_offsets) - tl.store(dst_base + hidden_offsets, data) - - # First token in page: release page back to free list - if token_in_page == 0: - # Write page ID back to free list at tail position (with wraparound) - write_idx = (free_list_tail + page_slot) % free_list_capacity - tl.store(free_list_ptr + write_idx, page_id) - - # Stride to next token for this block - token_idx += num_blocks - - # Calculate and store new free list tail (only block 0) - # We returned pages, so advance tail forward (unwrapped: no modulo) - # Write to temporary tensor to avoid race conditions - if pid == 0: - required_pages = tl.cdiv(num_tokens, PAGE_SIZE) - new_tail = free_list_tail + required_pages - tl.store(new_free_list_tail_ptr, new_tail) - - -class PagedTensor: - """ - A paged tensor that stores data in pages within a paged stash buffer. - """ - - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0): - """ - Args: - tensor: The tensor to store - num_tokens_tensor: Scalar tensor containing actual number of tokens - vp_stage: Virtual pipeline stage - layer_name: Name of the layer - max_tokens: Maximum number of tokens - page_size: Number of tokens per page - num_d2d_pages: Number of pages to copy using native PyTorch (rest uses Triton) - """ - self._tensor = tensor - self._original_tensor = None - assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1 - self.num_tokens_tensor = num_tokens_tensor.clone() - self.vp_stage = vp_stage - self.schedule_layer_no = schedule_layer_no - self.layer_name = layer_name - self.max_tokens = max_tokens - self.page_size = page_size - self.num_d2d_pages = num_d2d_pages - - # Original tensor information - self.original_shape = list(tensor.shape) - self.max_num_tokens = self.original_shape[0] - self.element_size = tensor.element_size() - self.hidden_size = self.original_shape[1] - self.dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype - self.device = tensor.device - - # Calculate number of pages needed - self.max_num_pages = (self.max_num_tokens + page_size - 1) // page_size # Ceiling division - - # Page record: stores which pages are being used for this tensor - self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) - - # Static tensor for D2D pages (allocate upfront if needed) - d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) - if d2d_tokens > 0: - self.static_tensor = torch.empty((d2d_tokens, self.hidden_size), dtype=self.dtype, device=self.device) - else: - self.static_tensor = None - - @property - def schedule_layer(self): - """Get the schedule layer.""" - return self.schedule_layer_no - - def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): - """Offload the paged tensor to paged stash buffer. - - Args: - paged_stash_buffer: The paged stash buffer to offload to - max_blocks: Maximum number of blocks for Triton kernel - """ - if not HAVE_TRITON: - raise RuntimeError("Triton is required for PagedTensor.offload_to_stash(). Please install triton.") - - self._tensor = self._tensor.contiguous() - if self.num_tokens_tensor.dim() == 0: - self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) - - # Get 2D tensor - if isinstance(self._tensor, MXFP8Tensor): - tensor_to_copy = self._tensor._columnwise_data - else: - tensor_to_copy = self._tensor - - # Split tensor into two parts: D2D portion and Triton portion - # Use max_num_tokens for consistent size across iterations - d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) - triton_tokens = self.max_num_tokens - d2d_tokens - - # Perform both D2D copy and Triton kernel together - # Part 1: Copy first d2d_tokens to static_tensor using native PyTorch - if d2d_tokens > 0: - self.static_tensor[:d2d_tokens] = tensor_to_copy[:d2d_tokens] - # Part 2: Copy remaining tokens using Triton kernel - if triton_tokens > 0: - triton_tensor = tensor_to_copy[d2d_tokens:self.max_num_tokens] - # Use actual num_tokens for the kernel (how many tokens to actually copy) - triton_num_tokens = self.num_tokens_tensor - d2d_tokens - - # Determine grid size - BLOCK_SIZE = GLOBAL_BLOCK_SIZE - num_blocks = min(triton_tokens, max_blocks) - grid = (num_blocks,) - - # Create temporary tensor for new head - new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) - - # Launch paged stash copy kernel - _paged_stash_copy_kernel[grid]( - triton_tensor, - paged_stash_buffer.buffer, - triton_num_tokens, - paged_stash_buffer.free_list, - paged_stash_buffer.free_list_head, - paged_stash_buffer.free_list_tail, - paged_stash_buffer.free_list_capacity, - self.page_record, # Triton kernel will populate page_record - paged_stash_buffer.overflow, - new_free_list_head, - PAGE_SIZE=self.page_size, - HIDDEN_SIZE=self.hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Update free list head - paged_stash_buffer.free_list_head.copy_(new_free_list_head) - - # Save reference to original tensor - self._original_tensor = self._tensor - self._tensor = None - - def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): - """Reload the paged tensor from paged stash buffer. - - Args: - paged_stash_buffer: The paged stash buffer to reload from - max_blocks: Maximum number of blocks for Triton kernel - """ - if not HAVE_TRITON: - raise RuntimeError("Triton is required for PagedTensor.reload_from_stash(). Please install triton.") - - # Allocate output tensor - if isinstance(self._original_tensor, MXFP8Tensor): - columnwise_data = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) - self._tensor = MXFP8Tensor( - shape=self._original_tensor.shape, - dtype=self._original_tensor.dtype, - fp8_dtype=self._original_tensor._fp8_dtype, - rowwise_data=self._original_tensor._rowwise_data, - rowwise_scale_inv=self._original_tensor._rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=self._original_tensor._columnwise_scale_inv, - quantizer=self._original_tensor._quantizer, - ) - tensor_to_reload = self._tensor._columnwise_data - else: - self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) - tensor_to_reload = self._tensor - - # Split tensor into two parts: D2D portion and Triton portion - # Use max_num_tokens for consistency with offload - d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) - triton_tokens = self.max_num_tokens - d2d_tokens - - # Perform both D2D copy and Triton kernel together - # Part 1: Copy first d2d_tokens from static_tensor using native PyTorch - if d2d_tokens > 0 and self.static_tensor is not None: - tensor_to_reload[:d2d_tokens] = self.static_tensor[:d2d_tokens] - - # Part 2: Copy remaining tokens using Triton kernel - if triton_tokens > 0: - triton_tensor = tensor_to_reload[d2d_tokens:self.max_num_tokens] - # Use actual num_tokens for the kernel (how many tokens to actually copy) - triton_num_tokens = self.num_tokens_tensor - d2d_tokens - - # Determine grid size - BLOCK_SIZE = GLOBAL_BLOCK_SIZE - num_blocks = min(triton_tokens, max_blocks) - grid = (num_blocks,) - - # Create temporary tensor for new tail - new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) - - # Launch paged stash pop kernel - _paged_stash_pop_kernel[grid]( - paged_stash_buffer.buffer, - triton_tensor, - triton_num_tokens, - self.page_record, # Triton kernel will read from page_record - paged_stash_buffer.free_list, - paged_stash_buffer.free_list_head, - paged_stash_buffer.free_list_tail, - paged_stash_buffer.free_list_capacity, - new_free_list_tail, - PAGE_SIZE=self.page_size, - HIDDEN_SIZE=self.hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Update free list tail - paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) - - -class PP_PreScheduleFunction(torch.autograd.Function): - """ - This function is used to update the pp schedule. - """ - - @staticmethod - def forward(ctx, tensor, offload_manager): # after forward - # pylint: disable=missing-function-docstring - ctx.offload_manager = offload_manager - # Wait for offload to complete before starting the next layer - offload_manager.wait_for_offload_to_complete() - return tensor - - @staticmethod - def backward(ctx, *grad_output): # before backward - # pylint: disable=missing-function-docstring - # Initiate reload for next layer - if ctx.offload_manager.status == 'captured' and ctx.offload_manager.current_schedule_index < len(ctx.offload_manager._pp_schedule): - next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index] - if next_schedule_layer < 0: - ctx.offload_manager.reload_packed_tensors(-next_schedule_layer) - - return grad_output + (None, None) - -class PP_PostScheduleFunction(torch.autograd.Function): - """ - This function is used to update the pp schedule. - """ - - @staticmethod - def forward(ctx, tensor, offload_manager): # after forward - # pylint: disable=missing-function-docstring - - ctx.offload_manager = offload_manager - ctx.vp_stage = offload_manager.current_vp_stage - if ctx.vp_stage is None: - ctx.vp_stage = 0 - ctx.layer_no, ctx.microbatch_no = offload_manager.update_pp_schedule(ctx.vp_stage+1) - - # Initiate offload for current layer and reload for next layer - if offload_manager.status == 'captured': - current_schedule_layer = offload_manager.get_schedule_layer(ctx.vp_stage+1, ctx.layer_no, ctx.microbatch_no) - next_schedule_layer = ctx.offload_manager._pp_schedule[ctx.offload_manager.current_schedule_index+1] - if current_schedule_layer != -next_schedule_layer: - # Start offload for current layer - ctx.offload_manager.offload_packed_tensors(current_schedule_layer) - if next_schedule_layer < 0: - # reload for next backward layer - ctx.offload_manager.reload_packed_tensors(-next_schedule_layer, no_wait=True) - else: - ctx.offload_manager.remove_packed_tensor_from_offload() - - ctx.offload_manager.current_schedule_index += 1 - # return the identical tensor - return tensor - - @staticmethod - def backward(ctx, *grad_output): # before backward - # pylint: disable=missing-function-docstring - if ctx.vp_stage is not None: - ctx.offload_manager.update_pp_schedule(-(ctx.vp_stage+1), -ctx.layer_no, -ctx.microbatch_no) - ctx.offload_manager.current_schedule_index += 1 - current_stream = torch.cuda.current_stream() - - ctx.offload_manager.wait_for_offload_to_complete() - if ctx.offload_manager._unpack_stream_status == 'reloading': - current_stream.wait_stream(ctx.offload_manager.unpack_stream) - ctx.offload_manager._unpack_stream_status = 'idle' - - return grad_output + (None, None) - -class PackedOffloadManager: - """ - Singleton manager for coordinating activation offloading across pipeline stages. - Manages chunk handlers, synchronizes GPU-GPU transfers, - and handles virtual pipeline parallelism - """ - - OFFLOAD_MGR = None - - @classmethod - def get_instance(cls): - """Get the singleton instance of PipelineOffloadManager.""" - if cls.OFFLOAD_MGR is None: - cls.OFFLOAD_MGR = PackedOffloadManager() - return cls.OFFLOAD_MGR - - def __init__(self): - """Initialize the manager with queues and dedicated CUDA streams.""" - # allocate streams and events for synchronization - self.enabled = False - self._pack_stream = torch.cuda.Stream() - # Currently paged/packed stashing is not stream-safe, so use the same stream for packing - # and unpacking - self._unpack_stream = self._pack_stream - self._pack_stream_status = 'idle' # idle, offloading - self._unpack_stream_status = 'idle' # idle, reloading - self.packed_tensors_to_offload = [] - self.packed_tensors_offload_in_progress = [] - self.packed_tensors_to_reload = {} - - self.iteration = 0 - self._current_layer_name = None - self.vp_size = None - self.current_vp_stage = None - self._last_layer = False - self.status = 'begin' # begin, capture, captured - self._pp_schedule = None # If element is +ve, it denotes forward pass of vp stage, if -ve, it denotes backward pass of vp stage - self.current_layer = None - self.current_microbatch = None - self.current_schedule_index = None - - # Track max tokens needed per vp_stage, dtype, and hidden_size - self.max_tokens_per_vp_stage = None - self.temp_tokens_per_vp_stage = None - # Track max tokens needed across all vp_stages grouped by dtype and hidden_size - self.max_tokens_across_vp_stages = None - self.temp_tokens_across_vp_stages = None - - self.num_tokens_tensor = None - self.max_num_tokens = None - self.stash_buffers = None - self.overflow = None - self.device = None - - # Page size for paged memory management - self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page - self.use_paged_stash = os.getenv('USE_PAGED_STASH', '0') == '1' # Enable via env var - - # Number of pages to copy using native PyTorch (D2D) - self.num_d2d_pages = int(os.getenv('NUM_D2D_PAGES', '0')) # Default 0 (all Triton) - - @property - def pack_stream(self): - """Get the pack stream.""" - return self._pack_stream - - @property - def unpack_stream(self): - """Get the unpack stream.""" - return self._unpack_stream - - def set_current_layer_name(self, name): - """Set the current layer name.""" - self._current_layer_name = name - - def get_schedule_layer(self, vp_stage, layer_no, microbatch_no): - """Get the schedule layer.""" - return vp_stage*1000000 + layer_no*1000 + microbatch_no - - def add_packed_tensor_to_offload(self, packed_tensor): - """Add a packed tensor to the offload list.""" - if self.status == 'captured': - self.packed_tensors_to_offload.append(packed_tensor) - else: - pass - - def remove_packed_tensor_from_offload(self): - """Remove all packed tensors from the offload list.""" - if self.status == 'captured': - while len(self.packed_tensors_to_offload) > 0: - packed_tensor = self.packed_tensors_to_offload.pop(0) - assert len(self.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {self.packed_tensors_to_offload}" - else: - pass - - def offload_packed_tensors(self, pp_schedule_layer): - """Offload the packed tensors.""" - current_stream = torch.cuda.current_stream() - self.pack_stream.wait_stream(current_stream) - - with torch.cuda.stream(self.pack_stream): - if self.status == 'captured': - self._pack_stream_status = 'offloading' - if pp_schedule_layer not in self.packed_tensors_to_reload: - self.packed_tensors_to_reload[pp_schedule_layer] = [] - assert len(self.packed_tensors_to_reload[pp_schedule_layer]) == 0, f"packed_tensors_to_reload {pp_schedule_layer} is not empty {self.packed_tensors_to_reload[pp_schedule_layer]}" - - while len(self.packed_tensors_to_offload) > 0: - packed_tensor = self.packed_tensors_to_offload.pop(0) - stash_buffers_vp_stage = self.stash_buffers[packed_tensor.vp_stage] if not self.use_paged_stash else self.stash_buffers[0] - stash_buffer = stash_buffers_vp_stage[packed_tensor.dtype][packed_tensor.hidden_size] - packed_tensor.offload_to_stash(stash_buffer) - self.packed_tensors_to_reload[pp_schedule_layer].append(packed_tensor) - self.packed_tensors_offload_in_progress.append(packed_tensor) - else: - pass - assert len(self.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {self.packed_tensors_to_offload}" - - def wait_for_offload_to_complete(self): - """Wait for offload to complete.""" - current_stream = torch.cuda.current_stream() - if self._pack_stream_status == 'offloading': - current_stream.wait_stream(self.pack_stream) - self._pack_stream_status = 'idle' - - # Deallocate original tensor after offload is complete - while len(self.packed_tensors_offload_in_progress) > 0: - packed_tensor = self.packed_tensors_offload_in_progress.pop(0) - if isinstance(packed_tensor._original_tensor, MXFP8Tensor): - packed_tensor._original_tensor._columnwise_data = None - else: - packed_tensor._original_tensor = None - - def reload_packed_tensors(self, pp_schedule_layer, no_wait=False): - """Reload the packed tensors.""" - # Avoid waiting for main stream if reload is immediately after offload since offload is already waiting for main stream - if not no_wait or self.unpack_stream != self.pack_stream: - current_stream = torch.cuda.current_stream() - self.unpack_stream.wait_stream(current_stream) - - with torch.cuda.stream(self.unpack_stream): - if self.status == 'captured': - self._unpack_stream_status = 'reloading' - count = 0 - for item in self.packed_tensors_to_reload: - if len(self.packed_tensors_to_reload[item]) > 0: - count += 1 - - while len(self.packed_tensors_to_reload[pp_schedule_layer]) > 0: - packed_tensor = self.packed_tensors_to_reload[pp_schedule_layer].pop(0) - stash_buffers_vp_stage = self.stash_buffers[packed_tensor.vp_stage] if not self.use_paged_stash else self.stash_buffers[0] - stash_buffer = stash_buffers_vp_stage[packed_tensor.dtype][packed_tensor.hidden_size] - packed_tensor.reload_from_stash(stash_buffer) - else: - pass - assert len(self.packed_tensors_to_reload[pp_schedule_layer]) == 0, f"packed_tensors_to_reload {pp_schedule_layer} is not empty {self.packed_tensors_to_reload[pp_schedule_layer]}" - - - def allocate_offload_buffers(self, stash_buffer_size_factor=1.10): - """Allocate offload buffers for each vp stage, organized by [vp_stage][dtype][hidden_size].""" - self.stash_buffers = [] - self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) - - if self.use_paged_stash: - self.stash_buffers.append({}) - for dtype, hidden_size in self.max_tokens_across_vp_stages: - if dtype not in self.stash_buffers[0]: - self.stash_buffers[0][dtype] = {} - assert hidden_size not in self.stash_buffers[0][dtype] - num_tokens = int(self.max_tokens_across_vp_stages[dtype, hidden_size] * stash_buffer_size_factor) - self.stash_buffers[0][dtype][hidden_size] = PagedStashBuffer( - num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype - ) - if torch.distributed.get_rank() == 0: - print(f'allocated paged stash buffer dtype={dtype} hidden_size={hidden_size}: {self.stash_buffers[0][dtype][hidden_size]}') - return - # Regular stash buffers - for vp_stage in range(self.vp_size): - self.stash_buffers.append({}) - for dtype in self.max_tokens_per_vp_stage[vp_stage]: - self.stash_buffers[vp_stage][dtype] = {} - for hidden_size in self.max_tokens_per_vp_stage[vp_stage][dtype]: - # Calculate number of tokens we can store (with safety factor) - num_tokens = int(self.max_tokens_per_vp_stage[vp_stage][dtype][hidden_size] * stash_buffer_size_factor) - - # Create buffer (regular) - self.stash_buffers[vp_stage][dtype][hidden_size] = StashBuffer( - num_tokens, hidden_size, self.device, self.overflow, dtype - ) - - if torch.distributed.get_rank() == 0: - buffer_type = "paged" if self.use_paged_stash else "regular" - print(f'allocated {buffer_type} stash buffer vp_stage={vp_stage} dtype={dtype} hidden_size={hidden_size}: {self.stash_buffers[vp_stage][dtype][hidden_size]}') - - def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): - """Update the pp schedule.""" - if self._pp_schedule is None: - self._pp_schedule = [] - # current layer and microbatch for each vp stage for forward pass - self.current_layer = [1 for _ in range(self.vp_size)] - self.current_microbatch = [1 for _ in range(self.vp_size)] - - assert self.vp_size is not None - if layer_no is None: - # forward pass - layer_no = self.current_layer[vp_stage-1] - self.current_layer[vp_stage-1] += 1 - microbatch_no = self.current_microbatch[vp_stage-1] - if self._last_layer: - self.current_layer[vp_stage-1] = 1 - self.current_microbatch[vp_stage-1] += 1 - - if self.status == 'capture': - self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) - num_tokens = self.num_tokens_tensor.item() - - assert self._pp_schedule[self.current_schedule_index] == self.get_schedule_layer(vp_stage, layer_no, microbatch_no), f"schedule {self._pp_schedule[self.current_schedule_index]} != {self.get_schedule_layer(vp_stage, layer_no, microbatch_no)}" - - - return layer_no, microbatch_no - #self._pp_schedule.append(vp_size) - #self._pp_schedule.append(vp_stage) - - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - """ - Hook called when autograd saves a tensor for backward pass. - Returns a tag to identify the tensor later. - """ - - # Handle 0-dim tensors (torch.Size([])) - they have no size(0) - if self.max_num_tokens is None or tensor.dim() == 0 or tensor.size(0) != self.max_num_tokens: - return tensor.detach() - if isinstance(tensor, MXFP8Tensor): - assert tensor._rowwise_data is None, f"rowwise_data is not None; Only columnwise data is supported for packed offloading" - - if self.status == 'capture': - - self.num_tokens = self.num_tokens_tensor.item() - - dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype - # Get hidden_size from tensor shape - if isinstance(tensor, MXFP8Tensor): - hidden_size = tensor._columnwise_data.shape[1] if tensor._columnwise_data.ndim > 1 else tensor._columnwise_data.numel() - else: - hidden_size = tensor.shape[1] if tensor.ndim > 1 else tensor.numel() - - if dtype not in self.temp_tokens_per_vp_stage[self.current_vp_stage]: - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} - self.max_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} - if hidden_size not in self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype]: - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 - self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 - - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] += self.num_tokens - self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = max( - self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size], - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] - ) - if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: - self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 - self.max_tokens_across_vp_stages[dtype, hidden_size] = 0 - - self.temp_tokens_across_vp_stages[dtype, hidden_size] += self.num_tokens - self.max_tokens_across_vp_stages[dtype, hidden_size] = max( - self.max_tokens_across_vp_stages[dtype, hidden_size], - self.temp_tokens_across_vp_stages[dtype, hidden_size] - ) - # Since capture stage does not use CUDA graph, we can truncate the saved tensor to actual num_tokens - # Truncate the tensor to the actual number of tokens - new_size = (self.num_tokens, *tensor.shape[1:]) - - if isinstance(tensor, MXFP8Tensor): - tensor_truncated = torch.empty(new_size, dtype=tensor._columnwise_data.dtype, device=tensor.device) - tensor_truncated.copy_(tensor._columnwise_data[:self.num_tokens, ...]) - tensor._columnwise_data = tensor_truncated - else: - tensor_truncated = torch.empty(new_size, dtype=tensor.dtype, device=tensor.device) - tensor_truncated.copy_(tensor[:self.num_tokens, ...]) - tensor = tensor_truncated - - # Create tensor (paged or regular based on configuration) - assert self.use_paged_stash, "Paged stashing must be used." - if self.use_paged_stash: - packed_tensor = PagedTensor( - tensor, - num_tokens_tensor=self.num_tokens_tensor, - vp_stage=self.current_vp_stage, - schedule_layer_no=self._pp_schedule[self.current_schedule_index] if self._pp_schedule is not None and self.current_schedule_index < len(self._pp_schedule) else None, - layer_name=self._current_layer_name, - max_tokens=self.max_num_tokens, - page_size=self.page_size, - num_d2d_pages=self.num_d2d_pages - ) - - if self.status == 'captured': - self.add_packed_tensor_to_offload(packed_tensor) - return packed_tensor - - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - """ - Hook called when autograd retrieves a saved tensor during backward pass. - Returns the actual tensor (potentially reloading from CPU). - """ - if isinstance(saved_state, (PagedTensor)): - if self.status == 'capture': - num_tokens = saved_state.num_tokens_tensor.item() - self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][saved_state.hidden_size] -= num_tokens - self.temp_tokens_across_vp_stages[saved_state.dtype, saved_state.hidden_size] -= num_tokens - # Pad the tensor to the max number of tokens - npad = self.max_num_tokens - num_tokens - pad = () - for _ in range(saved_state._tensor.ndim-1): - pad = pad + (0, 0) - pad = pad + (0, npad) - if isinstance(saved_state._tensor, MXFP8Tensor): - saved_state._tensor._columnwise_data = torch.nn.functional.pad(saved_state._tensor._columnwise_data, pad) - else: - saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad) - - assert saved_state._tensor is not None, f"saved_state._tensor is None {saved_state._tensor}" - return saved_state._tensor - - return saved_state - -class PackedOffloadContext: - """Wrapper context manager that adds custom enter/exit behavior around saved_tensors_hooks.""" - - def __init__(self, offload_manager): - self.offload_manager = offload_manager - self.saved_tensors_context = torch.autograd.graph.saved_tensors_hooks( - offload_manager.on_save_for_backward, - offload_manager.on_get_saved_tensor - ) - - def __enter__(self): - from megatron.core.extensions.transformer_engine import cpu_offload - - if cpu_offload is not None: - cpu_offload.CPUOffloadEnabled = True - # Call the underlying context manager's __enter__ - result = self.saved_tensors_context.__enter__() - - # Add more custom logic after entering if needed - return result - - def __exit__(self, *args: Any): - # Call the underlying context manager's __exit__ - result = self.saved_tensors_context.__exit__(*args) - from megatron.core.extensions.transformer_engine import cpu_offload - - if cpu_offload is not None: - cpu_offload.CPUOffloadEnabled = False - return result - -def packed_moe_expert_offloading_group_start(tensor, name=None): - """Mark the start of a layer group and prepare for offload/reload.""" - rank = torch.distributed.get_rank() - offload_manager = PackedOffloadManager.get_instance() - if not offload_manager.enabled: - return tensor - return PP_PreScheduleFunction.apply(tensor, offload_manager) - -def get_packed_moe_expert_offloading_context(name=None, max_num_tokens=None, num_tokens_tensor=None): - """Get the fine-grained offload context""" - offload_manager = PackedOffloadManager.get_instance() - if not offload_manager.enabled: - return nullcontext() - offload_manager.max_num_tokens = max_num_tokens - assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) - offload_manager.num_tokens_tensor = num_tokens_tensor - offload_manager.set_current_layer_name(name) if name is not None else None - pack_unpack_context = PackedOffloadContext(offload_manager) - return pack_unpack_context - -def packed_moe_expert_offloading_group_commit(tensor, name=None): - """Mark the end of a layer group and prepare for offload/reload.""" - rank = torch.distributed.get_rank() - offload_manager = PackedOffloadManager.get_instance() - offload_manager.device = tensor.device - if not offload_manager.enabled: - return tensor - return PP_PostScheduleFunction.apply(tensor, offload_manager) - -def packed_moe_expert_offloading_init_chunk_handler(vp_size, vp_stage): - """Initialize the chunk handler, called at the start of a microbatch forward pass.""" - offload_manager = PackedOffloadManager.get_instance() - if not offload_manager.enabled: - return - offload_manager.current_vp_stage = vp_stage if vp_stage is not None else 0 - if vp_size is not None: - offload_manager.vp_size = vp_size - else: - offload_manager.vp_size = 1 - if offload_manager.max_tokens_per_vp_stage is None: - offload_manager.max_tokens_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] - offload_manager.temp_tokens_per_vp_stage = [{} for _ in range(offload_manager.vp_size)] - if offload_manager.max_tokens_across_vp_stages is None: - offload_manager.max_tokens_across_vp_stages = {} - offload_manager.temp_tokens_across_vp_stages = {} - -def packed_moe_expert_offloading_set_last_layer(is_last_layer=False): - """Set the last layer flag.""" - offload_manager = PackedOffloadManager.get_instance() - if not offload_manager.enabled: - return - offload_manager._last_layer = is_last_layer - -def packed_moe_expert_offloading_reset(enabled=True): - """Reset the chunk handler, called at the start of a training iteration.""" - offload_manager = PackedOffloadManager.get_instance() - offload_manager.enabled = enabled - offload_manager.iteration += 1 - # current layer and microbatch for each vp stage for forward pass - offload_manager.current_schedule_index = 0 - - if not enabled: - return - - set_ideal_affinity_for_current_gpu() # Set the ideal affinity for the current GPU - if offload_manager.status == 'begin': - offload_manager.status = 'capture' - elif offload_manager.status == 'capture': - offload_manager.status = 'captured' - stash_buffer_size_factor = float(os.getenv('STASH_BUFFER_SIZE_FACTOR', '1.10')) - offload_manager.allocate_offload_buffers(stash_buffer_size_factor=stash_buffer_size_factor) - elif offload_manager.status == 'captured': - pass - - if offload_manager.status == 'captured': - if not torch.cuda.is_current_stream_capturing(): - overflow = offload_manager.overflow.item() - assert overflow == 0, f"PackedOffloadManager overflow!!!" - - for vp_buffers in offload_manager.stash_buffers: - for dtype in vp_buffers.keys(): - for hidden_size in vp_buffers[dtype].keys(): - vp_buffers[dtype][hidden_size].reset() - offload_manager.overflow.zero_() - offload_manager.current_layer = [1 for _ in range(offload_manager.vp_size)] - offload_manager.current_microbatch = [1 for _ in range(offload_manager.vp_size)] - assert len(offload_manager.packed_tensors_to_offload) == 0, f"packed_tensors_to_offload is not empty {offload_manager.packed_tensors_to_offload}" - assert len(offload_manager.packed_tensors_offload_in_progress) == 0, f"packed_tensors_offload_in_progress is not empty {offload_manager.packed_tensors_offload_in_progress}" - diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 75824d4f3a7..5d99fa20a8a 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -12,8 +12,8 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.moe_packed_offload import ( - packed_moe_expert_offloading_reset, +from megatron.core.transformer.moe.paged_stash import ( + paged_stash_reset, ) from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( @@ -593,7 +593,7 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - packed_moe_expert_offloading_reset(enabled=config.packed_moe_expert_offloading and not forward_only) + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only) no_sync_func = config.no_sync_func if no_sync_func is None: @@ -1054,7 +1054,7 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" - packed_moe_expert_offloading_reset(enabled=config.packed_moe_expert_offloading and not forward_only) + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only) if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -2238,7 +2238,7 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - packed_moe_expert_offloading_reset(enabled=config.packed_moe_expert_offloading and not forward_only) + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only) # Disable async grad reductions no_sync_func = config.no_sync_func diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 25b43f2c3db..fab6d1039d2 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -33,11 +33,11 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.moe_packed_offload import ( - packed_moe_expert_offloading_group_start, - get_packed_moe_expert_offloading_context, - packed_moe_expert_offloading_reset, - packed_moe_expert_offloading_group_commit, +from megatron.core.transformer.moe.paged_stash import ( + paged_stash_group_start, + get_paged_stash_context, + paged_stash_reset, + paged_stash_group_commit, ) from megatron.core.tensor_parallel.layers import ( _initialize_affine_weight_cpu, @@ -705,9 +705,9 @@ def __init__( and "moe_act" in self.config.offload_modules ) - self.packed_offload_expert_fc1 = self.config.packed_moe_expert_offloading and "expert_fc1" in self.config.offload_modules - self.packed_offload_moe_act = self.config.packed_moe_expert_offloading and "moe_act" in self.config.offload_modules - self.packed_offload_expert_fc2 = self.config.packed_moe_expert_offloading and "expert_fc2" in self.config.offload_modules + self.moe_paged_stash_expert_fc1 = self.config.moe_paged_stash and "expert_fc1" in self.config.offload_modules + self.moe_paged_stash_moe_act = self.config.moe_paged_stash and "moe_act" in self.config.offload_modules + self.moe_paged_stash_expert_fc2 = self.config.moe_paged_stash and "expert_fc2" in self.config.offload_modules self.activation_recompute = ( self.config.recompute_granularity == 'selective' @@ -1012,10 +1012,10 @@ def forward( with off_interface( self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" ) as permuted_local_hidden_states: - if self.config.packed_moe_expert_offloading: - permuted_local_hidden_states = packed_moe_expert_offloading_group_start(permuted_local_hidden_states, name="expert_fc1") - if self.packed_offload_expert_fc1: - offload_context = get_packed_moe_expert_offloading_context(name="expert_fc1", max_num_tokens=permuted_local_hidden_states.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + if self.config.moe_paged_stash: + permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states, name="expert_fc1") + if self.moe_paged_stash_expert_fc1: + offload_context = get_paged_stash_context(name="expert_fc1", max_num_tokens=permuted_local_hidden_states.shape[0], num_tokens_tensor=tokens_per_expert.sum()) else: offload_context = nullcontext() with offload_context: @@ -1069,7 +1069,7 @@ def remove_glu_interleaving(x: torch.Tensor) -> torch.Tensor: bias_parallel, permuted_probs, self.config.activation_func_fp8_input_store, - tokens_per_expert.sum() if self.packed_offload_moe_act else None, + tokens_per_expert.sum() if self.moe_paged_stash_moe_act else None, ) elif self.activation_func == quick_gelu and self.config.gated_linear_unit: @@ -1122,8 +1122,8 @@ def glu(x): ) else: with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: - if self.packed_offload_moe_act: - offload_context = get_packed_moe_expert_offloading_context(name="moe_act", max_num_tokens=fc1_output.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + if self.moe_paged_stash_moe_act: + offload_context = get_paged_stash_context(name="moe_act", max_num_tokens=fc1_output.shape[0], num_tokens_tensor=tokens_per_expert.sum()) else: offload_context = nullcontext() with offload_context: @@ -1133,14 +1133,14 @@ def glu(x): bias_act_output, name="moe_act", forced_released_tensors=[fc1_output] ) - if self.packed_offload_expert_fc2: - offload_context = get_packed_moe_expert_offloading_context(name="expert_fc2", max_num_tokens=bias_act_output.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + if self.moe_paged_stash_expert_fc2: + offload_context = get_paged_stash_context(name="expert_fc2", max_num_tokens=bias_act_output.shape[0], num_tokens_tensor=tokens_per_expert.sum()) else: offload_context = nullcontext() with offload_context: output, output_bias = apply_module(self.linear_fc2)(bias_act_output, tokens_per_expert) - if self.config.packed_moe_expert_offloading: - output = packed_moe_expert_offloading_group_commit(output, name="expert_fc2") + if self.config.moe_paged_stash: + output = paged_stash_group_commit(output, name="expert_fc2") if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 9fb40446126..411a73e4681 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -17,8 +17,8 @@ from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.pipeline_parallel.moe_packed_offload import ( - packed_moe_expert_offloading_set_last_layer, +from megatron.core.transformer.moe.paged_stash import ( + paged_stash_set_last_layer, ) from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection @@ -895,8 +895,8 @@ def forward( mhc_manager.is_last_layer_in_recompute_block = ( mhc_is_last_in_recompute_block[l_no] ) - if self.config.packed_moe_expert_offloading: - packed_moe_expert_offloading_set_last_layer( + if self.config.moe_paged_stash: + paged_stash_set_last_layer( is_last_layer = (l_no == self.num_layers_per_pipeline_rank - 1) ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 6195c905a9a..803619519f2 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -993,8 +993,8 @@ class TransformerConfig(ModelParallelConfig): min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" - packed_moe_expert_offloading: bool = False - """If True, enable packed moe expert offloading.""" + moe_paged_stash: bool = False + """If True, enable paged stash for MoE expert activations.""" def __post_init__(self): @@ -1451,10 +1451,10 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) - if self.packed_moe_expert_offloading: + if self.moe_paged_stash: assert ( not self.cpu_offloading and not self.fine_grained_activation_offloading - ), "packed_moe_expert_offloading cannot be enabled with cpu_offloading." + ), "paged_stash cannot be enabled with cpu_offloading." assert self.offload_modules is not None and len(self.offload_modules) > 0 allowed_modules = {"expert_fc1", "expert_fc2", "moe_act"} invalid_modules = set(self.offload_modules) - allowed_modules diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 8ab4c7e7785..7cbbdfe4011 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1409,12 +1409,12 @@ def validate_args(args, defaults={}): if is_te_min_version("2.10.0"): assert os.getenv("NVTE_CPU_OFFLOAD_V1", "0") == "1", \ "For fine-grained activation offloading with TE >= 2.10.0, NVTE_CPU_OFFLOAD_V1 should be set to 1 to avoid offloading weights." - assert not args.packed_moe_expert_offloading, "Fine-grained activation offloading and packed moe expert offloading cannot be enabled at the same time" + assert not args.moe_paged_stash, "Fine-grained activation offloading and paged stash cannot be enabled at the same time" - if args.packed_moe_expert_offloading: + if args.moe_paged_stash: assert args.transformer_impl == 'transformer_engine', \ - "Packed moe expert offloading is only supported with transformer_engine implementation" - assert not args.fine_grained_activation_offloading, "Packed moe expert offloading and fine-grained activation offloading cannot be enabled at the same time" + "Paged stash is only supported with transformer_engine implementation" + assert not args.fine_grained_activation_offloading, "Paged stash and fine-grained activation offloading cannot be enabled at the same time" if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." From 922689a4825be8b2e5870df4d2622a2ed86a2ddd Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Thu, 18 Dec 2025 15:58:55 +0800 Subject: [PATCH 29/57] Check in paged_stash.py that was omited in the previous commit --- megatron/core/transformer/moe/paged_stash.py | 917 +++++++++++++++++++ 1 file changed, 917 insertions(+) create mode 100644 megatron/core/transformer/moe/paged_stash.py diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py new file mode 100644 index 00000000000..289d9b34879 --- /dev/null +++ b/megatron/core/transformer/moe/paged_stash.py @@ -0,0 +1,917 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import warnings +from contextlib import nullcontext +from typing import Any +import os +import torch +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +import triton +import triton.language as tl + +GLOBAL_BLOCK_SIZE = 1024 + +class PagedStashBuffer: + """ + A paged stash buffer with page-level memory management. + + The buffer is organized as [num_pages, page_size, hidden_size]. + Uses a free list (circular buffer) to track available pages. + """ + + def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype): + """ + Args: + num_tokens: Maximum number of tokens the buffer can hold + hidden_size: Hidden dimension size + page_size: Number of tokens per page + device: Device for the buffer + overflow: Overflow flag tensor (shared across all buffers) + dtype: Data type + """ + self.hidden_size = hidden_size + self.page_size = page_size + self.num_pages = (num_tokens + page_size - 1) // page_size # Ceiling division + self.total_tokens = self.num_pages * page_size + + # Create 2D buffer [total_tokens, hidden_size] + # Organized as pages: [page_0_tokens, page_1_tokens, ...] + if os.getenv('PAGED_STASH_TO_CPU', '0') == '1': + self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True) + else: + self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device=device) + + self.overflow = overflow # GPU flag (shared) + self.device = device + self.dtype = dtype + + # Free list as circular buffer: stores available page IDs + self.free_list = torch.arange(self.num_pages, dtype=torch.int64, device=device) + + # Head and tail pointers for free_list circular buffer + self.free_list_head = torch.zeros(1, dtype=torch.int64, device=device) # Read pointer (allocation) + self.free_list_tail = self.num_pages*torch.ones(1, dtype=torch.int64, device=device) # Write pointer (deallocation) + + # Capacity of free list + self.free_list_capacity = self.num_pages*torch.ones(1, dtype=torch.int64, device=device) + + def reset(self): + """Reset the paged buffer - reinitialize free list.""" + self.free_list.copy_(torch.arange(self.num_pages, dtype=torch.int64, device=self.device)) + self.free_list_head.zero_() + self.free_list_tail.fill_(self.num_pages) + + def __repr__(self): + return f"PagedStashBuffer(num_pages={self.num_pages}, page_size={self.page_size}, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" + + +@triton.jit +def _paged_stash_copy_kernel( + src_ptr, + dst_ptr, + num_tokens_ptr, + free_list_ptr, + free_list_head_ptr, # Read-only: current head position + free_list_tail_ptr, # Read-only: current tail position (for overflow check) + free_list_capacity_ptr, + page_record_ptr, # Output: records which pages were used + overflow_ptr, + new_free_list_head_ptr, # Output: new head position (written by kernel) + PAGE_SIZE: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel to copy tokens to paged stash buffer. + + Allocates pages from free list (reads from head, advances head). + Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. + Grid: (num_blocks,) where blocks process tokens in a strided pattern. + Writes new head to temporary tensor to avoid race conditions. + """ + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load parameters + num_tokens = tl.load(num_tokens_ptr) + free_list_head = tl.load(free_list_head_ptr) + free_list_tail = tl.load(free_list_tail_ptr) + free_list_capacity = tl.load(free_list_capacity_ptr) + + # Check available pages (unwrapped indices: simple subtraction, no modulo needed) + avail_pages = free_list_tail - free_list_head + + # Calculate required pages + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + overflow_detected = avail_pages < required_pages + + # Only block 0 writes overflow flag + if pid == 0 and overflow_detected: + tl.store(overflow_ptr, 1) + + # All blocks return early if overflow + if overflow_detected: + return + + # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] + token_idx = pid + while token_idx < num_tokens: + # Determine which page this token belongs to + page_slot = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + + # Read page ID from free list (with wraparound) + free_list_idx = (free_list_head + page_slot) % free_list_capacity + page_id = tl.load(free_list_ptr + free_list_idx) + + # First token in page: record the page ID (only if this block handles token 0 of the page) + if token_in_page == 0: + tl.store(page_record_ptr + page_slot, page_id) + + # Calculate destination address in paged buffer + dst_token_idx = page_id * PAGE_SIZE + token_in_page + + # Copy token data (2D: hidden dimension) + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + + src_base = src_ptr + token_idx * HIDDEN_SIZE + dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE + + if need_mask: + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + + # Stride to next token for this block + token_idx += num_blocks + + # Calculate and store new free list head (only block 0) + # We consumed pages, so advance head forward (unwrapped: no modulo) + # Write to temporary tensor to avoid race conditions + if pid == 0: + new_head = free_list_head + required_pages + tl.store(new_free_list_head_ptr, new_head) + + +@triton.jit +def _paged_stash_pop_kernel( + src_ptr, + dst_ptr, + num_tokens_ptr, + page_record_ptr, # Input: which pages to read + free_list_ptr, + free_list_head_ptr, # Read-only: current head position (not used) + free_list_tail_ptr, # Read-only: current tail position + free_list_capacity_ptr, + new_free_list_tail_ptr, # Output: new tail position (written by kernel) + PAGE_SIZE: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel to reload tokens from paged stash buffer. + + Returns pages to free list (writes to tail, advances tail). + Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. + Grid: (num_blocks,) where blocks process tokens in a strided pattern. + Writes new tail to temporary tensor to avoid race conditions. + """ + pid = tl.program_id(axis=0) + num_blocks = tl.num_programs(axis=0) + + # Load parameters + num_tokens = tl.load(num_tokens_ptr) + free_list_tail = tl.load(free_list_tail_ptr) + free_list_capacity = tl.load(free_list_capacity_ptr) + + # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] + token_idx = pid + while token_idx < num_tokens: + # Determine which page this token belongs to + page_slot = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + + # Read page ID from page record + page_id = tl.load(page_record_ptr + page_slot) + + # Calculate source address in paged buffer + src_token_idx = page_id * PAGE_SIZE + token_in_page + + # Copy token data (2D: hidden dimension) + elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE + need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 + num_iters = elements_per_thread + (1 if need_mask else 0) + + src_base = src_ptr + src_token_idx * HIDDEN_SIZE + dst_base = dst_ptr + token_idx * HIDDEN_SIZE + + if need_mask: + for iter in range(num_iters): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + hidden_mask = hidden_offsets < HIDDEN_SIZE + data = tl.load(src_base + hidden_offsets, mask=hidden_mask, other=0) + tl.store(dst_base + hidden_offsets, data, mask=hidden_mask) + else: + for iter in range(elements_per_thread): + hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE + data = tl.load(src_base + hidden_offsets) + tl.store(dst_base + hidden_offsets, data) + + # First token in page: release page back to free list + if token_in_page == 0: + # Write page ID back to free list at tail position (with wraparound) + write_idx = (free_list_tail + page_slot) % free_list_capacity + tl.store(free_list_ptr + write_idx, page_id) + + # Stride to next token for this block + token_idx += num_blocks + + # Calculate and store new free list tail (only block 0) + # We returned pages, so advance tail forward (unwrapped: no modulo) + # Write to temporary tensor to avoid race conditions + if pid == 0: + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + new_tail = free_list_tail + required_pages + tl.store(new_free_list_tail_ptr, new_tail) + + +class PagedTensor: + """ + A paged tensor that stores data in pages within a paged stash buffer. + """ + + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0): + """ + Args: + tensor: The tensor to store + num_tokens_tensor: Scalar tensor containing actual number of tokens + vp_stage: Virtual pipeline stage + layer_name: Name of the layer + max_tokens: Maximum number of tokens + page_size: Number of tokens per page + num_d2d_pages: Number of pages to copy using native PyTorch (rest uses Triton) + """ + self._tensor = tensor + self._original_tensor = None + assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1 + self.num_tokens_tensor = num_tokens_tensor.clone() + self.vp_stage = vp_stage + self.schedule_layer_no = schedule_layer_no + self.layer_name = layer_name + self.max_tokens = max_tokens + self.page_size = page_size + self.num_d2d_pages = num_d2d_pages + + # Original tensor information + self.original_shape = list(tensor.shape) + self.max_num_tokens = self.original_shape[0] + self.element_size = tensor.element_size() + self.hidden_size = self.original_shape[1] + self.dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype + self.device = tensor.device + + # Calculate number of pages needed + self.max_num_pages = (self.max_num_tokens + page_size - 1) // page_size # Ceiling division + + # Page record: stores which pages are being used for this tensor + self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) + + # Static tensor for D2D pages (allocate upfront if needed) + d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) + if d2d_tokens > 0: + self.static_tensor = torch.empty((d2d_tokens, self.hidden_size), dtype=self.dtype, device=self.device) + else: + self.static_tensor = None + + @property + def schedule_layer(self): + """Get the schedule layer.""" + return self.schedule_layer_no + + def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): + """Offload the paged tensor to paged stash buffer. + + Args: + paged_stash_buffer: The paged stash buffer to offload to + max_blocks: Maximum number of blocks for Triton kernel + """ + self._tensor = self._tensor.contiguous() + if self.num_tokens_tensor.dim() == 0: + self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) + + # Get 2D tensor + if isinstance(self._tensor, MXFP8Tensor): + tensor_to_copy = self._tensor._columnwise_data + else: + tensor_to_copy = self._tensor + + # Split tensor into two parts: D2D portion and Triton portion + # Use max_num_tokens for consistent size across iterations + d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) + triton_tokens = self.max_num_tokens - d2d_tokens + + # Perform both D2D copy and Triton kernel together + # Part 1: Copy first d2d_tokens to static_tensor using native PyTorch + if d2d_tokens > 0: + self.static_tensor[:d2d_tokens] = tensor_to_copy[:d2d_tokens] + # Part 2: Copy remaining tokens using Triton kernel + if triton_tokens > 0: + triton_tensor = tensor_to_copy[d2d_tokens:self.max_num_tokens] + # Use actual num_tokens for the kernel (how many tokens to actually copy) + triton_num_tokens = self.num_tokens_tensor - d2d_tokens + + # Determine grid size + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(triton_tokens, max_blocks) + grid = (num_blocks,) + + # Create temporary tensor for new head + new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch paged stash copy kernel + _paged_stash_copy_kernel[grid]( + triton_tensor, + paged_stash_buffer.buffer, + triton_num_tokens, + paged_stash_buffer.free_list, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + self.page_record, # Triton kernel will populate page_record + paged_stash_buffer.overflow, + new_free_list_head, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Update free list head + paged_stash_buffer.free_list_head.copy_(new_free_list_head) + + # Save reference to original tensor + self._original_tensor = self._tensor + self._tensor = None + + def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): + """Reload the paged tensor from paged stash buffer. + + Args: + paged_stash_buffer: The paged stash buffer to reload from + max_blocks: Maximum number of blocks for Triton kernel + """ + # Allocate output tensor + if isinstance(self._original_tensor, MXFP8Tensor): + columnwise_data = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) + self._tensor = MXFP8Tensor( + shape=self._original_tensor.shape, + dtype=self._original_tensor.dtype, + fp8_dtype=self._original_tensor._fp8_dtype, + rowwise_data=self._original_tensor._rowwise_data, + rowwise_scale_inv=self._original_tensor._rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=self._original_tensor._columnwise_scale_inv, + quantizer=self._original_tensor._quantizer, + ) + tensor_to_reload = self._tensor._columnwise_data + else: + self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) + tensor_to_reload = self._tensor + + # Split tensor into two parts: D2D portion and Triton portion + # Use max_num_tokens for consistency with stash + d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) + triton_tokens = self.max_num_tokens - d2d_tokens + + # Perform both D2D copy and Triton kernel together + # Part 1: Copy first d2d_tokens from static_tensor using native PyTorch + if d2d_tokens > 0 and self.static_tensor is not None: + tensor_to_reload[:d2d_tokens] = self.static_tensor[:d2d_tokens] + + # Part 2: Copy remaining tokens using Triton kernel + if triton_tokens > 0: + triton_tensor = tensor_to_reload[d2d_tokens:self.max_num_tokens] + # Use actual num_tokens for the kernel (how many tokens to actually copy) + triton_num_tokens = self.num_tokens_tensor - d2d_tokens + + # Determine grid size + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(triton_tokens, max_blocks) + grid = (num_blocks,) + + # Create temporary tensor for new tail + new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch paged stash pop kernel + _paged_stash_pop_kernel[grid]( + paged_stash_buffer.buffer, + triton_tensor, + triton_num_tokens, + self.page_record, # Triton kernel will read from page_record + paged_stash_buffer.free_list, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + new_free_list_tail, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Update free list tail + paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) + + +class PP_PreScheduleFunction(torch.autograd.Function): + """ + This function is used to update the pp schedule. + """ + + @staticmethod + def forward(ctx, tensor, stash_manager): # after forward + # pylint: disable=missing-function-docstring + ctx.stash_manager = stash_manager + # Wait for stash to complete before starting the next layer + stash_manager.wait_for_stash_to_complete() + return tensor + + @staticmethod + def backward(ctx, *grad_output): # before backward + # pylint: disable=missing-function-docstring + # Initiate reload for next layer + if ctx.stash_manager.status == 'captured' and ctx.stash_manager.current_schedule_index < len(ctx.stash_manager._pp_schedule): + next_schedule_layer = ctx.stash_manager._pp_schedule[ctx.stash_manager.current_schedule_index] + if next_schedule_layer < 0: + ctx.stash_manager.reload_paged_tensors(-next_schedule_layer) + + return grad_output + (None, None) + +class PP_PostScheduleFunction(torch.autograd.Function): + """ + This function is used to update the pp schedule. + """ + + @staticmethod + def forward(ctx, tensor, stash_manager): # after forward + # pylint: disable=missing-function-docstring + + ctx.stash_manager = stash_manager + ctx.vp_stage = stash_manager.current_vp_stage + if ctx.vp_stage is None: + ctx.vp_stage = 0 + ctx.layer_no, ctx.microbatch_no = stash_manager.update_pp_schedule(ctx.vp_stage+1) + + # Initiate stash for current layer and reload for next layer + if stash_manager.status == 'captured': + current_schedule_layer = stash_manager.get_schedule_layer(ctx.vp_stage+1, ctx.layer_no, ctx.microbatch_no) + next_schedule_layer = ctx.stash_manager._pp_schedule[ctx.stash_manager.current_schedule_index+1] + if current_schedule_layer != -next_schedule_layer: + # Start stash for current layer + ctx.stash_manager.stash_paged_tensors(current_schedule_layer) + if next_schedule_layer < 0: + # reload for next backward layer + ctx.stash_manager.reload_paged_tensors(-next_schedule_layer, no_wait=True) + else: + ctx.stash_manager.remove_paged_tensor_from_stash() + + ctx.stash_manager.current_schedule_index += 1 + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, *grad_output): # before backward + # pylint: disable=missing-function-docstring + if ctx.vp_stage is not None: + ctx.stash_manager.update_pp_schedule(-(ctx.vp_stage+1), -ctx.layer_no, -ctx.microbatch_no) + ctx.stash_manager.current_schedule_index += 1 + current_stream = torch.cuda.current_stream() + + ctx.stash_manager.wait_for_stash_to_complete() + if ctx.stash_manager._unpack_stream_status == 'reloading': + current_stream.wait_stream(ctx.stash_manager.unpack_stream) + ctx.stash_manager._unpack_stream_status = 'idle' + + return grad_output + (None, None) + +class PagedStashManager: + """ + Singleton manager for coordinating paged stashing across pipeline stages. + Manages chunk handlers, synchronizes GPU-GPU transfers, + and handles virtual pipeline parallelism + """ + + STASH_MGR = None + + @classmethod + def get_instance(cls): + """Get the singleton instance of PagedStashManager.""" + if cls.STASH_MGR is None: + cls.STASH_MGR = PagedStashManager() + return cls.STASH_MGR + + def __init__(self): + """Initialize the manager with queues and dedicated CUDA streams.""" + # allocate streams and events for synchronization + self.enabled = False + self._pack_stream = torch.cuda.Stream() + # Currently paged stashing is not stream-safe, so use the same stream for packing + # and unpacking + self._unpack_stream = self._pack_stream + self._pack_stream_status = 'idle' # idle, stashing + self._unpack_stream_status = 'idle' # idle, reloading + self.paged_tensors_to_stash = [] + self.paged_tensors_stash_in_progress = [] + self.paged_tensors_to_reload = {} + + self.iteration = 0 + self._current_layer_name = None + self.vp_size = None + self.current_vp_stage = None + self._last_layer = False + self.status = 'begin' # begin, capture, captured + self._pp_schedule = None # If element is +ve, it denotes forward pass of vp stage, if -ve, it denotes backward pass of vp stage + self.current_layer = None + self.current_microbatch = None + self.current_schedule_index = None + + # Track max tokens needed per vp_stage, dtype, and hidden_size + self.max_tokens_per_vp_stage = None + self.temp_tokens_per_vp_stage = None + # Track max tokens needed across all vp_stages grouped by dtype and hidden_size + self.max_tokens_across_vp_stages = None + self.temp_tokens_across_vp_stages = None + + self.num_tokens_tensor = None + self.max_num_tokens = None + self.stash_buffers = None + self.overflow = None + self.device = None + + # Page size for paged memory management + self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page + + # Number of pages to copy using native PyTorch (D2D) + self.num_d2d_pages = int(os.getenv('NUM_D2D_PAGES', '0')) # Default 0 (all Triton) + + @property + def pack_stream(self): + """Get the pack stream.""" + return self._pack_stream + + @property + def unpack_stream(self): + """Get the unpack stream.""" + return self._unpack_stream + + def set_current_layer_name(self, name): + """Set the current layer name.""" + self._current_layer_name = name + + def get_schedule_layer(self, vp_stage, layer_no, microbatch_no): + """Get the schedule layer.""" + return vp_stage*1000000 + layer_no*1000 + microbatch_no + + def add_paged_tensor_to_stash(self, paged_tensor): + """Add a paged tensor to the stash list.""" + if self.status == 'captured': + self.paged_tensors_to_stash.append(paged_tensor) + else: + pass + + def remove_paged_tensor_from_stash(self): + """Remove all paged tensors from the stash list.""" + if self.status == 'captured': + while len(self.paged_tensors_to_stash) > 0: + paged_tensor = self.paged_tensors_to_stash.pop(0) + assert len(self.paged_tensors_to_stash) == 0, f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" + else: + pass + + def stash_paged_tensors(self, pp_schedule_layer): + """Stash the paged tensors.""" + current_stream = torch.cuda.current_stream() + self.pack_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.pack_stream): + if self.status == 'captured': + self._pack_stream_status = 'stashing' + if pp_schedule_layer not in self.paged_tensors_to_reload: + self.paged_tensors_to_reload[pp_schedule_layer] = [] + assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, f"paged_tensors_to_reload {pp_schedule_layer} is not empty {self.paged_tensors_to_reload[pp_schedule_layer]}" + + while len(self.paged_tensors_to_stash) > 0: + paged_tensor = self.paged_tensors_to_stash.pop(0) + stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] + paged_tensor.offload_to_stash(stash_buffer) + self.paged_tensors_to_reload[pp_schedule_layer].append(paged_tensor) + self.paged_tensors_stash_in_progress.append(paged_tensor) + else: + pass + assert len(self.paged_tensors_to_stash) == 0, f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" + + def wait_for_stash_to_complete(self): + """Wait for stash to complete.""" + current_stream = torch.cuda.current_stream() + if self._pack_stream_status == 'stashing': + current_stream.wait_stream(self.pack_stream) + self._pack_stream_status = 'idle' + + # Deallocate original tensor after stash is complete + while len(self.paged_tensors_stash_in_progress) > 0: + paged_tensor = self.paged_tensors_stash_in_progress.pop(0) + if isinstance(paged_tensor._original_tensor, MXFP8Tensor): + paged_tensor._original_tensor._columnwise_data = None + else: + paged_tensor._original_tensor = None + + def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): + """Reload the paged tensors.""" + # Avoid waiting for main stream if reload is immediately after stash since stash is already waiting for main stream + if not no_wait or self.unpack_stream != self.pack_stream: + current_stream = torch.cuda.current_stream() + self.unpack_stream.wait_stream(current_stream) + + with torch.cuda.stream(self.unpack_stream): + if self.status == 'captured': + self._unpack_stream_status = 'reloading' + count = 0 + for item in self.paged_tensors_to_reload: + if len(self.paged_tensors_to_reload[item]) > 0: + count += 1 + + while len(self.paged_tensors_to_reload[pp_schedule_layer]) > 0: + paged_tensor = self.paged_tensors_to_reload[pp_schedule_layer].pop(0) + stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] + paged_tensor.reload_from_stash(stash_buffer) + else: + pass + assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, f"paged_tensors_to_reload {pp_schedule_layer} is not empty {self.paged_tensors_to_reload[pp_schedule_layer]}" + + + def allocate_stash_buffers(self, stash_buffer_size_factor=1.10): + """Allocate stash buffers organized by [dtype][hidden_size].""" + self.stash_buffers = {} + self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) + + for dtype, hidden_size in self.max_tokens_across_vp_stages: + if dtype not in self.stash_buffers: + self.stash_buffers[dtype] = {} + assert hidden_size not in self.stash_buffers[dtype] + num_tokens = int( + self.max_tokens_across_vp_stages[dtype, hidden_size] * stash_buffer_size_factor + ) + self.stash_buffers[dtype][hidden_size] = PagedStashBuffer( + num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype + ) + + def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): + """Update the pp schedule.""" + if self._pp_schedule is None: + self._pp_schedule = [] + # current layer and microbatch for each vp stage for forward pass + self.current_layer = [1 for _ in range(self.vp_size)] + self.current_microbatch = [1 for _ in range(self.vp_size)] + + assert self.vp_size is not None + if layer_no is None: + # forward pass + layer_no = self.current_layer[vp_stage-1] + self.current_layer[vp_stage-1] += 1 + microbatch_no = self.current_microbatch[vp_stage-1] + if self._last_layer: + self.current_layer[vp_stage-1] = 1 + self.current_microbatch[vp_stage-1] += 1 + + if self.status == 'capture': + self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) + num_tokens = self.num_tokens_tensor.item() + + assert self._pp_schedule[self.current_schedule_index] == self.get_schedule_layer(vp_stage, layer_no, microbatch_no), f"schedule {self._pp_schedule[self.current_schedule_index]} != {self.get_schedule_layer(vp_stage, layer_no, microbatch_no)}" + + + return layer_no, microbatch_no + #self._pp_schedule.append(vp_size) + #self._pp_schedule.append(vp_stage) + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """ + Hook called when autograd saves a tensor for backward pass. + Returns a tag to identify the tensor later. + """ + + # Handle 0-dim tensors (torch.Size([])) - they have no size(0) + if self.max_num_tokens is None or tensor.dim() == 0 or tensor.size(0) != self.max_num_tokens: + return tensor.detach() + if isinstance(tensor, MXFP8Tensor): + assert tensor._rowwise_data is None, f"rowwise_data is not None; Only columnwise data is supported for paged stashing" + + if self.status == 'capture': + + self.num_tokens = self.num_tokens_tensor.item() + + dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype + # Get hidden_size from tensor shape + if isinstance(tensor, MXFP8Tensor): + hidden_size = tensor._columnwise_data.shape[1] if tensor._columnwise_data.ndim > 1 else tensor._columnwise_data.numel() + else: + hidden_size = tensor.shape[1] if tensor.ndim > 1 else tensor.numel() + + if dtype not in self.temp_tokens_per_vp_stage[self.current_vp_stage]: + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} + self.max_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} + if hidden_size not in self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype]: + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 + self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 + + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] += self.num_tokens + self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = max( + self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size], + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] + ) + if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: + self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.max_tokens_across_vp_stages[dtype, hidden_size] = 0 + + self.temp_tokens_across_vp_stages[dtype, hidden_size] += self.num_tokens + self.max_tokens_across_vp_stages[dtype, hidden_size] = max( + self.max_tokens_across_vp_stages[dtype, hidden_size], + self.temp_tokens_across_vp_stages[dtype, hidden_size] + ) + # Since capture stage does not use CUDA graph, we can truncate the saved tensor to actual num_tokens + # Truncate the tensor to the actual number of tokens + new_size = (self.num_tokens, *tensor.shape[1:]) + + if isinstance(tensor, MXFP8Tensor): + tensor_truncated = torch.empty(new_size, dtype=tensor._columnwise_data.dtype, device=tensor.device) + tensor_truncated.copy_(tensor._columnwise_data[:self.num_tokens, ...]) + tensor._columnwise_data = tensor_truncated + else: + tensor_truncated = torch.empty(new_size, dtype=tensor.dtype, device=tensor.device) + tensor_truncated.copy_(tensor[:self.num_tokens, ...]) + tensor = tensor_truncated + + + paged_tensor = PagedTensor( + tensor, + num_tokens_tensor=self.num_tokens_tensor, + vp_stage=self.current_vp_stage, + schedule_layer_no=self._pp_schedule[self.current_schedule_index] if self._pp_schedule is not None and self.current_schedule_index < len(self._pp_schedule) else None, + layer_name=self._current_layer_name, + max_tokens=self.max_num_tokens, + page_size=self.page_size, + num_d2d_pages=self.num_d2d_pages + ) + + if self.status == 'captured': + self.add_paged_tensor_to_stash(paged_tensor) + return paged_tensor + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """ + Hook called when autograd retrieves a saved tensor during backward pass. + Returns the actual tensor (potentially reloading from CPU). + """ + if isinstance(saved_state, (PagedTensor)): + if self.status == 'capture': + num_tokens = saved_state.num_tokens_tensor.item() + self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][saved_state.hidden_size] -= num_tokens + self.temp_tokens_across_vp_stages[saved_state.dtype, saved_state.hidden_size] -= num_tokens + # Pad the tensor to the max number of tokens + npad = self.max_num_tokens - num_tokens + pad = () + for _ in range(saved_state._tensor.ndim-1): + pad = pad + (0, 0) + pad = pad + (0, npad) + if isinstance(saved_state._tensor, MXFP8Tensor): + saved_state._tensor._columnwise_data = torch.nn.functional.pad(saved_state._tensor._columnwise_data, pad) + else: + saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad) + + assert saved_state._tensor is not None, f"saved_state._tensor is None {saved_state._tensor}" + return saved_state._tensor + + return saved_state + +class PagedStashContext: + """Wrapper context manager that adds custom enter/exit behavior around saved_tensors_hooks.""" + + def __init__(self, stash_manager): + self.stash_manager = stash_manager + self.saved_tensors_context = torch.autograd.graph.saved_tensors_hooks( + stash_manager.on_save_for_backward, + stash_manager.on_get_saved_tensor + ) + + def __enter__(self): + from megatron.core.extensions.transformer_engine import cpu_offload + + if cpu_offload is not None: + cpu_offload.CPUOffloadEnabled = True + # Call the underlying context manager's __enter__ + result = self.saved_tensors_context.__enter__() + + # Add more custom logic after entering if needed + return result + + def __exit__(self, *args: Any): + # Call the underlying context manager's __exit__ + result = self.saved_tensors_context.__exit__(*args) + from megatron.core.extensions.transformer_engine import cpu_offload + + if cpu_offload is not None: + cpu_offload.CPUOffloadEnabled = False + return result + +def paged_stash_group_start(tensor, name=None): + """Mark the start of a layer group and prepare for stash/reload.""" + rank = torch.distributed.get_rank() + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return tensor + return PP_PreScheduleFunction.apply(tensor, stash_manager) + +def get_paged_stash_context(name=None, max_num_tokens=None, num_tokens_tensor=None): + """Get the paged stash context""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return nullcontext() + stash_manager.max_num_tokens = max_num_tokens + assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) + stash_manager.num_tokens_tensor = num_tokens_tensor + stash_manager.set_current_layer_name(name) if name is not None else None + pack_unpack_context = PagedStashContext(stash_manager) + return pack_unpack_context + +def paged_stash_group_commit(tensor, name=None): + """Mark the end of a layer group and prepare for stash/reload.""" + rank = torch.distributed.get_rank() + stash_manager = PagedStashManager.get_instance() + stash_manager.device = tensor.device + if not stash_manager.enabled: + return tensor + return PP_PostScheduleFunction.apply(tensor, stash_manager) + +def paged_stash_init_chunk_handler(vp_size, vp_stage): + """Initialize the chunk handler, called at the start of a microbatch forward pass.""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return + stash_manager.current_vp_stage = vp_stage if vp_stage is not None else 0 + if vp_size is not None: + stash_manager.vp_size = vp_size + else: + stash_manager.vp_size = 1 + if stash_manager.max_tokens_per_vp_stage is None: + stash_manager.max_tokens_per_vp_stage = [{} for _ in range(stash_manager.vp_size)] + stash_manager.temp_tokens_per_vp_stage = [{} for _ in range(stash_manager.vp_size)] + if stash_manager.max_tokens_across_vp_stages is None: + stash_manager.max_tokens_across_vp_stages = {} + stash_manager.temp_tokens_across_vp_stages = {} + +def paged_stash_set_last_layer(is_last_layer=False): + """Set the last layer flag.""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled: + return + stash_manager._last_layer = is_last_layer + +def paged_stash_reset(enabled=True): + """Reset the chunk handler, called at the start of a training iteration.""" + stash_manager = PagedStashManager.get_instance() + stash_manager.enabled = enabled + stash_manager.iteration += 1 + # current layer and microbatch for each vp stage for forward pass + stash_manager.current_schedule_index = 0 + + if not enabled: + return + + if stash_manager.status == 'begin': + stash_manager.status = 'capture' + elif stash_manager.status == 'capture': + stash_manager.status = 'captured' + stash_buffer_size_factor = float(os.getenv('STASH_BUFFER_SIZE_FACTOR', '1.10')) + stash_manager.allocate_stash_buffers(stash_buffer_size_factor=stash_buffer_size_factor) + elif stash_manager.status == 'captured': + pass + + if stash_manager.status == 'captured': + if not torch.cuda.is_current_stream_capturing(): + overflow = stash_manager.overflow.item() + assert overflow == 0, f"PagedStashManager overflow!!!" + + for dtype in stash_manager.stash_buffers.keys(): + for hidden_size in stash_manager.stash_buffers[dtype].keys(): + stash_manager.stash_buffers[dtype][hidden_size].reset() + stash_manager.overflow.zero_() + stash_manager.current_layer = [1 for _ in range(stash_manager.vp_size)] + stash_manager.current_microbatch = [1 for _ in range(stash_manager.vp_size)] + assert len(stash_manager.paged_tensors_to_stash) == 0, f"paged_tensors_to_stash is not empty {stash_manager.paged_tensors_to_stash}" + assert len(stash_manager.paged_tensors_stash_in_progress) == 0, f"paged_tensors_stash_in_progress is not empty {stash_manager.paged_tensors_stash_in_progress}" + From 25e1f82bc7dc3715442227ceb58cb3b99c179803 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Thu, 18 Dec 2025 16:07:50 +0800 Subject: [PATCH 30/57] Remove d2d page feature for now Remove unused triton kernel for dropping token in case overflow happens --- megatron/core/transformer/moe/moe_utils.py | 96 ----------- megatron/core/transformer/moe/paged_stash.py | 152 +++++++----------- .../core/transformer/moe/token_dispatcher.py | 1 - 3 files changed, 54 insertions(+), 195 deletions(-) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 472477661eb..dbcc25a905c 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -44,10 +44,6 @@ HAVE_TE = False -import triton -import triton.language as tl - - def switch_load_balancing_loss_func( probs: torch.Tensor, tokens_per_expert: torch.Tensor, @@ -1544,95 +1540,3 @@ def wrapped_func(moe_layer, *args, **kwargs): return wrapped_func return decorator - -@triton.jit -def _drop_routing_map_kernel( - routing_map_ptr, - over_budget_ptr, - routing_map_dropped_ptr, - num_elements: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """Triton kernel to drop routing map based on budget constraints. - - Args: - routing_map_ptr: Pointer to the input routing_map tensor - over_budget_ptr: Pointer to the boolean tensor indicating if any EP rank is over budget - routing_map_dropped_ptr: Pointer to the output routing_map tensor - num_elements: Total number of elements to process - BLOCK_SIZE: Block size for Triton kernel - """ - # Get the program ID - pid = tl.program_id(axis=0) - - # Read the over_budget value (scalar tensor with single element) - over_budget_val = tl.load(over_budget_ptr) - - # Calculate the offset for this program - offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - - # Load the routing_map values - mask = offset < num_elements - routing_map_val = tl.load(routing_map_ptr + offset, mask=mask, other=0.0) - - # If over_budget is 1 (True), output is 0 (drop); if over_budget is 0 (False), output is routing_map_val (keep) - output_val = routing_map_val * (1 - over_budget_val) - - # Store the result - tl.store(routing_map_dropped_ptr + offset, output_val, mask=mask) - - -def drop_routing_map_triton( - routing_map: torch.Tensor, - budget: torch.Tensor, - num_tokens_per_ep_rank: torch.Tensor -) -> torch.Tensor: - """Drop tokens from routing_map that exceed the budget per EP rank using Triton. - - Args: - routing_map: Tensor indicating which tokens are assigned to each expert. - budget: Integer tensor with the maximum number of tokens per EP rank. - num_tokens_per_ep_rank: Tensor with actual number of tokens per EP rank. - - Returns: - Modified routing_map with tokens exceeding budget zeroed out if any EP rank - exceeds budget, otherwise returns the original routing_map. - """ - - # Calculate boolean tensor: over_budget is True if ANY EP rank exceeds budget - over_budget = (num_tokens_per_ep_rank > budget).any() - - # Convert boolean to int8 - over_budget_int = over_budget.to(torch.int8) - - # Convert routing_map to numeric type if it's boolean - if routing_map.dtype == torch.bool: - routing_map_numeric = routing_map.to(torch.int8) - else: - routing_map_numeric = routing_map - - # Create output tensor with same dtype as input - routing_map_dropped = torch.empty_like(routing_map_numeric) - - # Flatten tensors for kernel processing - routing_map_flat = routing_map_numeric.flatten() - num_elements = routing_map_flat.numel() - - # Determine grid size - BLOCK_SIZE = 1024 - grid = (triton.cdiv(num_elements, BLOCK_SIZE),) - - # Launch kernel with over_budget tensor pointer (as int8) - _drop_routing_map_kernel[grid]( - routing_map_flat, - over_budget_int, - routing_map_dropped.flatten(), - num_elements, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Convert back to boolean if original was boolean - if routing_map.dtype == torch.bool: - routing_map_dropped = routing_map_dropped.to(torch.bool) - - return routing_map_dropped, over_budget.to(torch.bool) diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 289d9b34879..c13718bb003 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -247,7 +247,7 @@ class PagedTensor: A paged tensor that stores data in pages within a paged stash buffer. """ - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0): + def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64): """ Args: tensor: The tensor to store @@ -256,7 +256,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer layer_name: Name of the layer max_tokens: Maximum number of tokens page_size: Number of tokens per page - num_d2d_pages: Number of pages to copy using native PyTorch (rest uses Triton) """ self._tensor = tensor self._original_tensor = None @@ -267,7 +266,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer self.layer_name = layer_name self.max_tokens = max_tokens self.page_size = page_size - self.num_d2d_pages = num_d2d_pages # Original tensor information self.original_shape = list(tensor.shape) @@ -282,13 +280,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer # Page record: stores which pages are being used for this tensor self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) - - # Static tensor for D2D pages (allocate upfront if needed) - d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) - if d2d_tokens > 0: - self.static_tensor = torch.empty((d2d_tokens, self.hidden_size), dtype=self.dtype, device=self.device) - else: - self.static_tensor = None @property def schedule_layer(self): @@ -312,48 +303,33 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 else: tensor_to_copy = self._tensor - # Split tensor into two parts: D2D portion and Triton portion - # Use max_num_tokens for consistent size across iterations - d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) - triton_tokens = self.max_num_tokens - d2d_tokens - - # Perform both D2D copy and Triton kernel together - # Part 1: Copy first d2d_tokens to static_tensor using native PyTorch - if d2d_tokens > 0: - self.static_tensor[:d2d_tokens] = tensor_to_copy[:d2d_tokens] - # Part 2: Copy remaining tokens using Triton kernel - if triton_tokens > 0: - triton_tensor = tensor_to_copy[d2d_tokens:self.max_num_tokens] - # Use actual num_tokens for the kernel (how many tokens to actually copy) - triton_num_tokens = self.num_tokens_tensor - d2d_tokens - - # Determine grid size - BLOCK_SIZE = GLOBAL_BLOCK_SIZE - num_blocks = min(triton_tokens, max_blocks) - grid = (num_blocks,) - - # Create temporary tensor for new head - new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) - - # Launch paged stash copy kernel - _paged_stash_copy_kernel[grid]( - triton_tensor, - paged_stash_buffer.buffer, - triton_num_tokens, - paged_stash_buffer.free_list, - paged_stash_buffer.free_list_head, - paged_stash_buffer.free_list_tail, - paged_stash_buffer.free_list_capacity, - self.page_record, # Triton kernel will populate page_record - paged_stash_buffer.overflow, - new_free_list_head, - PAGE_SIZE=self.page_size, - HIDDEN_SIZE=self.hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Update free list head - paged_stash_buffer.free_list_head.copy_(new_free_list_head) + # Determine grid size + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(self.max_num_tokens, max_blocks) + grid = (num_blocks,) + + # Create temporary tensor for new head + new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch paged stash copy kernel + _paged_stash_copy_kernel[grid]( + tensor_to_copy, + paged_stash_buffer.buffer, + self.num_tokens_tensor, + paged_stash_buffer.free_list, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + self.page_record, # Triton kernel will populate page_record + paged_stash_buffer.overflow, + new_free_list_head, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Update free list head + paged_stash_buffer.free_list_head.copy_(new_free_list_head) # Save reference to original tensor self._original_tensor = self._tensor @@ -384,48 +360,32 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) tensor_to_reload = self._tensor - # Split tensor into two parts: D2D portion and Triton portion - # Use max_num_tokens for consistency with stash - d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens) - triton_tokens = self.max_num_tokens - d2d_tokens - - # Perform both D2D copy and Triton kernel together - # Part 1: Copy first d2d_tokens from static_tensor using native PyTorch - if d2d_tokens > 0 and self.static_tensor is not None: - tensor_to_reload[:d2d_tokens] = self.static_tensor[:d2d_tokens] - - # Part 2: Copy remaining tokens using Triton kernel - if triton_tokens > 0: - triton_tensor = tensor_to_reload[d2d_tokens:self.max_num_tokens] - # Use actual num_tokens for the kernel (how many tokens to actually copy) - triton_num_tokens = self.num_tokens_tensor - d2d_tokens - - # Determine grid size - BLOCK_SIZE = GLOBAL_BLOCK_SIZE - num_blocks = min(triton_tokens, max_blocks) - grid = (num_blocks,) - - # Create temporary tensor for new tail - new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) - - # Launch paged stash pop kernel - _paged_stash_pop_kernel[grid]( - paged_stash_buffer.buffer, - triton_tensor, - triton_num_tokens, - self.page_record, # Triton kernel will read from page_record - paged_stash_buffer.free_list, - paged_stash_buffer.free_list_head, - paged_stash_buffer.free_list_tail, - paged_stash_buffer.free_list_capacity, - new_free_list_tail, - PAGE_SIZE=self.page_size, - HIDDEN_SIZE=self.hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - # Update free list tail - paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) + # Determine grid size + BLOCK_SIZE = GLOBAL_BLOCK_SIZE + num_blocks = min(self.max_num_tokens, max_blocks) + grid = (num_blocks,) + + # Create temporary tensor for new tail + new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) + + # Launch paged stash pop kernel + _paged_stash_pop_kernel[grid]( + paged_stash_buffer.buffer, + tensor_to_reload, + self.num_tokens_tensor, + self.page_record, # Triton kernel will read from page_record + paged_stash_buffer.free_list, + paged_stash_buffer.free_list_head, + paged_stash_buffer.free_list_tail, + paged_stash_buffer.free_list_capacity, + new_free_list_tail, + PAGE_SIZE=self.page_size, + HIDDEN_SIZE=self.hidden_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Update free list tail + paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) class PP_PreScheduleFunction(torch.autograd.Function): @@ -555,9 +515,6 @@ def __init__(self): # Page size for paged memory management self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page - - # Number of pages to copy using native PyTorch (D2D) - self.num_d2d_pages = int(os.getenv('NUM_D2D_PAGES', '0')) # Default 0 (all Triton) @property def pack_stream(self): @@ -765,7 +722,6 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: layer_name=self._current_layer_name, max_tokens=self.max_num_tokens, page_size=self.page_size, - num_d2d_pages=self.num_d2d_pages ) if self.status == 'captured': diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 41a2f59e449..60da1fb47d1 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -33,7 +33,6 @@ permute, sort_chunks_by_idxs, unpermute, - drop_routing_map_triton, ) from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.transformer_config import TransformerConfig From de34d7ba3066d1a41b1fa003b282c8413cbab9d3 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Thu, 18 Dec 2025 18:38:16 +0800 Subject: [PATCH 31/57] Update added arguments and add compatibility check --- megatron/core/transformer/moe/experts.py | 11 +++-- .../core/transformer/moe/token_dispatcher.py | 10 ++--- .../core/transformer/transformer_config.py | 45 +++++++++++++++++-- megatron/training/arguments.py | 2 +- 4 files changed, 55 insertions(+), 13 deletions(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index fab6d1039d2..ec66507f2c1 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -705,9 +705,10 @@ def __init__( and "moe_act" in self.config.offload_modules ) - self.moe_paged_stash_expert_fc1 = self.config.moe_paged_stash and "expert_fc1" in self.config.offload_modules - self.moe_paged_stash_moe_act = self.config.moe_paged_stash and "moe_act" in self.config.offload_modules - self.moe_paged_stash_expert_fc2 = self.config.moe_paged_stash and "expert_fc2" in self.config.offload_modules + stash_modules = self.config.stash_modules or [] + self.moe_paged_stash_expert_fc1 = self.config.moe_paged_stash and "expert_fc1" in stash_modules + self.moe_paged_stash_moe_act = self.config.moe_paged_stash and "moe_act" in stash_modules + self.moe_paged_stash_expert_fc2 = self.config.moe_paged_stash and "expert_fc2" in stash_modules self.activation_recompute = ( self.config.recompute_granularity == 'selective' @@ -1069,7 +1070,9 @@ def remove_glu_interleaving(x: torch.Tensor) -> torch.Tensor: bias_parallel, permuted_probs, self.config.activation_func_fp8_input_store, - tokens_per_expert.sum() if self.moe_paged_stash_moe_act else None, + tokens_per_expert.sum() + if (isinstance(tokens_per_expert, torch.Tensor) and tokens_per_expert.is_cuda) + else None, ) elif self.activation_func == quick_gelu and self.config.gated_linear_unit: diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 60da1fb47d1..b742cbbbb3f 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1023,7 +1023,7 @@ def __init__( "https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep." ) - self.packed_offloading_capacity_factor = self.config.moe_expert_capacity_factor_for_packed_offloading + self.moe_expert_rank_capacity_factor = self.config.moe_expert_rank_capacity_factor self.over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): @@ -1031,9 +1031,9 @@ def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): self.routing_map = routing_map.reshape(num_tokens, self.num_experts) self.token_probs = probs.reshape(num_tokens, self.num_experts) - if self.packed_offloading_capacity_factor is not None: + if self.moe_expert_rank_capacity_factor is not None: pad_multiple = get_align_size_for_quantization(self.config) - budget = int(routing_map.shape[0] * self.config.moe_router_topk * self.packed_offloading_capacity_factor) + budget = int(routing_map.shape[0] * self.config.moe_router_topk * self.moe_expert_rank_capacity_factor) budget += -budget % pad_multiple self.num_permuted_tokens = budget # Compute the capacity for each expert at the drop_and_pad mode @@ -1080,7 +1080,7 @@ def dispatch( pad_multiple=self.pad_multiple, ) ) - if self.packed_offloading_capacity_factor is not None: + if self.moe_expert_rank_capacity_factor is not None: over_budget = self.handle[8] != 0 # this is overflow_flag self.over_budget |= over_budget @@ -1088,7 +1088,7 @@ def dispatch( self.tokens_per_expert = tokens_per_expert.to(torch.int64) # self.num_permuted_tokens is necessary to allocate the output tensor for permute self.num_permuted_tokens = self.tokens_per_expert.sum() - if self.config.moe_expert_capacity_factor_for_packed_offloading is not None: + if self.moe_expert_rank_capacity_factor is not None: self.tokens_per_expert = tokens_per_expert.to(torch.int64) return dispatched_hidden diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 803619519f2..176ed2bb49a 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -769,6 +769,7 @@ class TransformerConfig(ModelParallelConfig): """Number of SMs to use for HybridEP. In pure NVL scenarios, 16 SMs can generally achieve good bandwidth.""" +<<<<<<< HEAD moe_mlp_glu_interleave_size: Optional[int] = None """When set, GLU activations in the MoE grouped MLP layer will use a block interleaved format. Instead of interpreting the input tensor @@ -778,6 +779,11 @@ class TransformerConfig(ModelParallelConfig): This data format is experimental and primarily intended to enable advanced fused kernels.""" +======= + moe_expert_rank_capacity_factor: Optional[float] = None + """moe_expert_rank_capacity_factor (float): The capacity factor for each expert, None means no token + will be dropped. The default is None.""" +>>>>>>> f52bf7f51 (Update added arguments and add compatibility check) ################## # Context Parallel ################## @@ -996,6 +1002,14 @@ class TransformerConfig(ModelParallelConfig): moe_paged_stash: bool = False """If True, enable paged stash for MoE expert activations.""" + stash_modules: Optional[list[str]] = None + """The MoE submodules to stash activations for. + choices: "expert_fc1", "moe_act", "expert_fc2". + "expert_fc1": stash the input of the expert fc1 part. + "moe_act": stash the input of the moe activation part. + "expert_fc2": stash the input of the expert fc2 part. + """ + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. @@ -1250,6 +1264,18 @@ def __post_init__(self): "moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity" ) + if self.moe_expert_rank_capacity_factor is not None: + if not self.moe_use_device_initiated_grouped_gemm: + raise ValueError( + "moe_expert_rank_capacity_factor requires moe_use_device_initiated_grouped_gemm " + "to be enabled." + ) + if self.moe_flex_dispatcher_backend != "hybridep": + raise ValueError( + "moe_expert_rank_capacity_factor requires moe_flex_dispatcher_backend to be " + "'hybridep'." + ) + if self.cpu_offloading and ( self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers ): @@ -1455,13 +1481,26 @@ def __post_init__(self): assert ( not self.cpu_offloading and not self.fine_grained_activation_offloading ), "paged_stash cannot be enabled with cpu_offloading." - assert self.offload_modules is not None and len(self.offload_modules) > 0 + assert self.stash_modules is not None and len(self.stash_modules) > 0, ( + "stash_modules must be specified when moe_paged_stash is enabled." + ) allowed_modules = {"expert_fc1", "expert_fc2", "moe_act"} - invalid_modules = set(self.offload_modules) - allowed_modules + invalid_modules = set(self.stash_modules) - allowed_modules assert not invalid_modules, ( - f'Invalid choices for offload_modules: {invalid_modules}. ' + f'Invalid choices for stash_modules: {invalid_modules}. ' f'Allowed modules are: {allowed_modules}' ) + assert self.moe_expert_rank_capacity_factor is not None, ( + "moe_expert_rank_capacity_factor must be set when moe_paged_stash is enabled." + ) + + # Check that no module is both stashed and offloaded + if self.stash_modules and self.offload_modules: + overlap = set(self.stash_modules) & set(self.offload_modules) + assert not overlap, ( + f"A module cannot be stashed and offloaded at the same time. " + f"Found overlapping modules: {overlap}" + ) if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 7cbbdfe4011..6a7110e8590 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1651,7 +1651,7 @@ def _add_inference_args(parser): group.add_argument('--use-legacy-static-engine', action='store_true', default=False, help='Use legacy static engine. (Current static engine uses dynamic engine under the hood)', dest='use_legacy_static_engine') - group.add_argument('--moe-expert-capacity-factor-for-packed-offloading', type=float, default=None, + group.add_argument('--moe-expert-rank-capacity-factor', type=float, default=None, help='The capacity factor for each EP rank when packed offloading is enabled.') group.add_argument('--inference-max-requests', type=int, default=8, help='Maximum number of requests for inference.', From 1bf3e433f90926d59ccbc672390b2d5917b19ad3 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Thu, 18 Dec 2025 20:10:55 +0800 Subject: [PATCH 32/57] refine overflow check resolve accidental change in fused_a2a.py --- megatron/core/full_cuda_graph.py | 16 ++++++++++++---- megatron/core/transformer/moe/fused_a2a.py | 8 +++++++- megatron/core/transformer/moe/paged_stash.py | 8 ++++++++ megatron/training/arguments.py | 5 ----- megatron/training/training.py | 12 ++++++++++-- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 40fec15e67b..6502cf62b4b 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -9,6 +9,7 @@ from megatron.core.tensor_parallel.random import get_all_rng_states from megatron.core.transformer.moe.paged_stash import ( paged_stash_reset, + check_paged_stash_overflow, ) logger = logging.getLogger(__name__) @@ -101,11 +102,18 @@ class FullCudaGraphWrapper: cuda_graph = {'training': None, 'validation': None} result = {'training': None, 'validation': None} - def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1, moe_paged_stash=False): + def __init__( + self, + forward_backward_func, + cuda_graph_warmup_steps=1, + moe_paged_stash=False, + moe_expert_rank_capacity_factor=None, + ): self.forward_backward_func = forward_backward_func self.static_loader = StaticBufferLoader() self.cuda_graph_warmup_steps = cuda_graph_warmup_steps self.moe_paged_stash = moe_paged_stash + self.moe_expert_rank_capacity_factor = moe_expert_rank_capacity_factor def data_read(self, data_iterator, model, training, num_microbatches): """Read all microbatch inputs from Dataloader and copy to static buffers.""" @@ -184,19 +192,19 @@ def __call__(self, *args, **kwargs): torch.cuda.synchronize() torch.distributed.barrier() logger.info(f'CUDA graph capture done for {training_str}!!!') - + paged_stash_reset(enabled=self.moe_paged_stash and training) if FullCudaGraphWrapper.cuda_graph[training_str] is None: FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: - paged_stash_reset(enabled=self.moe_paged_stash and training) FullCudaGraphWrapper.cuda_graph[training_str].replay() + check_paged_stash_overflow() self.speculative_cuda_graph_check(model) self.next_iter(training_str) return FullCudaGraphWrapper.result[training_str] def speculative_cuda_graph_check(self, model): ''' check speculative execution modules ''' - if self.moe_paged_stash: + if self.moe_expert_rank_capacity_factor is not None: # Check if there is any overflow in the receiving buffer over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') for model_chunk in model: diff --git a/megatron/core/transformer/moe/fused_a2a.py b/megatron/core/transformer/moe/fused_a2a.py index ad935cea66e..6d5f14eb121 100644 --- a/megatron/core/transformer/moe/fused_a2a.py +++ b/megatron/core/transformer/moe/fused_a2a.py @@ -3,6 +3,8 @@ # Copyright (c) 2025 DeepSeek # Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE +from megatron.core.utils import internal_api + try: from deep_ep import Buffer from deep_ep.utils import EventHandle, EventOverlap @@ -327,6 +329,7 @@ def reset_hybrid_ep_buffer(): _hybrid_ep_buffer = None +@internal_api class HybridEPDispatch(torch.autograd.Function): ''' Fused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend @@ -403,6 +406,7 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper return combined_hidden, None, combined_probs, None, None, None, None, None, None, None +@internal_api class HybridEPCombine(torch.autograd.Function): ''' Fused combine operation for permute + combine a2a + permute using the HybridEP backend @@ -439,6 +443,7 @@ def backward(ctx, grad_x): if HAVE_HYBRIDEP: + @internal_api def hybrid_ep_dispatch( x, routing_map, @@ -489,6 +494,7 @@ def hybrid_ep_dispatch( pad_multiple, ) + @internal_api def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): ''' Perform fused combine operation for unpermute + combine a2a + unpermute @@ -510,4 +516,4 @@ def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): else: hybrid_ep_dispatch = None - hybrid_ep_combine = None + hybrid_ep_combine = None \ No newline at end of file diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index c13718bb003..34733a0700d 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -871,3 +871,11 @@ def paged_stash_reset(enabled=True): assert len(stash_manager.paged_tensors_to_stash) == 0, f"paged_tensors_to_stash is not empty {stash_manager.paged_tensors_to_stash}" assert len(stash_manager.paged_tensors_stash_in_progress) == 0, f"paged_tensors_stash_in_progress is not empty {stash_manager.paged_tensors_stash_in_progress}" +def check_paged_stash_overflow(): + """Check if paged stash overflow""" + stash_manager = PagedStashManager.get_instance() + if not stash_manager.enabled or stash_manager.overflow is None: + return + overflow = stash_manager.overflow.item() + if overflow != 0: + raise RuntimeError("PagedStashManager overflow!!!") diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 6a7110e8590..5627c491a0e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1411,11 +1411,6 @@ def validate_args(args, defaults={}): "For fine-grained activation offloading with TE >= 2.10.0, NVTE_CPU_OFFLOAD_V1 should be set to 1 to avoid offloading weights." assert not args.moe_paged_stash, "Fine-grained activation offloading and paged stash cannot be enabled at the same time" - if args.moe_paged_stash: - assert args.transformer_impl == 'transformer_engine', \ - "Paged stash is only supported with transformer_engine implementation" - assert not args.fine_grained_activation_offloading, "Paged stash and fine-grained activation offloading cannot be enabled at the same time" - if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." # MTP is compatible with position embedding types that use position_ids. diff --git a/megatron/training/training.py b/megatron/training/training.py index c5715e96aed..454b84d274a 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2716,7 +2716,11 @@ def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs): # Wrap forward_backward_func for Full iteration CUDA graph forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: - forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + forward_backward_func = FullCudaGraphWrapper( + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + moe_expert_rank_capacity_factor=args.moe_expert_rank_capacity_factor, + ) def get_e2e_base_metrics(): """Get base metrics values for one-logger to calculate E2E tracking metrics.""" @@ -3202,7 +3206,11 @@ def evaluate( eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size) forward_backward_func = get_forward_backward_func() if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: - forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) + forward_backward_func = FullCudaGraphWrapper( + forward_backward_func, + cuda_graph_warmup_steps=args.cuda_graph_warmup_steps, + moe_expert_rank_capacity_factor=args.moe_expert_rank_capacity_factor, + ) if has_nvidia_modelopt: # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors From db0b5c930a690c3f75d954ed126b8349b93bf9b4 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Fri, 19 Dec 2025 12:00:45 +0800 Subject: [PATCH 33/57] Fixing lint issues --- megatron/core/full_cuda_graph.py | 11 +- megatron/core/fusions/fused_bias_swiglu.py | 93 +++-- .../common/model_chunk_schedule_plan.py | 4 +- megatron/core/models/gpt/gpt_model.py | 8 +- megatron/core/pipeline_parallel/schedules.py | 7 +- megatron/core/transformer/moe/experts.py | 53 ++- megatron/core/transformer/moe/fused_a2a.py | 2 +- megatron/core/transformer/moe/paged_stash.py | 361 +++++++++++------- .../core/transformer/moe/token_dispatcher.py | 12 +- .../core/transformer/transformer_block.py | 3 +- .../core/transformer/transformer_config.py | 25 +- 11 files changed, 341 insertions(+), 238 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 6502cf62b4b..28836b10b2a 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -7,10 +7,7 @@ import torch from megatron.core.tensor_parallel.random import get_all_rng_states -from megatron.core.transformer.moe.paged_stash import ( - paged_stash_reset, - check_paged_stash_overflow, -) +from megatron.core.transformer.moe.paged_stash import check_paged_stash_overflow, paged_stash_reset logger = logging.getLogger(__name__) @@ -203,14 +200,16 @@ def __call__(self, *args, **kwargs): return FullCudaGraphWrapper.result[training_str] def speculative_cuda_graph_check(self, model): - ''' check speculative execution modules ''' + '''check speculative execution modules''' if self.moe_expert_rank_capacity_factor is not None: # Check if there is any overflow in the receiving buffer over_budget = torch.zeros(1, dtype=torch.bool, device='cuda') for model_chunk in model: for layer in model_chunk.module.module.decoder.layers: mlp = layer.mlp - if hasattr(mlp, 'token_dispatcher') and hasattr(mlp.token_dispatcher, 'check_over_budget'): + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher, 'check_over_budget' + ): over_budget |= mlp.token_dispatcher.check_over_budget() if over_budget.item(): raise Exception(f"Rank {torch.distributed.get_rank()} overbudget") diff --git a/megatron/core/fusions/fused_bias_swiglu.py b/megatron/core/fusions/fused_bias_swiglu.py index 3f6c70c75d8..b15081343f9 100644 --- a/megatron/core/fusions/fused_bias_swiglu.py +++ b/megatron/core/fusions/fused_bias_swiglu.py @@ -194,7 +194,7 @@ class WeightedSwiGLUFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, weights, fp8_input_store, num_tokens_tensor=None): """Forward pass for weighted SwiGLU. - + Args: input: [total_tokens, hidden_size * 2] weights: [total_tokens, 1] @@ -204,7 +204,7 @@ def forward(ctx, input, weights, fp8_input_store, num_tokens_tensor=None): """ # Convert input for backward pass input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input - + # Use Triton implementation if num_tokens_tensor provided and available if num_tokens_tensor is not None and input.dim() == 2: output = weighted_swiglu_triton(input, weights, num_tokens_tensor) @@ -215,7 +215,7 @@ def forward(ctx, input, weights, fp8_input_store, num_tokens_tensor=None): output = weighted_swiglu(input, weights) ctx.save_for_backward(input_for_backward, weights) ctx.use_triton = False - + ctx.ori_input_dtype = input.dtype ctx.fp8_input_store = fp8_input_store return output @@ -287,6 +287,7 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False, num_t # bias_swiglu_impl = BiasSwiGLUFunction.apply # swiglu_impl = SwiGLUFunction.apply + @triton.jit def _weighted_swiglu_fwd_kernel( input_ptr, @@ -297,26 +298,26 @@ def _weighted_swiglu_fwd_kernel( BLOCK_SIZE: tl.constexpr, ): """Triton kernel for weighted SwiGLU forward pass. - + Processes tokens in strided pattern, only operating on valid tokens. Formula: output = SiLU(input[:, :H]) * input[:, H:] * weights """ pid = tl.program_id(axis=0) num_blocks = tl.num_programs(axis=0) - + # Load actual number of tokens num_tokens = tl.load(num_tokens_ptr) - + # Strided access: each block handles tokens [pid, pid+num_blocks, ...] token_idx = pid while token_idx < num_tokens: # Load weight for this token weight = tl.load(weights_ptr + token_idx) - + # Process hidden dimension for h_offset in range(0, hidden_size, BLOCK_SIZE): h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size - + # Load input chunks (gate and value) input_offset_1 = token_idx * (hidden_size * 2) + h_offset input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset @@ -334,11 +335,11 @@ def _weighted_swiglu_fwd_kernel( y1_fp32 = y1.to(tl.float32) y2_fp32 = y2.to(tl.float32) weight_fp32 = weight.to(tl.float32) - + sigmoid_y1 = tl.sigmoid(y1_fp32) silu_y1 = y1_fp32 * sigmoid_y1 result = silu_y1 * y2_fp32 * weight_fp32 - + # Store output (cast back to original dtype) output_offset = token_idx * hidden_size + h_offset tl.store( @@ -346,10 +347,11 @@ def _weighted_swiglu_fwd_kernel( result.to(y1.dtype), mask=h_mask, ) - + # Stride to next token token_idx += num_blocks + @triton.jit def _weighted_swiglu_bwd_kernel( grad_output_ptr, @@ -362,34 +364,32 @@ def _weighted_swiglu_bwd_kernel( BLOCK_SIZE: tl.constexpr, ): """Triton kernel for weighted SwiGLU backward pass. - + Computes gradients with respect to input and weights for valid tokens only. """ pid = tl.program_id(axis=0) num_blocks = tl.num_programs(axis=0) - + # Load actual number of tokens num_tokens = tl.load(num_tokens_ptr) - + # Strided access token_idx = pid while token_idx < num_tokens: # Load weight for this token weight = tl.load(weights_ptr + token_idx) - + # Accumulator for weight gradient (fp32 for precision) weight_grad_acc = 0.0 - + # Process hidden dimension for h_offset in range(0, hidden_size, BLOCK_SIZE): h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size - + # Load grad_output grad_out_offset = token_idx * hidden_size + h_offset grad_out = tl.load( - grad_output_ptr + grad_out_offset + tl.arange(0, BLOCK_SIZE), - mask=h_mask, - other=0.0, + grad_output_ptr + grad_out_offset + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 ) # Load input chunks @@ -402,25 +402,25 @@ def _weighted_swiglu_bwd_kernel( y2 = tl.load( input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 ) - + # Cast to fp32 for sigmoid computation (required by Triton) y1_fp32 = y1.to(tl.float32) y2_fp32 = y2.to(tl.float32) grad_out_fp32 = grad_out.to(tl.float32) weight_fp32 = weight.to(tl.float32) - + # Forward calculations sigmoid_y1 = tl.sigmoid(y1_fp32) silu_y1 = y1_fp32 * sigmoid_y1 - + # Gradient for y1 (gate): d(SiLU(y1))/dy1 * y2 * weight * grad_out # d(SiLU(y1))/dy1 = sigmoid(y1) * (1 + y1 * (1 - sigmoid(y1))) dsilu_dy1 = sigmoid_y1 * (1.0 + y1_fp32 * (1.0 - sigmoid_y1)) grad_y1 = grad_out_fp32 * weight_fp32 * dsilu_dy1 * y2_fp32 - + # Gradient for y2 (value): SiLU(y1) * weight * grad_out grad_y2 = grad_out_fp32 * weight_fp32 * silu_y1 - + # Store input gradients (cast back to original dtype) tl.store( grad_input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), @@ -432,79 +432,76 @@ def _weighted_swiglu_bwd_kernel( grad_y2.to(y2.dtype), mask=h_mask, ) - + # Accumulate weight gradient: swiglu(y) * grad_out # swiglu(y) = silu_y1 * y2 weight_grad_contribution = silu_y1 * y2_fp32 * grad_out_fp32 weight_grad_acc += tl.sum(weight_grad_contribution) - + # Store weight gradient after processing all chunks tl.store(grad_weights_ptr + token_idx, weight_grad_acc) - + # Stride to next token token_idx += num_blocks + def weighted_swiglu_triton(input, weights, num_tokens_tensor): """Triton implementation of weighted SwiGLU forward pass. - + Args: input: [total_tokens, hidden_size * 2] weights: [total_tokens, 1] num_tokens_tensor: Scalar tensor with actual token count - + Returns: output: [total_tokens, hidden_size] """ assert input.dim() == 2, "Input must be 2D [total_tokens, hidden_size*2]" assert weights.dim() == 2 and weights.size(1) == 1, "Weights must be [total_tokens, 1]" - + total_tokens, hidden_size_2 = input.shape hidden_size = hidden_size_2 // 2 - + # Allocate output output = torch.empty((total_tokens, hidden_size), dtype=input.dtype, device=input.device) - + # Launch kernel BLOCK_SIZE = 128 num_blocks = min(total_tokens, 4096) grid = (num_blocks,) - + _weighted_swiglu_fwd_kernel[grid]( - input, - weights, - output, - num_tokens_tensor, - hidden_size=hidden_size, - BLOCK_SIZE=BLOCK_SIZE, + input, weights, output, num_tokens_tensor, hidden_size=hidden_size, BLOCK_SIZE=BLOCK_SIZE ) - + return output + def weighted_swiglu_triton_back(grad_output, input, weights, num_tokens_tensor): """Triton implementation of weighted SwiGLU backward pass. - + Args: grad_output: [total_tokens, hidden_size] input: [total_tokens, hidden_size * 2] weights: [total_tokens, 1] num_tokens_tensor: Scalar tensor with actual token count - + Returns: grad_input: [total_tokens, hidden_size * 2] grad_weights: [total_tokens, 1] """ total_tokens, hidden_size_2 = input.shape hidden_size = hidden_size_2 // 2 - + # Allocate gradients grad_input = torch.empty_like(input) grad_weights = torch.empty_like(weights) - + # Launch kernel BLOCK_SIZE = 128 num_blocks = min(total_tokens, 4096) grid = (num_blocks,) - + _weighted_swiglu_bwd_kernel[grid]( grad_output, input, @@ -515,5 +512,5 @@ def weighted_swiglu_triton_back(grad_output, input, weights, num_tokens_tensor): hidden_size=hidden_size, BLOCK_SIZE=BLOCK_SIZE, ) - - return grad_input, grad_weights \ No newline at end of file + + return grad_input, grad_weights diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index a745ffe2294..c5cf05a8f6e 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -8,9 +8,6 @@ from megatron.core.enums import Fp8Recipe from megatron.core.fp8_utils import get_fp8_context -from megatron.core.transformer.moe.paged_stash import ( - paged_stash_set_last_layer, -) from megatron.core.pipeline_parallel.utils import ( AbstractSchedulePlan, NoopScheduleNode, @@ -18,6 +15,7 @@ get_comp_stream, ) from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.moe.paged_stash import paged_stash_set_last_layer class ModelChunkState: diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 2a829a24929..df732ef8d94 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -21,14 +21,12 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.transformer.moe.paged_stash import ( - paged_stash_init_chunk_handler, -) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.quantization.utils import get_quant_config_or_none from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.enums import CudaGraphScope, ModelType from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule +from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, mtp_on_this_rank, @@ -479,8 +477,7 @@ def preprocess_for_fine_grained_offloading(self): def preprocess_for_paged_stash(self): """Preprocess for paged stash.""" return paged_stash_init_chunk_handler( - vp_size=self.config.virtual_pipeline_model_parallel_size, - vp_stage=self.vp_stage, + vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage ) def forward( @@ -540,7 +537,6 @@ def forward( rotary_pos_cos_sin = preproc_output[6] if len(preproc_output) == 7 else None - # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 5d99fa20a8a..976c8e6018f 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -12,9 +12,6 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.transformer.moe.paged_stash import ( - paged_stash_reset, -) from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, @@ -25,6 +22,7 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.moe.paged_stash import paged_stash_reset from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.utils import ( drain_embedding_wgrad_compute, @@ -1710,8 +1708,6 @@ def forward_backward_helper_wrapper( # Forward pass. forward_k = k + num_warmup_microbatches - - # Decide to checkpoint all layers' activations of the current micro-batch. if max_outstanding_backprops is not None: checkpoint_activations_microbatch = ( @@ -1725,6 +1721,7 @@ def forward_backward_helper_wrapper( if config.overlap_p2p_comm: backward_k = k + # Sync forward recv def pp_pre_forward(vp_stage=None): if vp_stage is None: diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index ec66507f2c1..c309fcd84f3 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -33,12 +33,6 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.transformer.moe.paged_stash import ( - paged_stash_group_start, - get_paged_stash_context, - paged_stash_reset, - paged_stash_group_commit, -) from megatron.core.tensor_parallel.layers import ( _initialize_affine_weight_cpu, _initialize_affine_weight_gpu, @@ -57,6 +51,11 @@ ProcessGroupCollection, get_align_size_for_quantization, ) +from megatron.core.transformer.moe.paged_stash import ( + get_paged_stash_context, + paged_stash_group_commit, + paged_stash_group_start, +) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( ensure_metadata_has_dp_cp_group, @@ -706,9 +705,13 @@ def __init__( ) stash_modules = self.config.stash_modules or [] - self.moe_paged_stash_expert_fc1 = self.config.moe_paged_stash and "expert_fc1" in stash_modules + self.moe_paged_stash_expert_fc1 = ( + self.config.moe_paged_stash and "expert_fc1" in stash_modules + ) self.moe_paged_stash_moe_act = self.config.moe_paged_stash and "moe_act" in stash_modules - self.moe_paged_stash_expert_fc2 = self.config.moe_paged_stash and "expert_fc2" in stash_modules + self.moe_paged_stash_expert_fc2 = ( + self.config.moe_paged_stash and "expert_fc2" in stash_modules + ) self.activation_recompute = ( self.config.recompute_granularity == 'selective' @@ -1014,9 +1017,15 @@ def forward( self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" ) as permuted_local_hidden_states: if self.config.moe_paged_stash: - permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states, name="expert_fc1") + permuted_local_hidden_states = paged_stash_group_start( + permuted_local_hidden_states, name="expert_fc1" + ) if self.moe_paged_stash_expert_fc1: - offload_context = get_paged_stash_context(name="expert_fc1", max_num_tokens=permuted_local_hidden_states.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + offload_context = get_paged_stash_context( + name="expert_fc1", + max_num_tokens=permuted_local_hidden_states.shape[0], + num_tokens_tensor=tokens_per_expert.sum(), + ) else: offload_context = nullcontext() with offload_context: @@ -1070,10 +1079,14 @@ def remove_glu_interleaving(x: torch.Tensor) -> torch.Tensor: bias_parallel, permuted_probs, self.config.activation_func_fp8_input_store, - tokens_per_expert.sum() - if (isinstance(tokens_per_expert, torch.Tensor) and tokens_per_expert.is_cuda) - else None, - + ( + tokens_per_expert.sum() + if ( + isinstance(tokens_per_expert, torch.Tensor) + and tokens_per_expert.is_cuda + ) + else None + ), ) elif self.activation_func == quick_gelu and self.config.gated_linear_unit: intermediate_parallel = weighted_bias_quick_geglu_impl( @@ -1126,7 +1139,11 @@ def glu(x): else: with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: if self.moe_paged_stash_moe_act: - offload_context = get_paged_stash_context(name="moe_act", max_num_tokens=fc1_output.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + offload_context = get_paged_stash_context( + name="moe_act", + max_num_tokens=fc1_output.shape[0], + num_tokens_tensor=tokens_per_expert.sum(), + ) else: offload_context = nullcontext() with offload_context: @@ -1137,7 +1154,11 @@ def glu(x): ) if self.moe_paged_stash_expert_fc2: - offload_context = get_paged_stash_context(name="expert_fc2", max_num_tokens=bias_act_output.shape[0], num_tokens_tensor=tokens_per_expert.sum()) + offload_context = get_paged_stash_context( + name="expert_fc2", + max_num_tokens=bias_act_output.shape[0], + num_tokens_tensor=tokens_per_expert.sum(), + ) else: offload_context = nullcontext() with offload_context: diff --git a/megatron/core/transformer/moe/fused_a2a.py b/megatron/core/transformer/moe/fused_a2a.py index 6d5f14eb121..aa13b9b5b5b 100644 --- a/megatron/core/transformer/moe/fused_a2a.py +++ b/megatron/core/transformer/moe/fused_a2a.py @@ -516,4 +516,4 @@ def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple): else: hybrid_ep_dispatch = None - hybrid_ep_combine = None \ No newline at end of file + hybrid_ep_combine = None diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 34733a0700d..daf3065bfd4 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -1,24 +1,25 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -import warnings +import os from contextlib import nullcontext from typing import Any -import os + import torch -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor import triton import triton.language as tl +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor GLOBAL_BLOCK_SIZE = 1024 + class PagedStashBuffer: """ A paged stash buffer with page-level memory management. - + The buffer is organized as [num_pages, page_size, hidden_size]. Uses a free list (circular buffer) to track available pages. """ - + def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype): """ Args: @@ -33,36 +34,45 @@ def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype): self.page_size = page_size self.num_pages = (num_tokens + page_size - 1) // page_size # Ceiling division self.total_tokens = self.num_pages * page_size - + # Create 2D buffer [total_tokens, hidden_size] # Organized as pages: [page_0_tokens, page_1_tokens, ...] if os.getenv('PAGED_STASH_TO_CPU', '0') == '1': - self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True) + self.buffer = torch.empty( + (self.total_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True + ) else: self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device=device) - + self.overflow = overflow # GPU flag (shared) self.device = device self.dtype = dtype - + # Free list as circular buffer: stores available page IDs self.free_list = torch.arange(self.num_pages, dtype=torch.int64, device=device) - + # Head and tail pointers for free_list circular buffer - self.free_list_head = torch.zeros(1, dtype=torch.int64, device=device) # Read pointer (allocation) - self.free_list_tail = self.num_pages*torch.ones(1, dtype=torch.int64, device=device) # Write pointer (deallocation) - + self.free_list_head = torch.zeros( + 1, dtype=torch.int64, device=device + ) # Read pointer (allocation) + self.free_list_tail = self.num_pages * torch.ones( + 1, dtype=torch.int64, device=device + ) # Write pointer (deallocation) + # Capacity of free list - self.free_list_capacity = self.num_pages*torch.ones(1, dtype=torch.int64, device=device) - + self.free_list_capacity = self.num_pages * torch.ones(1, dtype=torch.int64, device=device) + def reset(self): """Reset the paged buffer - reinitialize free list.""" self.free_list.copy_(torch.arange(self.num_pages, dtype=torch.int64, device=self.device)) self.free_list_head.zero_() self.free_list_tail.fill_(self.num_pages) - + def __repr__(self): - return f"PagedStashBuffer(num_pages={self.num_pages}, page_size={self.page_size}, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" + return ( + f"PagedStashBuffer(num_pages={self.num_pages}, page_size={self.page_size}, " + f"hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" + ) @triton.jit @@ -82,7 +92,7 @@ def _paged_stash_copy_kernel( BLOCK_SIZE: tl.constexpr, ): """Triton kernel to copy tokens to paged stash buffer. - + Allocates pages from free list (reads from head, advances head). Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. Grid: (num_blocks,) where blocks process tokens in a strided pattern. @@ -90,54 +100,54 @@ def _paged_stash_copy_kernel( """ pid = tl.program_id(axis=0) num_blocks = tl.num_programs(axis=0) - + # Load parameters num_tokens = tl.load(num_tokens_ptr) free_list_head = tl.load(free_list_head_ptr) free_list_tail = tl.load(free_list_tail_ptr) free_list_capacity = tl.load(free_list_capacity_ptr) - + # Check available pages (unwrapped indices: simple subtraction, no modulo needed) avail_pages = free_list_tail - free_list_head - + # Calculate required pages required_pages = tl.cdiv(num_tokens, PAGE_SIZE) overflow_detected = avail_pages < required_pages - + # Only block 0 writes overflow flag if pid == 0 and overflow_detected: tl.store(overflow_ptr, 1) - + # All blocks return early if overflow if overflow_detected: return - + # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] token_idx = pid while token_idx < num_tokens: # Determine which page this token belongs to page_slot = token_idx // PAGE_SIZE token_in_page = token_idx % PAGE_SIZE - + # Read page ID from free list (with wraparound) free_list_idx = (free_list_head + page_slot) % free_list_capacity page_id = tl.load(free_list_ptr + free_list_idx) - + # First token in page: record the page ID (only if this block handles token 0 of the page) if token_in_page == 0: tl.store(page_record_ptr + page_slot, page_id) - + # Calculate destination address in paged buffer dst_token_idx = page_id * PAGE_SIZE + token_in_page - + # Copy token data (2D: hidden dimension) elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 num_iters = elements_per_thread + (1 if need_mask else 0) - + src_base = src_ptr + token_idx * HIDDEN_SIZE dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE - + if need_mask: for iter in range(num_iters): hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE @@ -149,10 +159,10 @@ def _paged_stash_copy_kernel( hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE data = tl.load(src_base + hidden_offsets) tl.store(dst_base + hidden_offsets, data) - + # Stride to next token for this block token_idx += num_blocks - + # Calculate and store new free list head (only block 0) # We consumed pages, so advance head forward (unwrapped: no modulo) # Write to temporary tensor to avoid race conditions @@ -177,7 +187,7 @@ def _paged_stash_pop_kernel( BLOCK_SIZE: tl.constexpr, ): """Triton kernel to reload tokens from paged stash buffer. - + Returns pages to free list (writes to tail, advances tail). Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. Grid: (num_blocks,) where blocks process tokens in a strided pattern. @@ -185,33 +195,33 @@ def _paged_stash_pop_kernel( """ pid = tl.program_id(axis=0) num_blocks = tl.num_programs(axis=0) - + # Load parameters num_tokens = tl.load(num_tokens_ptr) free_list_tail = tl.load(free_list_tail_ptr) free_list_capacity = tl.load(free_list_capacity_ptr) - + # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] token_idx = pid while token_idx < num_tokens: # Determine which page this token belongs to page_slot = token_idx // PAGE_SIZE token_in_page = token_idx % PAGE_SIZE - + # Read page ID from page record page_id = tl.load(page_record_ptr + page_slot) - + # Calculate source address in paged buffer src_token_idx = page_id * PAGE_SIZE + token_in_page - + # Copy token data (2D: hidden dimension) elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 num_iters = elements_per_thread + (1 if need_mask else 0) - + src_base = src_ptr + src_token_idx * HIDDEN_SIZE dst_base = dst_ptr + token_idx * HIDDEN_SIZE - + if need_mask: for iter in range(num_iters): hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE @@ -223,16 +233,16 @@ def _paged_stash_pop_kernel( hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE data = tl.load(src_base + hidden_offsets) tl.store(dst_base + hidden_offsets, data) - + # First token in page: release page back to free list if token_in_page == 0: # Write page ID back to free list at tail position (with wraparound) write_idx = (free_list_tail + page_slot) % free_list_capacity tl.store(free_list_ptr + write_idx, page_id) - + # Stride to next token for this block token_idx += num_blocks - + # Calculate and store new free list tail (only block 0) # We returned pages, so advance tail forward (unwrapped: no modulo) # Write to temporary tensor to avoid race conditions @@ -246,8 +256,17 @@ class PagedTensor: """ A paged tensor that stores data in pages within a paged stash buffer. """ - - def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64): + + def __init__( + self, + tensor, + num_tokens_tensor=None, + vp_stage=None, + schedule_layer_no=None, + layer_name=None, + max_tokens=None, + page_size=64, + ): """ Args: tensor: The tensor to store @@ -259,25 +278,31 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer """ self._tensor = tensor self._original_tensor = None - assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) and num_tokens_tensor.numel() == 1 + assert ( + num_tokens_tensor is not None + and isinstance(num_tokens_tensor, torch.Tensor) + and num_tokens_tensor.numel() == 1 + ) self.num_tokens_tensor = num_tokens_tensor.clone() self.vp_stage = vp_stage self.schedule_layer_no = schedule_layer_no self.layer_name = layer_name self.max_tokens = max_tokens self.page_size = page_size - + # Original tensor information self.original_shape = list(tensor.shape) self.max_num_tokens = self.original_shape[0] self.element_size = tensor.element_size() self.hidden_size = self.original_shape[1] - self.dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype + self.dtype = ( + tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype + ) self.device = tensor.device - + # Calculate number of pages needed self.max_num_pages = (self.max_num_tokens + page_size - 1) // page_size # Ceiling division - + # Page record: stores which pages are being used for this tensor self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) @@ -288,7 +313,7 @@ def schedule_layer(self): def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): """Offload the paged tensor to paged stash buffer. - + Args: paged_stash_buffer: The paged stash buffer to offload to max_blocks: Maximum number of blocks for Triton kernel @@ -296,21 +321,21 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 self._tensor = self._tensor.contiguous() if self.num_tokens_tensor.dim() == 0: self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) - + # Get 2D tensor if isinstance(self._tensor, MXFP8Tensor): tensor_to_copy = self._tensor._columnwise_data else: tensor_to_copy = self._tensor - + # Determine grid size BLOCK_SIZE = GLOBAL_BLOCK_SIZE num_blocks = min(self.max_num_tokens, max_blocks) grid = (num_blocks,) - + # Create temporary tensor for new head new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) - + # Launch paged stash copy kernel _paged_stash_copy_kernel[grid]( tensor_to_copy, @@ -327,17 +352,17 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 HIDDEN_SIZE=self.hidden_size, BLOCK_SIZE=BLOCK_SIZE, ) - + # Update free list head paged_stash_buffer.free_list_head.copy_(new_free_list_head) - + # Save reference to original tensor self._original_tensor = self._tensor self._tensor = None - + def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): """Reload the paged tensor from paged stash buffer. - + Args: paged_stash_buffer: The paged stash buffer to reload from max_blocks: Maximum number of blocks for Triton kernel @@ -359,15 +384,15 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 else: self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) tensor_to_reload = self._tensor - + # Determine grid size BLOCK_SIZE = GLOBAL_BLOCK_SIZE num_blocks = min(self.max_num_tokens, max_blocks) grid = (num_blocks,) - + # Create temporary tensor for new tail new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) - + # Launch paged stash pop kernel _paged_stash_pop_kernel[grid]( paged_stash_buffer.buffer, @@ -383,7 +408,7 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 HIDDEN_SIZE=self.hidden_size, BLOCK_SIZE=BLOCK_SIZE, ) - + # Update free list tail paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) @@ -394,7 +419,7 @@ class PP_PreScheduleFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, tensor, stash_manager): # after forward + def forward(ctx, tensor, stash_manager): # after forward # pylint: disable=missing-function-docstring ctx.stash_manager = stash_manager # Wait for stash to complete before starting the next layer @@ -402,35 +427,45 @@ def forward(ctx, tensor, stash_manager): # after forward return tensor @staticmethod - def backward(ctx, *grad_output): # before backward + def backward(ctx, *grad_output): # before backward # pylint: disable=missing-function-docstring # Initiate reload for next layer - if ctx.stash_manager.status == 'captured' and ctx.stash_manager.current_schedule_index < len(ctx.stash_manager._pp_schedule): - next_schedule_layer = ctx.stash_manager._pp_schedule[ctx.stash_manager.current_schedule_index] + if ( + ctx.stash_manager.status == 'captured' + and ctx.stash_manager.current_schedule_index < len(ctx.stash_manager._pp_schedule) + ): + next_schedule_layer = ctx.stash_manager._pp_schedule[ + ctx.stash_manager.current_schedule_index + ] if next_schedule_layer < 0: ctx.stash_manager.reload_paged_tensors(-next_schedule_layer) return grad_output + (None, None) + class PP_PostScheduleFunction(torch.autograd.Function): """ This function is used to update the pp schedule. """ @staticmethod - def forward(ctx, tensor, stash_manager): # after forward + def forward(ctx, tensor, stash_manager): # after forward # pylint: disable=missing-function-docstring ctx.stash_manager = stash_manager ctx.vp_stage = stash_manager.current_vp_stage if ctx.vp_stage is None: ctx.vp_stage = 0 - ctx.layer_no, ctx.microbatch_no = stash_manager.update_pp_schedule(ctx.vp_stage+1) + ctx.layer_no, ctx.microbatch_no = stash_manager.update_pp_schedule(ctx.vp_stage + 1) # Initiate stash for current layer and reload for next layer if stash_manager.status == 'captured': - current_schedule_layer = stash_manager.get_schedule_layer(ctx.vp_stage+1, ctx.layer_no, ctx.microbatch_no) - next_schedule_layer = ctx.stash_manager._pp_schedule[ctx.stash_manager.current_schedule_index+1] + current_schedule_layer = stash_manager.get_schedule_layer( + ctx.vp_stage + 1, ctx.layer_no, ctx.microbatch_no + ) + next_schedule_layer = ctx.stash_manager._pp_schedule[ + ctx.stash_manager.current_schedule_index + 1 + ] if current_schedule_layer != -next_schedule_layer: # Start stash for current layer ctx.stash_manager.stash_paged_tensors(current_schedule_layer) @@ -445,10 +480,12 @@ def forward(ctx, tensor, stash_manager): # after forward return tensor @staticmethod - def backward(ctx, *grad_output): # before backward + def backward(ctx, *grad_output): # before backward # pylint: disable=missing-function-docstring if ctx.vp_stage is not None: - ctx.stash_manager.update_pp_schedule(-(ctx.vp_stage+1), -ctx.layer_no, -ctx.microbatch_no) + ctx.stash_manager.update_pp_schedule( + -(ctx.vp_stage + 1), -ctx.layer_no, -ctx.microbatch_no + ) ctx.stash_manager.current_schedule_index += 1 current_stream = torch.cuda.current_stream() @@ -456,9 +493,10 @@ def backward(ctx, *grad_output): # before backward if ctx.stash_manager._unpack_stream_status == 'reloading': current_stream.wait_stream(ctx.stash_manager.unpack_stream) ctx.stash_manager._unpack_stream_status = 'idle' - + return grad_output + (None, None) + class PagedStashManager: """ Singleton manager for coordinating paged stashing across pipeline stages. @@ -480,11 +518,11 @@ def __init__(self): # allocate streams and events for synchronization self.enabled = False self._pack_stream = torch.cuda.Stream() - # Currently paged stashing is not stream-safe, so use the same stream for packing + # Currently paged stashing is not stream-safe, so use the same stream for packing # and unpacking - self._unpack_stream = self._pack_stream + self._unpack_stream = self._pack_stream self._pack_stream_status = 'idle' # idle, stashing - self._unpack_stream_status = 'idle' # idle, reloading + self._unpack_stream_status = 'idle' # idle, reloading self.paged_tensors_to_stash = [] self.paged_tensors_stash_in_progress = [] self.paged_tensors_to_reload = {} @@ -494,12 +532,14 @@ def __init__(self): self.vp_size = None self.current_vp_stage = None self._last_layer = False - self.status = 'begin' # begin, capture, captured - self._pp_schedule = None # If element is +ve, it denotes forward pass of vp stage, if -ve, it denotes backward pass of vp stage + self.status = 'begin' # begin, capture, captured + # If element is +ve, it denotes forward pass of vp stage, + # if -ve, it denotes backward pass of vp stage + self._pp_schedule = None self.current_layer = None self.current_microbatch = None self.current_schedule_index = None - + # Track max tokens needed per vp_stage, dtype, and hidden_size self.max_tokens_per_vp_stage = None self.temp_tokens_per_vp_stage = None @@ -512,7 +552,7 @@ def __init__(self): self.stash_buffers = None self.overflow = None self.device = None - + # Page size for paged memory management self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page @@ -532,7 +572,7 @@ def set_current_layer_name(self, name): def get_schedule_layer(self, vp_stage, layer_no, microbatch_no): """Get the schedule layer.""" - return vp_stage*1000000 + layer_no*1000 + microbatch_no + return vp_stage * 1000000 + layer_no * 1000 + microbatch_no def add_paged_tensor_to_stash(self, paged_tensor): """Add a paged tensor to the stash list.""" @@ -546,7 +586,9 @@ def remove_paged_tensor_from_stash(self): if self.status == 'captured': while len(self.paged_tensors_to_stash) > 0: paged_tensor = self.paged_tensors_to_stash.pop(0) - assert len(self.paged_tensors_to_stash) == 0, f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" + assert ( + len(self.paged_tensors_to_stash) == 0 + ), f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" else: pass @@ -560,8 +602,11 @@ def stash_paged_tensors(self, pp_schedule_layer): self._pack_stream_status = 'stashing' if pp_schedule_layer not in self.paged_tensors_to_reload: self.paged_tensors_to_reload[pp_schedule_layer] = [] - assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, f"paged_tensors_to_reload {pp_schedule_layer} is not empty {self.paged_tensors_to_reload[pp_schedule_layer]}" - + assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, ( + f"paged_tensors_to_reload {pp_schedule_layer} is not empty " + f"{self.paged_tensors_to_reload[pp_schedule_layer]}" + ) + while len(self.paged_tensors_to_stash) > 0: paged_tensor = self.paged_tensors_to_stash.pop(0) stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] @@ -570,7 +615,9 @@ def stash_paged_tensors(self, pp_schedule_layer): self.paged_tensors_stash_in_progress.append(paged_tensor) else: pass - assert len(self.paged_tensors_to_stash) == 0, f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" + assert ( + len(self.paged_tensors_to_stash) == 0 + ), f"paged_tensors_to_stash is not empty {self.paged_tensors_to_stash}" def wait_for_stash_to_complete(self): """Wait for stash to complete.""" @@ -589,7 +636,8 @@ def wait_for_stash_to_complete(self): def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): """Reload the paged tensors.""" - # Avoid waiting for main stream if reload is immediately after stash since stash is already waiting for main stream + # Avoid waiting for main stream if reload is immediately after stash + # since stash is already waiting for main stream if not no_wait or self.unpack_stream != self.pack_stream: current_stream = torch.cuda.current_stream() self.unpack_stream.wait_stream(current_stream) @@ -601,16 +649,18 @@ def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): for item in self.paged_tensors_to_reload: if len(self.paged_tensors_to_reload[item]) > 0: count += 1 - + while len(self.paged_tensors_to_reload[pp_schedule_layer]) > 0: paged_tensor = self.paged_tensors_to_reload[pp_schedule_layer].pop(0) stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] paged_tensor.reload_from_stash(stash_buffer) else: pass - assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, f"paged_tensors_to_reload {pp_schedule_layer} is not empty {self.paged_tensors_to_reload[pp_schedule_layer]}" + assert len(self.paged_tensors_to_reload[pp_schedule_layer]) == 0, ( + f"paged_tensors_to_reload {pp_schedule_layer} is not empty " + f"{self.paged_tensors_to_reload[pp_schedule_layer]}" + ) - def allocate_stash_buffers(self, stash_buffer_size_factor=1.10): """Allocate stash buffers organized by [dtype][hidden_size].""" self.stash_buffers = {} @@ -625,7 +675,7 @@ def allocate_stash_buffers(self, stash_buffer_size_factor=1.10): ) self.stash_buffers[dtype][hidden_size] = PagedStashBuffer( num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype - ) + ) def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): """Update the pp schedule.""" @@ -638,23 +688,24 @@ def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): assert self.vp_size is not None if layer_no is None: # forward pass - layer_no = self.current_layer[vp_stage-1] - self.current_layer[vp_stage-1] += 1 - microbatch_no = self.current_microbatch[vp_stage-1] + layer_no = self.current_layer[vp_stage - 1] + self.current_layer[vp_stage - 1] += 1 + microbatch_no = self.current_microbatch[vp_stage - 1] if self._last_layer: - self.current_layer[vp_stage-1] = 1 - self.current_microbatch[vp_stage-1] += 1 + self.current_layer[vp_stage - 1] = 1 + self.current_microbatch[vp_stage - 1] += 1 if self.status == 'capture': self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) num_tokens = self.num_tokens_tensor.item() - assert self._pp_schedule[self.current_schedule_index] == self.get_schedule_layer(vp_stage, layer_no, microbatch_no), f"schedule {self._pp_schedule[self.current_schedule_index]} != {self.get_schedule_layer(vp_stage, layer_no, microbatch_no)}" - - + expected = self.get_schedule_layer(vp_stage, layer_no, microbatch_no) + actual = self._pp_schedule[self.current_schedule_index] + assert actual == expected, f"schedule {actual} != {expected}" + return layer_no, microbatch_no - #self._pp_schedule.append(vp_size) - #self._pp_schedule.append(vp_stage) + # self._pp_schedule.append(vp_size) + # self._pp_schedule.append(vp_stage) def on_save_for_backward(self, tensor: torch.Tensor) -> Any: """ @@ -663,22 +714,36 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: """ # Handle 0-dim tensors (torch.Size([])) - they have no size(0) - if self.max_num_tokens is None or tensor.dim() == 0 or tensor.size(0) != self.max_num_tokens: + if ( + self.max_num_tokens is None + or tensor.dim() == 0 + or tensor.size(0) != self.max_num_tokens + ): return tensor.detach() if isinstance(tensor, MXFP8Tensor): - assert tensor._rowwise_data is None, f"rowwise_data is not None; Only columnwise data is supported for paged stashing" + assert ( + tensor._rowwise_data is None + ), f"rowwise_data is not None; Only columnwise data is supported for paged stashing" if self.status == 'capture': self.num_tokens = self.num_tokens_tensor.item() - dtype = tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype + dtype = ( + tensor.dtype + if not isinstance(tensor, MXFP8Tensor) + else tensor._columnwise_data.dtype + ) # Get hidden_size from tensor shape if isinstance(tensor, MXFP8Tensor): - hidden_size = tensor._columnwise_data.shape[1] if tensor._columnwise_data.ndim > 1 else tensor._columnwise_data.numel() + hidden_size = ( + tensor._columnwise_data.shape[1] + if tensor._columnwise_data.ndim > 1 + else tensor._columnwise_data.numel() + ) else: hidden_size = tensor.shape[1] if tensor.ndim > 1 else tensor.numel() - + if dtype not in self.temp_tokens_per_vp_stage[self.current_vp_stage]: self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} self.max_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} @@ -686,10 +751,12 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] += self.num_tokens + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][ + hidden_size + ] += self.num_tokens self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = max( self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size], - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] + self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size], ) if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 @@ -698,28 +765,34 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: self.temp_tokens_across_vp_stages[dtype, hidden_size] += self.num_tokens self.max_tokens_across_vp_stages[dtype, hidden_size] = max( self.max_tokens_across_vp_stages[dtype, hidden_size], - self.temp_tokens_across_vp_stages[dtype, hidden_size] + self.temp_tokens_across_vp_stages[dtype, hidden_size], ) - # Since capture stage does not use CUDA graph, we can truncate the saved tensor to actual num_tokens - # Truncate the tensor to the actual number of tokens + # Since capture stage does not use CUDA graph, we can truncate + # the saved tensor to actual num_tokens new_size = (self.num_tokens, *tensor.shape[1:]) if isinstance(tensor, MXFP8Tensor): - tensor_truncated = torch.empty(new_size, dtype=tensor._columnwise_data.dtype, device=tensor.device) - tensor_truncated.copy_(tensor._columnwise_data[:self.num_tokens, ...]) + tensor_truncated = torch.empty( + new_size, dtype=tensor._columnwise_data.dtype, device=tensor.device + ) + tensor_truncated.copy_(tensor._columnwise_data[: self.num_tokens, ...]) tensor._columnwise_data = tensor_truncated else: tensor_truncated = torch.empty(new_size, dtype=tensor.dtype, device=tensor.device) - tensor_truncated.copy_(tensor[:self.num_tokens, ...]) + tensor_truncated.copy_(tensor[: self.num_tokens, ...]) tensor = tensor_truncated - paged_tensor = PagedTensor( - tensor, - num_tokens_tensor=self.num_tokens_tensor, + tensor, + num_tokens_tensor=self.num_tokens_tensor, vp_stage=self.current_vp_stage, - schedule_layer_no=self._pp_schedule[self.current_schedule_index] if self._pp_schedule is not None and self.current_schedule_index < len(self._pp_schedule) else None, - layer_name=self._current_layer_name, + schedule_layer_no=( + self._pp_schedule[self.current_schedule_index] + if self._pp_schedule is not None + and self.current_schedule_index < len(self._pp_schedule) + else None + ), + layer_name=self._current_layer_name, max_tokens=self.max_num_tokens, page_size=self.page_size, ) @@ -727,7 +800,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if self.status == 'captured': self.add_paged_tensor_to_stash(paged_tensor) return paged_tensor - + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: """ Hook called when autograd retrieves a saved tensor during backward pass. @@ -736,34 +809,42 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: if isinstance(saved_state, (PagedTensor)): if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() - self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][saved_state.hidden_size] -= num_tokens - self.temp_tokens_across_vp_stages[saved_state.dtype, saved_state.hidden_size] -= num_tokens + self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][ + saved_state.hidden_size + ] -= num_tokens + self.temp_tokens_across_vp_stages[ + saved_state.dtype, saved_state.hidden_size + ] -= num_tokens # Pad the tensor to the max number of tokens npad = self.max_num_tokens - num_tokens pad = () - for _ in range(saved_state._tensor.ndim-1): + for _ in range(saved_state._tensor.ndim - 1): pad = pad + (0, 0) pad = pad + (0, npad) if isinstance(saved_state._tensor, MXFP8Tensor): - saved_state._tensor._columnwise_data = torch.nn.functional.pad(saved_state._tensor._columnwise_data, pad) + saved_state._tensor._columnwise_data = torch.nn.functional.pad( + saved_state._tensor._columnwise_data, pad + ) else: saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad) - assert saved_state._tensor is not None, f"saved_state._tensor is None {saved_state._tensor}" + assert ( + saved_state._tensor is not None + ), f"saved_state._tensor is None {saved_state._tensor}" return saved_state._tensor return saved_state - + + class PagedStashContext: """Wrapper context manager that adds custom enter/exit behavior around saved_tensors_hooks.""" - + def __init__(self, stash_manager): self.stash_manager = stash_manager self.saved_tensors_context = torch.autograd.graph.saved_tensors_hooks( - stash_manager.on_save_for_backward, - stash_manager.on_get_saved_tensor + stash_manager.on_save_for_backward, stash_manager.on_get_saved_tensor ) - + def __enter__(self): from megatron.core.extensions.transformer_engine import cpu_offload @@ -771,10 +852,10 @@ def __enter__(self): cpu_offload.CPUOffloadEnabled = True # Call the underlying context manager's __enter__ result = self.saved_tensors_context.__enter__() - + # Add more custom logic after entering if needed return result - + def __exit__(self, *args: Any): # Call the underlying context manager's __exit__ result = self.saved_tensors_context.__exit__(*args) @@ -784,6 +865,7 @@ def __exit__(self, *args: Any): cpu_offload.CPUOffloadEnabled = False return result + def paged_stash_group_start(tensor, name=None): """Mark the start of a layer group and prepare for stash/reload.""" rank = torch.distributed.get_rank() @@ -792,6 +874,7 @@ def paged_stash_group_start(tensor, name=None): return tensor return PP_PreScheduleFunction.apply(tensor, stash_manager) + def get_paged_stash_context(name=None, max_num_tokens=None, num_tokens_tensor=None): """Get the paged stash context""" stash_manager = PagedStashManager.get_instance() @@ -804,6 +887,7 @@ def get_paged_stash_context(name=None, max_num_tokens=None, num_tokens_tensor=No pack_unpack_context = PagedStashContext(stash_manager) return pack_unpack_context + def paged_stash_group_commit(tensor, name=None): """Mark the end of a layer group and prepare for stash/reload.""" rank = torch.distributed.get_rank() @@ -813,6 +897,7 @@ def paged_stash_group_commit(tensor, name=None): return tensor return PP_PostScheduleFunction.apply(tensor, stash_manager) + def paged_stash_init_chunk_handler(vp_size, vp_stage): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" stash_manager = PagedStashManager.get_instance() @@ -830,6 +915,7 @@ def paged_stash_init_chunk_handler(vp_size, vp_stage): stash_manager.max_tokens_across_vp_stages = {} stash_manager.temp_tokens_across_vp_stages = {} + def paged_stash_set_last_layer(is_last_layer=False): """Set the last layer flag.""" stash_manager = PagedStashManager.get_instance() @@ -837,6 +923,7 @@ def paged_stash_set_last_layer(is_last_layer=False): return stash_manager._last_layer = is_last_layer + def paged_stash_reset(enabled=True): """Reset the chunk handler, called at the start of a training iteration.""" stash_manager = PagedStashManager.get_instance() @@ -868,8 +955,14 @@ def paged_stash_reset(enabled=True): stash_manager.overflow.zero_() stash_manager.current_layer = [1 for _ in range(stash_manager.vp_size)] stash_manager.current_microbatch = [1 for _ in range(stash_manager.vp_size)] - assert len(stash_manager.paged_tensors_to_stash) == 0, f"paged_tensors_to_stash is not empty {stash_manager.paged_tensors_to_stash}" - assert len(stash_manager.paged_tensors_stash_in_progress) == 0, f"paged_tensors_stash_in_progress is not empty {stash_manager.paged_tensors_stash_in_progress}" + assert ( + len(stash_manager.paged_tensors_to_stash) == 0 + ), f"paged_tensors_to_stash is not empty {stash_manager.paged_tensors_to_stash}" + assert len(stash_manager.paged_tensors_stash_in_progress) == 0, ( + f"paged_tensors_stash_in_progress is not empty " + f"{stash_manager.paged_tensors_stash_in_progress}" + ) + def check_paged_stash_overflow(): """Check if paged stash overflow""" diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index b742cbbbb3f..de629265168 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -10,7 +10,6 @@ from megatron.core.config import is_experimental_enabled from megatron.core.fusions.fused_indices_converter import fused_indices_to_multihot from megatron.core.fusions.fused_pad_routing_map import fused_pad_routing_map -from megatron.core.jit import jit_fuser from megatron.core.tensor_parallel import ( all_to_all, gather_from_sequence_parallel_region, @@ -1033,7 +1032,11 @@ def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor): if self.moe_expert_rank_capacity_factor is not None: pad_multiple = get_align_size_for_quantization(self.config) - budget = int(routing_map.shape[0] * self.config.moe_router_topk * self.moe_expert_rank_capacity_factor) + budget = int( + routing_map.shape[0] + * self.config.moe_router_topk + * self.moe_expert_rank_capacity_factor + ) budget += -budget % pad_multiple self.num_permuted_tokens = budget # Compute the capacity for each expert at the drop_and_pad mode @@ -1081,7 +1084,7 @@ def dispatch( ) ) if self.moe_expert_rank_capacity_factor is not None: - over_budget = self.handle[8] != 0 # this is overflow_flag + over_budget = self.handle[8] != 0 # this is overflow_flag self.over_budget |= over_budget if self.num_permuted_tokens is None: @@ -1471,7 +1474,7 @@ def dispatch_preprocess( routing_map, probs = self._initialize_metadata(routing_map, probs) self._comm_manager.setup_metadata(routing_map, probs) - + return hidden_states, self._comm_manager.token_probs def token_dispatch( @@ -1567,6 +1570,7 @@ def combine_postprocess(self, hidden_states: torch.Tensor): return hidden_states.view(self.hidden_shape) def check_over_budget(self): + """Check if the dispatcher has exceeded its budget.""" if hasattr(self._comm_manager, 'over_budget'): return self._comm_manager.over_budget else: diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 411a73e4681..b4e560868fa 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -26,6 +26,7 @@ from megatron.core.transformer.enums import CudaGraphScope, LayerType from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule +from megatron.core.transformer.moe.paged_stash import paged_stash_set_last_layer from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.torch_norm import LayerNormBuilder from megatron.core.transformer.transformer_config import TransformerConfig @@ -897,7 +898,7 @@ def forward( ) if self.config.moe_paged_stash: paged_stash_set_last_layer( - is_last_layer = (l_no == self.num_layers_per_pipeline_rank - 1) + is_last_layer=(l_no == self.num_layers_per_pipeline_rank - 1) ) with self.offload_context, inner_quantization_context: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 176ed2bb49a..6e2d30c6fbb 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -686,8 +686,8 @@ class TransformerConfig(ModelParallelConfig): """ moe_use_device_initiated_grouped_gemm: bool = False - """Use the cutlass grouped gemm kernel, which allows for the token_per_expert tensor on GPU. This can prevent the GPU-CPU synchronization during the grouped gemm.""" - + """Use the cutlass grouped gemm kernel, which allows for the token_per_expert tensor on GPU. + This can prevent the GPU-CPU synchronization during the grouped gemm.""" moe_use_legacy_grouped_gemm: bool = False """Use legacy GroupedMLP rather than TEGroupedMLP. @@ -769,21 +769,18 @@ class TransformerConfig(ModelParallelConfig): """Number of SMs to use for HybridEP. In pure NVL scenarios, 16 SMs can generally achieve good bandwidth.""" -<<<<<<< HEAD moe_mlp_glu_interleave_size: Optional[int] = None """When set, GLU activations in the MoE grouped MLP layer will use a block interleaved format. Instead of interpreting the input tensor as a concatenation of gates and linear units, it will be interpreted as alternating blocks of gates and linear units. - This data format is experimental and primarily intended to enable advanced fused kernels.""" -======= moe_expert_rank_capacity_factor: Optional[float] = None """moe_expert_rank_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None.""" ->>>>>>> f52bf7f51 (Update added arguments and add compatibility check) + ################## # Context Parallel ################## @@ -1267,8 +1264,8 @@ def __post_init__(self): if self.moe_expert_rank_capacity_factor is not None: if not self.moe_use_device_initiated_grouped_gemm: raise ValueError( - "moe_expert_rank_capacity_factor requires moe_use_device_initiated_grouped_gemm " - "to be enabled." + "moe_expert_rank_capacity_factor requires " + "moe_use_device_initiated_grouped_gemm to be enabled." ) if self.moe_flex_dispatcher_backend != "hybridep": raise ValueError( @@ -1481,18 +1478,18 @@ def __post_init__(self): assert ( not self.cpu_offloading and not self.fine_grained_activation_offloading ), "paged_stash cannot be enabled with cpu_offloading." - assert self.stash_modules is not None and len(self.stash_modules) > 0, ( - "stash_modules must be specified when moe_paged_stash is enabled." - ) + assert ( + self.stash_modules is not None and len(self.stash_modules) > 0 + ), "stash_modules must be specified when moe_paged_stash is enabled." allowed_modules = {"expert_fc1", "expert_fc2", "moe_act"} invalid_modules = set(self.stash_modules) - allowed_modules assert not invalid_modules, ( f'Invalid choices for stash_modules: {invalid_modules}. ' f'Allowed modules are: {allowed_modules}' ) - assert self.moe_expert_rank_capacity_factor is not None, ( - "moe_expert_rank_capacity_factor must be set when moe_paged_stash is enabled." - ) + assert ( + self.moe_expert_rank_capacity_factor is not None + ), "moe_expert_rank_capacity_factor must be set when moe_paged_stash is enabled." # Check that no module is both stashed and offloaded if self.stash_modules and self.offload_modules: From 3186b200f2410d848e65a79e2e44b17a0a959f77 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Wed, 7 Jan 2026 18:31:13 -0800 Subject: [PATCH 34/57] Minor refactor --- megatron/core/pipeline_parallel/schedules.py | 2 + megatron/core/transformer/moe/experts.py | 4 +- megatron/core/transformer/moe/paged_stash.py | 45 +++++-------------- .../core/transformer/moe/token_dispatcher.py | 1 - 4 files changed, 14 insertions(+), 38 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 976c8e6018f..04b22bfc297 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -1550,8 +1550,10 @@ def forward_backward_helper_wrapper( send_next_wait_handle = None send_prev_wait_handle = None recv_next_wait_handles = [] + for k in range(num_warmup_microbatches): cur_model_chunk_id = get_model_chunk_id(k, forward=True) + if config.overlap_p2p_comm_warmup_flush: if ( not ( diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index c309fcd84f3..311d1951edf 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -1017,9 +1017,7 @@ def forward( self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" ) as permuted_local_hidden_states: if self.config.moe_paged_stash: - permuted_local_hidden_states = paged_stash_group_start( - permuted_local_hidden_states, name="expert_fc1" - ) + permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states) if self.moe_paged_stash_expert_fc1: offload_context = get_paged_stash_context( name="expert_fc1", diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index daf3065bfd4..24c04d4eaee 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -540,9 +540,6 @@ def __init__(self): self.current_microbatch = None self.current_schedule_index = None - # Track max tokens needed per vp_stage, dtype, and hidden_size - self.max_tokens_per_vp_stage = None - self.temp_tokens_per_vp_stage = None # Track max tokens needed across all vp_stages grouped by dtype and hidden_size self.max_tokens_across_vp_stages = None self.temp_tokens_across_vp_stages = None @@ -688,12 +685,13 @@ def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): assert self.vp_size is not None if layer_no is None: # forward pass - layer_no = self.current_layer[vp_stage - 1] - self.current_layer[vp_stage - 1] += 1 - microbatch_no = self.current_microbatch[vp_stage - 1] + vp_stage_index = vp_stage - 1 + layer_no = self.current_layer[vp_stage_index] + self.current_layer[vp_stage_index] += 1 + microbatch_no = self.current_microbatch[vp_stage_index] if self._last_layer: - self.current_layer[vp_stage - 1] = 1 - self.current_microbatch[vp_stage - 1] += 1 + self.current_layer[vp_stage_index] = 1 + self.current_microbatch[vp_stage_index] += 1 if self.status == 'capture': self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) @@ -704,8 +702,6 @@ def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): assert actual == expected, f"schedule {actual} != {expected}" return layer_no, microbatch_no - # self._pp_schedule.append(vp_size) - # self._pp_schedule.append(vp_stage) def on_save_for_backward(self, tensor: torch.Tensor) -> Any: """ @@ -744,20 +740,6 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: else: hidden_size = tensor.shape[1] if tensor.ndim > 1 else tensor.numel() - if dtype not in self.temp_tokens_per_vp_stage[self.current_vp_stage]: - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} - self.max_tokens_per_vp_stage[self.current_vp_stage][dtype] = {} - if hidden_size not in self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype]: - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 - self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = 0 - - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][ - hidden_size - ] += self.num_tokens - self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size] = max( - self.max_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size], - self.temp_tokens_per_vp_stage[self.current_vp_stage][dtype][hidden_size], - ) if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 self.max_tokens_across_vp_stages[dtype, hidden_size] = 0 @@ -809,15 +791,13 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: if isinstance(saved_state, (PagedTensor)): if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() - self.temp_tokens_per_vp_stage[saved_state.vp_stage][saved_state.dtype][ - saved_state.hidden_size - ] -= num_tokens - self.temp_tokens_across_vp_stages[ - saved_state.dtype, saved_state.hidden_size - ] -= num_tokens # Pad the tensor to the max number of tokens npad = self.max_num_tokens - num_tokens pad = () + # check if the tensor is 2D + assert ( + saved_state._tensor.ndim == 2 + ), f"saved_state._tensor.ndim is not 2 {saved_state._tensor.ndim}" for _ in range(saved_state._tensor.ndim - 1): pad = pad + (0, 0) pad = pad + (0, npad) @@ -866,7 +846,7 @@ def __exit__(self, *args: Any): return result -def paged_stash_group_start(tensor, name=None): +def paged_stash_group_start(tensor): """Mark the start of a layer group and prepare for stash/reload.""" rank = torch.distributed.get_rank() stash_manager = PagedStashManager.get_instance() @@ -908,9 +888,6 @@ def paged_stash_init_chunk_handler(vp_size, vp_stage): stash_manager.vp_size = vp_size else: stash_manager.vp_size = 1 - if stash_manager.max_tokens_per_vp_stage is None: - stash_manager.max_tokens_per_vp_stage = [{} for _ in range(stash_manager.vp_size)] - stash_manager.temp_tokens_per_vp_stage = [{} for _ in range(stash_manager.vp_size)] if stash_manager.max_tokens_across_vp_stages is None: stash_manager.max_tokens_across_vp_stages = {} stash_manager.temp_tokens_across_vp_stages = {} diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index de629265168..284217d356c 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1474,7 +1474,6 @@ def dispatch_preprocess( routing_map, probs = self._initialize_metadata(routing_map, probs) self._comm_manager.setup_metadata(routing_map, probs) - return hidden_states, self._comm_manager.token_probs def token_dispatch( From 1ab150b6b43d29a3782d47bcee1d85faeef36633 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Thu, 8 Jan 2026 18:39:12 -0800 Subject: [PATCH 35/57] Add unit test for Paged Stashing --- .../transformer/moe/test_paged_stashing.py | 247 ++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 tests/unit_tests/transformer/moe/test_paged_stashing.py diff --git a/tests/unit_tests/transformer/moe/test_paged_stashing.py b/tests/unit_tests/transformer/moe/test_paged_stashing.py new file mode 100644 index 00000000000..757c602a7ec --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_paged_stashing.py @@ -0,0 +1,247 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import copy +import dataclasses + +import pytest +import torch + +from megatron.core import config, parallel_state +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.paged_stash import ( + paged_stash_init_chunk_handler, + paged_stash_reset, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import is_te_min_version +from megatron.training.initialize import _set_random_seed +from tests.unit_tests.test_utilities import Utils + + +def token_permutation(token_dispatcher, hidden_states, probs, indices): + residual = hidden_states + hidden_states, probs = token_dispatcher.dispatch_preprocess(hidden_states, indices, probs) + hidden_states, probs = token_dispatcher.token_dispatch(hidden_states, probs) + return hidden_states, probs, residual + + +def token_unpermutation(token_dispatcher, hidden_states): + hidden_states = token_dispatcher.token_combine(hidden_states) + hidden_states = token_dispatcher.combine_postprocess(hidden_states) + return hidden_states, None + + +class MoEModelTestContainer: + def __init__( + self, + tp_size, + ep_size, + pp_size, + cp_size=1, + moe_tp_size=None, + data_parallel_random_init=False, + num_moe_experts=8, + num_layers=1, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="alltoall", + moe_expert_capacity_factor=None, + moe_pad_expert_input_to_capacity=False, + moe_aux_loss_coeff=0.1, + test_dtype=torch.float32, + **kwargs, + ): + self.num_local_experts = num_moe_experts // ep_size + self.num_layers = num_layers + self.test_dtype = test_dtype + if moe_tp_size is None: + moe_tp_size = tp_size + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + expert_model_parallel_size=ep_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + ) + _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) + local_expert_indices_offset = ( + parallel_state.get_expert_model_parallel_rank() * self.num_local_experts + ) + self.local_expert_indices = [ + local_expert_indices_offset + i for i in range(self.num_local_experts) + ] + self.config = TransformerConfig( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + pipeline_model_parallel_size=pp_size, + context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, + fp8='e4m3', + fp8_recipe='mxfp8', + fp8_wgrad=True, + fp8_amax_compute_algo='most_recent', + fp8_amax_history_len=1, + fp8_interval=1, + fp8_margin=0, + moe_router_topk=moe_router_topk, + num_moe_experts=num_moe_experts, + moe_router_load_balancing_type=moe_router_load_balancing_type, + moe_token_dispatcher_type=moe_token_dispatcher_type, + moe_expert_capacity_factor=moe_expert_capacity_factor, + moe_pad_expert_input_to_capacity=moe_pad_expert_input_to_capacity, + moe_aux_loss_coeff=moe_aux_loss_coeff, + num_layers=num_layers, + moe_router_dtype="fp32", + hidden_size=kwargs.get("hidden_size", 16), + num_attention_heads=kwargs.get("num_attention_heads", 8), + use_cpu_initialization=kwargs.get("use_cpu_initialization", True), + sequence_parallel=tp_size > 1, + add_bias_linear=kwargs.get("add_bias_linear", False), + moe_permute_fusion=kwargs.get("moe_permute_fusion", False), + moe_flex_dispatcher_backend=kwargs.get("moe_flex_dispatcher_backend", None), + moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False), + moe_use_device_initiated_grouped_gemm=kwargs.get( + "moe_use_device_initiated_grouped_gemm", False + ), + moe_use_legacy_grouped_gemm=kwargs.get("moe_use_legacy_grouped_gemm", False), + moe_paged_stash=kwargs.get("moe_paged_stash", False), + stash_modules=kwargs.get("stash_modules", None), + moe_expert_rank_capacity_factor=kwargs.get("moe_expert_rank_capacity_factor", None), + moe_router_padding_for_fp8=kwargs.get("moe_router_padding_for_fp8", True), + ) + # init moe layers + self.moe_layers = [self.new_moe_layer(layer_number=i) for i in range(num_layers)] + + def new_moe_layer(self, **kargs): + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=self.config.num_moe_experts, moe_grouped_gemm=True + ) + layer_number = kargs.get("layer_number", 0) + quantization_context = get_fp8_context(self.config, layer_number, is_init=True) + with quantization_context: + moe_layer = ( + MoELayer(self.config, transformer_layer_spec.submodules.mlp.submodules) + .cuda() + .to(dtype=self.test_dtype) + ) + moe_layer.set_layer_number(layer_number) + return moe_layer + + def zero_grad(self): + for moe_layer in self.moe_layers: + moe_layer.zero_grad() + + def __del__(self): + torch.distributed.barrier() + torch.cuda.synchronize() + Utils.destroy_model_parallel() + + @pytest.mark.internal + def dispatcher_dropless_test(self, inp_hidden_states=None): + moe_layers = self.moe_layers + + inp_hidden_states = inp_hidden_states.cuda() + # Permute and then unpermute data are supposed to restore original data + inp_hidden_states.requires_grad = True + hidden_states = inp_hidden_states + for i, moe_layer in enumerate(moe_layers): + quantization_context = get_fp8_context(self.config) + with quantization_context: + probs, indices = moe_layer.router(hidden_states) + probs = torch.ones_like(probs) / moe_layer.router.topk + + (dispatched_input, probs, residual) = token_permutation( + moe_layer.token_dispatcher, hidden_states, probs, indices + ) + output, _ = moe_layer.routed_experts_compute(dispatched_input, probs, residual) + output, _ = token_unpermutation(moe_layer.token_dispatcher, output) + hidden_states = output + torch.autograd.backward(output, inp_hidden_states) + return output, inp_hidden_states.grad + + def set_params(self): + # TODO: Set consistent parameters for various parallelisms. + raise NotImplementedError + + def destroy(self): + Utils.destroy_model_parallel() + + +permute_fusion_params = [False] +if is_te_min_version("2.1.0"): + permute_fusion_params.append(True) + + +def is_deep_ep_available(): + from megatron.core.transformer.moe.fused_a2a import HAVE_DEEP_EP + return HAVE_DEEP_EP + + +def is_hybrid_ep_available(): + from megatron.core.transformer.moe.fused_a2a import HAVE_HYBRIDEP + return HAVE_HYBRIDEP + + +@pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available") +class TestFlexDispatcher: + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + def test_forward_backward(self): + if not is_hybrid_ep_available(): + pytest.skip("Hybrid EP is not available") + + config.ENABLE_EXPERIMENTAL = True + + container = MoEModelTestContainer( + tp_size=1, + ep_size=4, + pp_size=1, + num_moe_experts=8, + num_layers=2, + moe_router_topk=2, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="flex", + moe_permute_fusion=True, + hidden_size=1024, + moe_flex_dispatcher_backend="hybridep", + test_dtype=torch.bfloat16, + moe_grouped_gemm=True, + moe_use_device_initiated_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, + moe_paged_stash=True, + stash_modules=["expert_fc1", "moe_act", "expert_fc2"], + moe_expert_rank_capacity_factor=1.5, + ) + bs = 32 + seql = 8 + + inp_hidden_states = torch.randn( + (bs, seql, container.moe_layers[0].config.hidden_size), dtype=torch.bfloat16 + ) + # First iteration to capture schedule, calculate capacity, etc. + paged_stash_reset(True) + paged_stash_init_chunk_handler(1, 0) + output_ref, inp_hidden_states_grad_ref = container.dispatcher_dropless_test( + inp_hidden_states + ) + + container.zero_grad() + + # Second iteration to run with paged stash. + paged_stash_reset(True) + paged_stash_init_chunk_handler(1, 0) + output, inp_hidden_states_grad = container.dispatcher_dropless_test(inp_hidden_states) + + # verify output and input gradient are the same as the first iteration. + torch.testing.assert_close(output, output_ref, atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + inp_hidden_states_grad, inp_hidden_states_grad_ref, atol=1e-4, rtol=1e-4 + ) From 4ecacac1b59f0b68ab2e615854617f8f94629d8d Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Mon, 12 Jan 2026 14:38:51 +0800 Subject: [PATCH 36/57] Initial check in of a) force load imbalance b) log overload factors --- megatron/core/transformer/moe/moe_utils.py | 240 ++++++++++++++++++ megatron/core/transformer/moe/router.py | 16 ++ .../core/transformer/transformer_config.py | 5 + megatron/training/training.py | 9 + 4 files changed, 270 insertions(+) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index dbcc25a905c..84a18853669 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -44,6 +44,9 @@ HAVE_TE = False +# MOE logging +_MOE_OVERLOAD_FACTOR_TRACKER = {} + def switch_load_balancing_loss_func( probs: torch.Tensor, tokens_per_expert: torch.Tensor, @@ -953,6 +956,221 @@ def clear_aux_losses_tracker() -> None: get_moe_metrics_tracker().clear() +def get_overload_factor_tracker(): + """Return the overload factor tracker.""" + global _MOE_OVERLOAD_FACTOR_TRACKER + return _MOE_OVERLOAD_FACTOR_TRACKER + + +class SaveOverloadFactorFunction(torch.autograd.Function): + """Autograd function to save overload factor data for forward and backward passes.""" + + @staticmethod + def forward(ctx, tensor, routing_map, layer_number, num_local_experts): + """Forward pass: save overload factor data with 'fwd' label. + + Args: + tensor: A tensor in the autograd graph (e.g., probs) to pass through. + routing_map: The routing map tensor [num_tokens, num_experts]. + layer_number: Layer index (1-based). + num_local_experts: Number of experts per EP rank. + + Returns: + tensor unchanged (pass-through). + """ + if layer_number is None: + return tensor + + # Compute local tokens per expert + local_tokens_per_expert = routing_map.sum(dim=0).detach().float() + + # Group by EP rank: reshape [num_experts] -> [ep_size, num_local_experts] and sum + num_experts = local_tokens_per_expert.shape[0] + ep_size = num_experts // num_local_experts + local_tokens_per_ep_rank = local_tokens_per_expert.view(ep_size, num_local_experts).sum(dim=1) + + # Save to tracker + tracker = get_overload_factor_tracker() + if "fwd" not in tracker: + tracker["fwd"] = {} + if "fwd_bwd" not in tracker: + tracker["fwd_bwd"] = [] + + layer_idx = layer_number - 1 # Convert to 0-based index + if layer_idx not in tracker["fwd"]: + tracker["fwd"][layer_idx] = [] + + tracker["fwd"][layer_idx].append(local_tokens_per_ep_rank) + tracker["fwd_bwd"].append(local_tokens_per_ep_rank) + + # Save for backward + ctx.save_for_backward(local_tokens_per_ep_rank) + + return tensor + + @staticmethod + def backward(ctx, grad_output): + """Backward pass: append negated tokens to fwd_bwd tracker.""" + if ctx.saved_tensors: + (local_tokens_per_ep_rank,) = ctx.saved_tensors + tracker = get_overload_factor_tracker() + if "fwd_bwd" in tracker: + tracker["fwd_bwd"].append(-local_tokens_per_ep_rank) + return grad_output, None, None, None + + +def save_overload_factor_to_tracker( + tensor: torch.Tensor, + routing_map: torch.Tensor, + layer_number: int, + num_local_experts: int, + tp_ep_group: torch.distributed.ProcessGroup, + dp_group: torch.distributed.ProcessGroup, +): + """Save local tokens per EP rank and track token counts for activation memory analysis. + + This function wraps SaveOverloadFactorFunction to: + 1. Store data for overload factor computation (done in get_overload_factors_for_logging()) + 2. Track tokens in forward/backward for activation memory analysis + + Args: + tensor: A tensor in the autograd graph (e.g., probs) - passed through unchanged. + routing_map: The routing map tensor [num_tokens, num_experts]. + layer_number: Layer index (1-based). + num_local_experts: Number of experts per EP rank. + tp_ep_group: The TP x EP group for all-reducing. + dp_group: The DP group for max/avg reduction. + + Returns: + tensor unchanged. + """ + # Set comm groups in tracker (outside autograd function) + tracker = get_overload_factor_tracker() + if "tp_ep_group" not in tracker: + tracker["tp_ep_group"] = tp_ep_group + tracker["dp_group"] = dp_group + + return SaveOverloadFactorFunction.apply(tensor, routing_map, layer_number, num_local_experts) + + +def get_overload_factors_for_logging() -> dict: + """Compute overload factors from stored data and return organized metrics for logging. + + This function performs: + 1. All-reduce over TP x EP to get tokens per EP rank within each DP rank + 2. MAX reduction across DP ranks to get max tokens per EP rank + 3. AVG reduction across DP ranks to get avg tokens per EP rank + 4. Computes overload_factor = max_tokens / avg_tokens + + Should be called outside the critical path (e.g., during logging). + + Returns: + dict: A dictionary with structure: + { + "avg_overload_factor": float, + "max_overload_factor": float, + "max_cumsum_tokens": float (peak tokens from cumsum of fwd/bwd), + } + """ + tracker = get_overload_factor_tracker() + if "fwd" not in tracker or not tracker["fwd"]: + return {} + tp_ep_group = tracker.get("tp_ep_group") + dp_group = tracker.get("dp_group") + + # Collect fwd tensors for overload factor calculation + fwd_tensors = [] + layer_indices = [] # layer_idx for each fwd tensor + + for layer_idx in sorted(tracker["fwd"].keys()): + for local_tokens_per_ep_rank in tracker["fwd"][layer_idx]: + fwd_tensors.append(local_tokens_per_ep_rank) + layer_indices.append(layer_idx) + + if not fwd_tensors: + return {} + + # Stack fwd_bwd tensors (already has fwd positive, bwd negative) + fwd_bwd_tensors = tracker.get("fwd_bwd", []) + fwd_bwd_tensors_stacked = torch.stack(fwd_bwd_tensors, dim=0) if fwd_bwd_tensors else None + + # All-reduce across tp_ep_group, cumsum, and find max + max_cum_overload_factor = None + if fwd_bwd_tensors_stacked is not None: + if tp_ep_group is not None: + torch.distributed.all_reduce(fwd_bwd_tensors_stacked, group=tp_ep_group) + cumsum_tokens = fwd_bwd_tensors_stacked.cumsum(dim=0) + max_cumsum_tokens = cumsum_tokens.max().item() + + # Calculate max_cum_overload_factor = max_cumsum_tokens / cumsum_tokens.mean(dim=1).max() + mean_cumsum_max = cumsum_tokens.mean(dim=1).max() + local_max_cum_overload_factor = max_cumsum_tokens / (mean_cumsum_max.item() + 1e-8) + + # Max all-reduce to find global max across DP ranks + if dp_group is not None: + cum_overload_tensor = torch.tensor( + [local_max_cum_overload_factor], device=fwd_bwd_tensors_stacked.device + ) + torch.distributed.all_reduce(cum_overload_tensor, group=dp_group, op=torch.distributed.ReduceOp.MAX) + max_cum_overload_factor = cum_overload_tensor.item() + else: + max_cum_overload_factor = local_max_cum_overload_factor + all_tensors = fwd_tensors + # Stack all tensors and do all-reduce over TP x EP + # Shape: [num_entries, ep_size] + stacked_tensors = torch.stack(all_tensors, dim=0) + if tp_ep_group is not None: + torch.distributed.all_reduce(stacked_tensors, group=tp_ep_group) + + # Now reduce across DP ranks: get both MAX and AVG + if dp_group is not None: + # Clone for max reduction + max_tokens_per_ep_rank = stacked_tensors.clone() + torch.distributed.all_reduce( + max_tokens_per_ep_rank, group=dp_group, op=torch.distributed.ReduceOp.MAX + ) + # AVG reduction for mean tokens + avg_tokens_per_ep_rank = stacked_tensors.clone() + torch.distributed.all_reduce( + avg_tokens_per_ep_rank, group=dp_group, op=torch.distributed.ReduceOp.AVG + ) + else: + max_tokens_per_ep_rank = stacked_tensors + avg_tokens_per_ep_rank = stacked_tensors + + # Compute two overload factors for each entry: + # 1. avg_overload_factor = max(avg_tokens) / mean(avg_tokens) - based on AVG across DP + # 2. max_overload_factor = max(max_tokens) / mean(max_tokens) - based on MAX across DP + avg_max_tokens = avg_tokens_per_ep_rank.max(dim=1).values # [num_entries] + avg_mean_tokens = avg_tokens_per_ep_rank.float().mean(dim=1) # [num_entries] + avg_overload_factors = (avg_max_tokens / (avg_mean_tokens + 1e-8)) # [num_entries] + + max_max_tokens = max_tokens_per_ep_rank.max(dim=1).values # [num_entries] + max_mean_tokens = max_tokens_per_ep_rank.float().mean(dim=1) # [num_entries] + max_overload_factors = (max_max_tokens / (max_mean_tokens + 1e-8)) # [num_entries] + + # Compute overall statistics + # avg_overload_factor uses mean, max_overload_factor uses max + result = { + "avg_overload_factor": avg_overload_factors.mean().item(), + "max_overload_factor": max_overload_factors.max().item(), + "max_cum_overload_factor": max_cum_overload_factor, + } + + return result + + +def clear_overload_factor_tracker(): + """Clear the overload factor tracker.""" + tracker = get_overload_factor_tracker() + if "fwd" in tracker: + tracker["fwd"].clear() + if "fwd_bwd" in tracker: + tracker["fwd_bwd"].clear() + tracker.pop("tp_ep_group", None) + tracker.pop("dp_group", None) + + @deprecated( version="0.16", removal_version="0.18", alternative="get_moe_metrics_tracker()._sync_metrics()" ) @@ -995,6 +1213,7 @@ def track_moe_metrics( writer: Optional["SummaryWriter"] = None, wandb_writer: Optional["wandb.Run"] = None, total_loss_dict: Optional[dict[str, torch.Tensor]] = None, + overload_dict=None, per_layer_logging: bool = False, force_initialize: bool = False, track_names: Optional[List[str]] = None, @@ -1007,6 +1226,27 @@ def track_moe_metrics( Deprecated: Use get_moe_metrics_tracker().report() directly. """ + # Log overload factor metrics + overload_metrics = get_overload_factors_for_logging() + if overload_metrics: + if overload_dict is not None: + overload_dict.update(overload_metrics) + if writer is not None: + if "avg_overload_factor" in overload_metrics: + writer.add_scalar("moe/avg_overload_factor", overload_metrics["avg_overload_factor"], iteration) + if "max_overload_factor" in overload_metrics: + writer.add_scalar("moe/max_overload_factor", overload_metrics["max_overload_factor"], iteration) + if "max_cum_overload_factor" in overload_metrics and overload_metrics["max_cum_overload_factor"] is not None: + writer.add_scalar("moe/max_cum_overload_factor", overload_metrics["max_cum_overload_factor"], iteration) + if wandb_writer: + if "avg_overload_factor" in overload_metrics: + wandb_writer.log({"moe/avg_overload_factor": overload_metrics["avg_overload_factor"]}, iteration) + if "max_overload_factor" in overload_metrics: + wandb_writer.log({"moe/max_overload_factor": overload_metrics["max_overload_factor"]}, iteration) + if "max_cum_overload_factor" in overload_metrics and overload_metrics["max_cum_overload_factor"] is not None: + wandb_writer.log({"moe/max_cum_overload_factor": overload_metrics["max_cum_overload_factor"]}, iteration) + clear_overload_factor_tracker() + return get_moe_metrics_tracker().report( loss_scale=loss_scale, iteration=iteration, diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index c9a2a469531..df8fe018b45 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -17,6 +17,7 @@ compute_routing_scores_for_aux_loss, get_tokens_per_expert_and_token_count, router_gating_linear, + save_overload_factor_to_tracker, sinkhorn, switch_load_balancing_loss_func, topk_routing_with_score_function, @@ -53,6 +54,8 @@ def __init__( self.cp_group = pg_collection.cp self.tp_cp_group = pg_collection.tp_cp self.tp_dp_cp_group = pg_collection.tp_dp_cp + self.tp_ep_group = pg_collection.tp_ep + self.dp_group = pg_collection.dp # Initialize the gate weights. # TODO: Add support for GPU initialization, which requires updating the golden values. @@ -704,6 +707,19 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No ) probs, routing_map = self.routing(logits, padding_mask=padding_mask) + # Log overload factor if enabled + if self.config.log_overload_factor: + # Compute num_local_experts from config and EP size + ep_size = self.tp_ep_group.size() // self.tp_group.size() + num_local_experts = self.config.num_moe_experts // ep_size + probs = save_overload_factor_to_tracker( + tensor=probs, + routing_map=routing_map, + layer_number=self.layer_number, + num_local_experts=num_local_experts, + tp_ep_group=self.tp_ep_group, + dp_group=self.dp_group, + ) return probs, routing_map diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 6e2d30c6fbb..19fc7844175 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -679,6 +679,11 @@ class TransformerConfig(ModelParallelConfig): If negative, generates bias once per layer and reuses it (abs value is std). This is an experimental feature for benchmarking purposes.""" + log_overload_factor: bool = False + """If True, log MoE router overload factors: avg_overload_factor, max_overload_factor + (load imbalance across EP ranks), and max_cum_overload_factor (peak cumulative tokens + ratio for forward/backward memory analysis).""" + moe_grouped_gemm: bool = False """When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped diff --git a/megatron/training/training.py b/megatron/training/training.py index 454b84d274a..9ba8ac4d295 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2044,6 +2044,7 @@ def training_log( wandb_writer.log({'max_attention_logit': max_attention_logit}, iteration) # Log MoE metrics. moe_log_string = "" + overload_dict = {} if args.num_experts is not None: moe_loss_scale = 1 / get_num_microbatches() track_names = [] @@ -2066,6 +2067,7 @@ def training_log( iteration=iteration, writer=writer, wandb_writer=wandb_writer, + overload_dict=overload_dict, per_layer_logging=args.moe_per_layer_logging, force_initialize=True, track_names=track_names, @@ -2175,6 +2177,13 @@ def training_log( total_loss_dict[advanced_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 + if overload_dict: + if "avg_overload_factor" in overload_dict: + log_string += f' avg overload factor: {overload_dict["avg_overload_factor"]:.3f} |' + if "max_overload_factor" in overload_dict: + log_string += f' max overload factor: {overload_dict["max_overload_factor"]:.3f} |' + if "max_cum_overload_factor" in overload_dict and overload_dict["max_cum_overload_factor"] is not None: + log_string += f' max cum overload factor: {overload_dict["max_cum_overload_factor"]:.3f} |' print_rank_last(log_string) reported_memory_in_this_iteration = False if report_memory_flag: From e650d6084c9daf16fccd48f576a830f36e386b55 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Mon, 19 Jan 2026 15:38:58 +0800 Subject: [PATCH 37/57] make overload factor logging work for cuda graph --- megatron/core/transformer/moe/moe_utils.py | 17 ++++++++++------- megatron/core/transformer/moe/router.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 84a18853669..edefe67356c 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -1046,6 +1046,15 @@ def save_overload_factor_to_tracker( """ # Set comm groups in tracker (outside autograd function) tracker = get_overload_factor_tracker() + if "to_clear" in tracker and tracker["to_clear"]: + if "fwd" in tracker: + tracker["fwd"].clear() + if "fwd_bwd" in tracker: + tracker["fwd_bwd"].clear() + tracker.pop("tp_ep_group", None) + tracker.pop("dp_group", None) + tracker.pop("to_clear", None) + if "tp_ep_group" not in tracker: tracker["tp_ep_group"] = tp_ep_group tracker["dp_group"] = dp_group @@ -1093,7 +1102,6 @@ def get_overload_factors_for_logging() -> dict: # Stack fwd_bwd tensors (already has fwd positive, bwd negative) fwd_bwd_tensors = tracker.get("fwd_bwd", []) fwd_bwd_tensors_stacked = torch.stack(fwd_bwd_tensors, dim=0) if fwd_bwd_tensors else None - # All-reduce across tp_ep_group, cumsum, and find max max_cum_overload_factor = None if fwd_bwd_tensors_stacked is not None: @@ -1163,12 +1171,7 @@ def get_overload_factors_for_logging() -> dict: def clear_overload_factor_tracker(): """Clear the overload factor tracker.""" tracker = get_overload_factor_tracker() - if "fwd" in tracker: - tracker["fwd"].clear() - if "fwd_bwd" in tracker: - tracker["fwd_bwd"].clear() - tracker.pop("tp_ep_group", None) - tracker.pop("dp_group", None) + tracker["to_clear"] = True @deprecated( diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index df8fe018b45..af8e08d755f 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -706,7 +706,7 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No logits, self.config.moe_router_force_biased, self.layer_number ) - probs, routing_map = self.routing(logits, padding_mask=padding_mask) + probs, routing_map = self.routing(logits) # Log overload factor if enabled if self.config.log_overload_factor: # Compute num_local_experts from config and EP size From a792a435d33c3cad560a1ab4a9406d38c1f8bf83 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Thu, 22 Jan 2026 13:35:23 +0800 Subject: [PATCH 38/57] 1. allocate stashing buffer based on avg token count if STASH_BUFFER_SIZE_FACTOR is positive. 2. fix int32 overflow in some triton kernels when token count is large 3. fix a problem where restored activation might get deallocate prematurely --- megatron/core/fusions/fused_bias_swiglu.py | 28 ++++--- megatron/core/transformer/moe/experts.py | 28 ++++++- megatron/core/transformer/moe/paged_stash.py | 83 ++++++++++++++++++-- 3 files changed, 117 insertions(+), 22 deletions(-) diff --git a/megatron/core/fusions/fused_bias_swiglu.py b/megatron/core/fusions/fused_bias_swiglu.py index b15081343f9..d50164ea59b 100644 --- a/megatron/core/fusions/fused_bias_swiglu.py +++ b/megatron/core/fusions/fused_bias_swiglu.py @@ -311,16 +311,17 @@ def _weighted_swiglu_fwd_kernel( # Strided access: each block handles tokens [pid, pid+num_blocks, ...] token_idx = pid while token_idx < num_tokens: + token_idx_i64 = token_idx.to(tl.int64) # Load weight for this token - weight = tl.load(weights_ptr + token_idx) + weight = tl.load(weights_ptr + token_idx_i64) # Process hidden dimension for h_offset in range(0, hidden_size, BLOCK_SIZE): h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size # Load input chunks (gate and value) - input_offset_1 = token_idx * (hidden_size * 2) + h_offset - input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset + input_offset_1 = token_idx_i64 * (hidden_size * 2) + h_offset + input_offset_2 = token_idx_i64 * (hidden_size * 2) + hidden_size + h_offset y1 = tl.load( input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 @@ -341,7 +342,7 @@ def _weighted_swiglu_fwd_kernel( result = silu_y1 * y2_fp32 * weight_fp32 # Store output (cast back to original dtype) - output_offset = token_idx * hidden_size + h_offset + output_offset = token_idx_i64 * hidden_size + h_offset tl.store( output_ptr + output_offset + tl.arange(0, BLOCK_SIZE), result.to(y1.dtype), @@ -376,8 +377,9 @@ def _weighted_swiglu_bwd_kernel( # Strided access token_idx = pid while token_idx < num_tokens: + token_idx_i64 = token_idx.to(tl.int64) # Load weight for this token - weight = tl.load(weights_ptr + token_idx) + weight = tl.load(weights_ptr + token_idx_i64) # Accumulator for weight gradient (fp32 for precision) weight_grad_acc = 0.0 @@ -387,14 +389,14 @@ def _weighted_swiglu_bwd_kernel( h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size # Load grad_output - grad_out_offset = token_idx * hidden_size + h_offset + grad_out_offset = token_idx_i64 * hidden_size + h_offset grad_out = tl.load( grad_output_ptr + grad_out_offset + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 ) # Load input chunks - input_offset_1 = token_idx * (hidden_size * 2) + h_offset - input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset + input_offset_1 = token_idx_i64 * (hidden_size * 2) + h_offset + input_offset_2 = token_idx_i64 * (hidden_size * 2) + hidden_size + h_offset y1 = tl.load( input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 @@ -439,7 +441,7 @@ def _weighted_swiglu_bwd_kernel( weight_grad_acc += tl.sum(weight_grad_contribution) # Store weight gradient after processing all chunks - tl.store(grad_weights_ptr + token_idx, weight_grad_acc) + tl.store(grad_weights_ptr + token_idx_i64, weight_grad_acc) # Stride to next token token_idx += num_blocks @@ -471,9 +473,13 @@ def weighted_swiglu_triton(input, weights, num_tokens_tensor): grid = (num_blocks,) _weighted_swiglu_fwd_kernel[grid]( - input, weights, output, num_tokens_tensor, hidden_size=hidden_size, BLOCK_SIZE=BLOCK_SIZE + input, + weights, + output, + num_tokens_tensor, + hidden_size=hidden_size, + BLOCK_SIZE=BLOCK_SIZE, ) - return output diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 311d1951edf..c317407514f 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -1019,10 +1019,18 @@ def forward( if self.config.moe_paged_stash: permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states) if self.moe_paged_stash_expert_fc1: + max_num_tokens = permuted_local_hidden_states.shape[0] + # Average/expected tokens is a pre-padding estimate used by paged stashing heuristics. + # moe_expert_rank_capacity_factor is required when moe_paged_stash is enabled. + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None + ) offload_context = get_paged_stash_context( name="expert_fc1", - max_num_tokens=permuted_local_hidden_states.shape[0], + max_num_tokens=max_num_tokens, num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, ) else: offload_context = nullcontext() @@ -1137,10 +1145,18 @@ def glu(x): else: with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: if self.moe_paged_stash_moe_act: + max_num_tokens = fc1_output.shape[0] + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) + if cap_factor is not None and cap_factor > 0 + else None + ) offload_context = get_paged_stash_context( name="moe_act", - max_num_tokens=fc1_output.shape[0], + max_num_tokens=max_num_tokens, num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, ) else: offload_context = nullcontext() @@ -1152,10 +1168,16 @@ def glu(x): ) if self.moe_paged_stash_expert_fc2: + max_num_tokens = bias_act_output.shape[0] + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None + ) offload_context = get_paged_stash_context( name="expert_fc2", - max_num_tokens=bias_act_output.shape[0], + max_num_tokens=max_num_tokens, num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, ) else: offload_context = nullcontext() diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 24c04d4eaee..8f00b5ecf5c 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -145,8 +145,11 @@ def _paged_stash_copy_kernel( need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 num_iters = elements_per_thread + (1 if need_mask else 0) - src_base = src_ptr + token_idx * HIDDEN_SIZE - dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE + # Use int64 for address math to avoid int32 overflow when indices get large. + token_idx_i64 = token_idx.to(tl.int64) + dst_token_idx_i64 = dst_token_idx.to(tl.int64) + src_base = src_ptr + token_idx_i64 * HIDDEN_SIZE + dst_base = dst_ptr + dst_token_idx_i64 * HIDDEN_SIZE if need_mask: for iter in range(num_iters): @@ -219,8 +222,11 @@ def _paged_stash_pop_kernel( need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 num_iters = elements_per_thread + (1 if need_mask else 0) - src_base = src_ptr + src_token_idx * HIDDEN_SIZE - dst_base = dst_ptr + token_idx * HIDDEN_SIZE + # Use int64 for address math to avoid int32 overflow when indices get large. + src_token_idx_i64 = src_token_idx.to(tl.int64) + token_idx_i64 = token_idx.to(tl.int64) + src_base = src_ptr + src_token_idx_i64 * HIDDEN_SIZE + dst_base = dst_ptr + token_idx_i64 * HIDDEN_SIZE if need_mask: for iter in range(num_iters): @@ -261,6 +267,7 @@ def __init__( self, tensor, num_tokens_tensor=None, + avg_num_tokens: int = None, vp_stage=None, schedule_layer_no=None, layer_name=None, @@ -284,6 +291,7 @@ def __init__( and num_tokens_tensor.numel() == 1 ) self.num_tokens_tensor = num_tokens_tensor.clone() + self.avg_num_tokens = avg_num_tokens self.vp_stage = vp_stage self.schedule_layer_no = schedule_layer_no self.layer_name = layer_name @@ -517,7 +525,7 @@ def __init__(self): """Initialize the manager with queues and dedicated CUDA streams.""" # allocate streams and events for synchronization self.enabled = False - self._pack_stream = torch.cuda.Stream() + self._pack_stream = torch.cuda.current_stream()#torch.cuda.Stream() # Currently paged stashing is not stream-safe, so use the same stream for packing # and unpacking self._unpack_stream = self._pack_stream @@ -543,9 +551,14 @@ def __init__(self): # Track max tokens needed across all vp_stages grouped by dtype and hidden_size self.max_tokens_across_vp_stages = None self.temp_tokens_across_vp_stages = None + # Track max tokens computed from avg_num_tokens (heuristic) across all vp_stages + self.max_avg_tokens_across_vp_stages = None + self.temp_avg_tokens_across_vp_stages = None self.num_tokens_tensor = None self.max_num_tokens = None + # Optional hint: expected/average number of tokens (e.g., pre-padding estimate) + self.avg_num_tokens = None self.stash_buffers = None self.overflow = None self.device = None @@ -663,12 +676,28 @@ def allocate_stash_buffers(self, stash_buffer_size_factor=1.10): self.stash_buffers = {} self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) - for dtype, hidden_size in self.max_tokens_across_vp_stages: + # stash_buffer_size_factor controls both which sizing signal to use and how much headroom + # to allocate: + # - positive: size based on avg_num_tokens-derived maxima + # - negative: size based on actual num_tokens-derived maxima (legacy behavior) + # In both cases we scale by abs(stash_buffer_size_factor). + if stash_buffer_size_factor >= 0: + max_tokens_dict = self.max_avg_tokens_across_vp_stages + scale = stash_buffer_size_factor + else: + max_tokens_dict = self.max_tokens_across_vp_stages + scale = -stash_buffer_size_factor + + # Fallback safety: if avg-based dict is not available/populated yet, use actual-max dict. + if not max_tokens_dict: + max_tokens_dict = self.max_tokens_across_vp_stages + + for dtype, hidden_size in max_tokens_dict: if dtype not in self.stash_buffers: self.stash_buffers[dtype] = {} assert hidden_size not in self.stash_buffers[dtype] num_tokens = int( - self.max_tokens_across_vp_stages[dtype, hidden_size] * stash_buffer_size_factor + max_tokens_dict[dtype, hidden_size] * scale ) self.stash_buffers[dtype][hidden_size] = PagedStashBuffer( num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype @@ -721,9 +750,13 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: tensor._rowwise_data is None ), f"rowwise_data is not None; Only columnwise data is supported for paged stashing" + avg_num_tokens = None if self.status == 'capture': self.num_tokens = self.num_tokens_tensor.item() + avg_num_tokens = ( + int(self.avg_num_tokens) if self.avg_num_tokens is not None else None + ) dtype = ( tensor.dtype @@ -743,12 +776,22 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 self.max_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] = 0 + self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = 0 self.temp_tokens_across_vp_stages[dtype, hidden_size] += self.num_tokens self.max_tokens_across_vp_stages[dtype, hidden_size] = max( self.max_tokens_across_vp_stages[dtype, hidden_size], self.temp_tokens_across_vp_stages[dtype, hidden_size], ) + + # Track avg tokens across vp stages (if provided) using the same accumulation model. + if avg_num_tokens is not None: + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] += avg_num_tokens + self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = max( + self.max_avg_tokens_across_vp_stages[dtype, hidden_size], + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size], + ) # Since capture stage does not use CUDA graph, we can truncate # the saved tensor to actual num_tokens new_size = (self.num_tokens, *tensor.shape[1:]) @@ -767,6 +810,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: paged_tensor = PagedTensor( tensor, num_tokens_tensor=self.num_tokens_tensor, + avg_num_tokens=avg_num_tokens, vp_stage=self.current_vp_stage, schedule_layer_no=( self._pp_schedule[self.current_schedule_index] @@ -791,6 +835,14 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: if isinstance(saved_state, (PagedTensor)): if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() + key = (saved_state.dtype, saved_state.hidden_size) + if key in self.temp_tokens_across_vp_stages: + self.temp_tokens_across_vp_stages[key] -= num_tokens + if ( + saved_state.avg_num_tokens is not None + and key in self.temp_avg_tokens_across_vp_stages + ): + self.temp_avg_tokens_across_vp_stages[key] -= int(saved_state.avg_num_tokens) # Pad the tensor to the max number of tokens npad = self.max_num_tokens - num_tokens pad = () @@ -811,6 +863,13 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: assert ( saved_state._tensor is not None ), f"saved_state._tensor is None {saved_state._tensor}" + + # Record cross-stream usage (important when tensor was produced on another stream). + if isinstance(saved_state._tensor, MXFP8Tensor): + saved_state._tensor._columnwise_data.record_stream(torch.cuda.current_stream()) + elif isinstance(saved_state._tensor, torch.Tensor) and saved_state._tensor.is_cuda: + saved_state._tensor.record_stream(torch.cuda.current_stream()) + return saved_state._tensor return saved_state @@ -855,12 +914,18 @@ def paged_stash_group_start(tensor): return PP_PreScheduleFunction.apply(tensor, stash_manager) -def get_paged_stash_context(name=None, max_num_tokens=None, num_tokens_tensor=None): +def get_paged_stash_context( + name=None, + max_num_tokens=None, + num_tokens_tensor=None, + avg_num_tokens=None, +): """Get the paged stash context""" stash_manager = PagedStashManager.get_instance() if not stash_manager.enabled: return nullcontext() stash_manager.max_num_tokens = max_num_tokens + stash_manager.avg_num_tokens = avg_num_tokens assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor) stash_manager.num_tokens_tensor = num_tokens_tensor stash_manager.set_current_layer_name(name) if name is not None else None @@ -891,6 +956,8 @@ def paged_stash_init_chunk_handler(vp_size, vp_stage): if stash_manager.max_tokens_across_vp_stages is None: stash_manager.max_tokens_across_vp_stages = {} stash_manager.temp_tokens_across_vp_stages = {} + stash_manager.max_avg_tokens_across_vp_stages = {} + stash_manager.temp_avg_tokens_across_vp_stages = {} def paged_stash_set_last_layer(is_last_layer=False): From 57b97143f08f00173ac83e85ec33b21058a3d59f Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Fri, 23 Jan 2026 11:25:41 +0800 Subject: [PATCH 39/57] Reenable overlapping of stashing kernels --- megatron/core/transformer/moe/paged_stash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 8f00b5ecf5c..54887d73336 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -525,7 +525,7 @@ def __init__(self): """Initialize the manager with queues and dedicated CUDA streams.""" # allocate streams and events for synchronization self.enabled = False - self._pack_stream = torch.cuda.current_stream()#torch.cuda.Stream() + self._pack_stream = torch.cuda.Stream() # Currently paged stashing is not stream-safe, so use the same stream for packing # and unpacking self._unpack_stream = self._pack_stream From 639509dac6c8b94a0ae45e6384bce1281f02834b Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Tue, 3 Feb 2026 19:30:35 +0800 Subject: [PATCH 40/57] Remove a buggy/redundant reset --- megatron/core/full_cuda_graph.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 28836b10b2a..ab72ac2d472 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -189,7 +189,6 @@ def __call__(self, *args, **kwargs): torch.cuda.synchronize() torch.distributed.barrier() logger.info(f'CUDA graph capture done for {training_str}!!!') - paged_stash_reset(enabled=self.moe_paged_stash and training) if FullCudaGraphWrapper.cuda_graph[training_str] is None: FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: From 5a0267f01217d56f45de12c07d5ba2c6db239844 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Mon, 9 Feb 2026 14:27:32 -0800 Subject: [PATCH 41/57] Cleanup moe-expert-rank-capacity-factor argument. --- megatron/training/arguments.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5627c491a0e..db86e9099ae 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1646,8 +1646,6 @@ def _add_inference_args(parser): group.add_argument('--use-legacy-static-engine', action='store_true', default=False, help='Use legacy static engine. (Current static engine uses dynamic engine under the hood)', dest='use_legacy_static_engine') - group.add_argument('--moe-expert-rank-capacity-factor', type=float, default=None, - help='The capacity factor for each EP rank when packed offloading is enabled.') group.add_argument('--inference-max-requests', type=int, default=8, help='Maximum number of requests for inference.', dest='inference_max_requests') From b8ee0e778c4c226c1617f83cc3e31d448a39f022 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Fri, 20 Feb 2026 20:14:32 -0800 Subject: [PATCH 42/57] Update moe_use_device_initiated_grouped_gemm check for paged stashing to use_transformer_engine_op_fuser Enforce Router padding for paged stashing Initial commit to enable paged stashing for TE fused op Enable stashing for 1D shape, colwise_scale_inv tensors Use moe_paged_stash to enable/disable stashing with fused op Use use_transformer_engine_op_fuser to enable/disable fused op Dynamic-shape no-stashing fallback for non-CG Dynamic-shape no-stashing fallback + Full CG Eliminate sync in mtp loss cal enable 1f1b overlap Add overflow check back temporarily before changes for PagedStashRunner is ready nanz/megatron-lm!1 - Paged stashing fallback --- megatron/core/full_cuda_graph.py | 24 +- megatron/core/transformer/moe/experts.py | 78 +++- megatron/core/transformer/moe/paged_stash.py | 386 ++++++++++++------ .../core/transformer/moe/token_dispatcher.py | 2 + .../transformer/multi_token_prediction.py | 5 +- .../core/transformer/transformer_block.py | 9 - .../core/transformer/transformer_config.py | 37 +- 7 files changed, 373 insertions(+), 168 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index ab72ac2d472..06144ed6e44 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -4,10 +4,10 @@ import logging +import gc import torch from megatron.core.tensor_parallel.random import get_all_rng_states -from megatron.core.transformer.moe.paged_stash import check_paged_stash_overflow, paged_stash_reset logger = logging.getLogger(__name__) @@ -190,11 +190,11 @@ def __call__(self, *args, **kwargs): torch.distributed.barrier() logger.info(f'CUDA graph capture done for {training_str}!!!') if FullCudaGraphWrapper.cuda_graph[training_str] is None: - FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) + # do it in a side stream + with torch.cuda.stream(torch.cuda.Stream()): + FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: FullCudaGraphWrapper.cuda_graph[training_str].replay() - check_paged_stash_overflow() - self.speculative_cuda_graph_check(model) self.next_iter(training_str) return FullCudaGraphWrapper.result[training_str] @@ -220,3 +220,19 @@ def curr_iter(self, stage): def next_iter(self, stage): """Increment current training/validation iteration.""" FullCudaGraphWrapper.curr_iteration[stage] += 1 + + def reset_cuda_graph(self, stage=None): + """Reset CUDA graph.""" + if stage is None or stage == 'training': + if FullCudaGraphWrapper.cuda_graph['training'] is not None: + del FullCudaGraphWrapper.cuda_graph['training'] + FullCudaGraphWrapper.cuda_graph['training'] = None + FullCudaGraphWrapper.result['training'] = None + FullCudaGraphWrapper.curr_iteration['training'] = 0 + if stage is None or stage == 'validation': + if FullCudaGraphWrapper.cuda_graph['validation'] is not None: + del FullCudaGraphWrapper.cuda_graph['validation'] + FullCudaGraphWrapper.cuda_graph['validation'] = None + FullCudaGraphWrapper.result['validation'] = None + FullCudaGraphWrapper.curr_iteration['validation'] = 0 + gc.collect() \ No newline at end of file diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index c317407514f..7d7d67a575e 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -540,6 +540,36 @@ def backward_dw(self): """ pass +def _pad_unpad(inp, pad): + if pad: + if inp.ndim == 2: + result = torch.nn.functional.pad(inp, (0, 0, 0, 256)) + elif inp.ndim == 1: + result = torch.nn.functional.pad(inp, (0, 256)) + else: + raise ValueError(f"Input dimension {inp.ndim} not supported") + else: + if inp.ndim == 2: + result = inp[:-256, :] + elif inp.ndim == 1: + result = inp[:-256] + else: + raise ValueError(f"Input dimension {inp.ndim} not supported") + assert result.shape[0] == 0 + return result + + +class PadUnpadFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp, pad): + result = _pad_unpad(inp, pad) + ctx.pad = not pad + return result + + @staticmethod + def backward(ctx, grad_out): + return _pad_unpad(grad_out, ctx.pad), None + class GroupedLinearFc1Interface(Protocol): """Interface for linear_fc1 module in TEGroupedMLP.""" @@ -806,6 +836,8 @@ def _is_fused_impl_supported(self) -> bool: if self.activation_func != F.silu or not self.config.gated_linear_unit: return False # Expected SwiGLU activation + if not self.config.use_transformer_engine_op_fuser: + return False return True def _make_fused_ops(self) -> torch.nn.Module: @@ -919,6 +951,8 @@ def _fused_forward( ) -> torch.Tensor: """Forward pass using Transformer Engine operation fuser API.""" + if self.config.moe_expert_rank_capacity_factor is not None: + assert self.config.moe_router_padding_for_quantization, "moe_expert_rank_capacity_factor requires moe_router_padding_for_quantization" # Construct fused impl if needed # Note: We initialize during the first forward pass in case # the params are modified after the constructor. @@ -946,19 +980,45 @@ def _fused_forward( tokens_per_expert = torch.tensor( tokens_per_expert, dtype=torch.int, device=permuted_probs.device ) + # if the number of tokens is 0, pad the hidden states to 256 + apply_pad_unpad = False + if permuted_local_hidden_states.shape[0] == 0 and not torch.cuda.is_current_stream_capturing(): + apply_pad_unpad = True + permuted_local_hidden_states = PadUnpadFunction.apply(permuted_local_hidden_states, True) + permuted_probs = PadUnpadFunction.apply(permuted_probs, True) - # Call fused impl - output = ops( - permuted_local_hidden_states, - tokens_per_expert, # FC1 - permuted_probs, # Scaled SwiGLU - tokens_per_expert, # FC2 - ) - + if self.config.moe_paged_stash: + permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states) + max_num_tokens = permuted_local_hidden_states.shape[0] + # Average/expected tokens is a pre-padding estimate used by paged stashing heuristics. + # moe_expert_rank_capacity_factor is required when moe_paged_stash is enabled. + cap_factor = self.config.moe_expert_rank_capacity_factor + avg_num_tokens = ( + int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None + ) + offload_context = get_paged_stash_context( + name="expert_fc1_fused", + max_num_tokens=max_num_tokens, + num_tokens_tensor=tokens_per_expert.sum(), + avg_num_tokens=avg_num_tokens, + ) + else: + offload_context = nullcontext() + with offload_context: + # Call fused impl + output = ops( + permuted_local_hidden_states, + tokens_per_expert, # FC1 + permuted_probs, # Scaled SwiGLU + tokens_per_expert, # FC2 + ) + if apply_pad_unpad: + output = PadUnpadFunction.apply(output, False) # Remove padding if needed if unpadded_tokens_per_expert is not None: output = self.quantization_unpadding(output, unpadded_tokens_per_expert) - + if self.config.moe_paged_stash: + output = paged_stash_group_commit(output, name="expert_fc1_fused") return output def forward( diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 54887d73336..483ebf7e16b 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -7,10 +7,12 @@ import torch import triton import triton.language as tl -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -GLOBAL_BLOCK_SIZE = 1024 +from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer +from megatron.core.full_cuda_graph import FullCudaGraphWrapper +GLOBAL_BLOCK_SIZE = 1024 +SCALE_INV_BLOCK_SIZE = 32 class PagedStashBuffer: """ @@ -269,9 +271,11 @@ def __init__( num_tokens_tensor=None, avg_num_tokens: int = None, vp_stage=None, + original_shape=None, schedule_layer_no=None, layer_name=None, - max_tokens=None, + max_num_tokens=None, + hidden_size=None, page_size=64, ): """ @@ -280,7 +284,8 @@ def __init__( num_tokens_tensor: Scalar tensor containing actual number of tokens vp_stage: Virtual pipeline stage layer_name: Name of the layer - max_tokens: Maximum number of tokens + max_num_tokens: Maximum number of tokens + hidden_size: Hidden size page_size: Number of tokens per page """ self._tensor = tensor @@ -295,17 +300,14 @@ def __init__( self.vp_stage = vp_stage self.schedule_layer_no = schedule_layer_no self.layer_name = layer_name - self.max_tokens = max_tokens + self.max_num_tokens = max_num_tokens + self.hidden_size = hidden_size self.page_size = page_size # Original tensor information - self.original_shape = list(tensor.shape) - self.max_num_tokens = self.original_shape[0] + self.original_shape = list(tensor.shape) if original_shape is None else original_shape self.element_size = tensor.element_size() - self.hidden_size = self.original_shape[1] - self.dtype = ( - tensor.dtype if not isinstance(tensor, MXFP8Tensor) else tensor._columnwise_data.dtype - ) + self.dtype = tensor.dtype self.device = tensor.device # Calculate number of pages needed @@ -329,16 +331,19 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 self._tensor = self._tensor.contiguous() if self.num_tokens_tensor.dim() == 0: self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) - - # Get 2D tensor - if isinstance(self._tensor, MXFP8Tensor): - tensor_to_copy = self._tensor._columnwise_data + if 'columnwise_scale_inv' in self.layer_name: + num_tokens_tensor = self.num_tokens_tensor // SCALE_INV_BLOCK_SIZE + max_num_tokens = self.max_num_tokens // SCALE_INV_BLOCK_SIZE else: - tensor_to_copy = self._tensor + num_tokens_tensor = self.num_tokens_tensor + max_num_tokens = self.max_num_tokens + + # Get 1D tensor + tensor_to_copy = self._tensor # Determine grid size BLOCK_SIZE = GLOBAL_BLOCK_SIZE - num_blocks = min(self.max_num_tokens, max_blocks) + num_blocks = min(max_num_tokens, max_blocks) grid = (num_blocks,) # Create temporary tensor for new head @@ -346,9 +351,9 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 # Launch paged stash copy kernel _paged_stash_copy_kernel[grid]( - tensor_to_copy, + tensor_to_copy.view(paged_stash_buffer.buffer.dtype), paged_stash_buffer.buffer, - self.num_tokens_tensor, + num_tokens_tensor, paged_stash_buffer.free_list, paged_stash_buffer.free_list_head, paged_stash_buffer.free_list_tail, @@ -376,26 +381,18 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 max_blocks: Maximum number of blocks for Triton kernel """ # Allocate output tensor - if isinstance(self._original_tensor, MXFP8Tensor): - columnwise_data = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) - self._tensor = MXFP8Tensor( - shape=self._original_tensor.shape, - dtype=self._original_tensor.dtype, - fp8_dtype=self._original_tensor._fp8_dtype, - rowwise_data=self._original_tensor._rowwise_data, - rowwise_scale_inv=self._original_tensor._rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=self._original_tensor._columnwise_scale_inv, - quantizer=self._original_tensor._quantizer, - ) - tensor_to_reload = self._tensor._columnwise_data - else: - self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) - tensor_to_reload = self._tensor + self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) + tensor_to_reload = self._tensor + if 'columnwise_scale_inv' in self.layer_name: + num_tokens_tensor = self.num_tokens_tensor // SCALE_INV_BLOCK_SIZE + max_num_tokens = self.max_num_tokens // SCALE_INV_BLOCK_SIZE + else: + num_tokens_tensor = self.num_tokens_tensor + max_num_tokens = self.max_num_tokens # Determine grid size BLOCK_SIZE = GLOBAL_BLOCK_SIZE - num_blocks = min(self.max_num_tokens, max_blocks) + num_blocks = min(max_num_tokens, max_blocks) grid = (num_blocks,) # Create temporary tensor for new tail @@ -404,8 +401,8 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 # Launch paged stash pop kernel _paged_stash_pop_kernel[grid]( paged_stash_buffer.buffer, - tensor_to_reload, - self.num_tokens_tensor, + tensor_to_reload.view(paged_stash_buffer.buffer.dtype), + num_tokens_tensor, self.page_record, # Triton kernel will read from page_record paged_stash_buffer.free_list, paged_stash_buffer.free_list_head, @@ -459,7 +456,6 @@ class PP_PostScheduleFunction(torch.autograd.Function): @staticmethod def forward(ctx, tensor, stash_manager): # after forward # pylint: disable=missing-function-docstring - ctx.stash_manager = stash_manager ctx.vp_stage = stash_manager.current_vp_stage if ctx.vp_stage is None: @@ -539,7 +535,6 @@ def __init__(self): self._current_layer_name = None self.vp_size = None self.current_vp_stage = None - self._last_layer = False self.status = 'begin' # begin, capture, captured # If element is +ve, it denotes forward pass of vp stage, # if -ve, it denotes backward pass of vp stage @@ -616,7 +611,6 @@ def stash_paged_tensors(self, pp_schedule_layer): f"paged_tensors_to_reload {pp_schedule_layer} is not empty " f"{self.paged_tensors_to_reload[pp_schedule_layer]}" ) - while len(self.paged_tensors_to_stash) > 0: paged_tensor = self.paged_tensors_to_stash.pop(0) stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] @@ -639,10 +633,7 @@ def wait_for_stash_to_complete(self): # Deallocate original tensor after stash is complete while len(self.paged_tensors_stash_in_progress) > 0: paged_tensor = self.paged_tensors_stash_in_progress.pop(0) - if isinstance(paged_tensor._original_tensor, MXFP8Tensor): - paged_tensor._original_tensor._columnwise_data = None - else: - paged_tensor._original_tensor = None + paged_tensor._original_tensor = None def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): """Reload the paged tensors.""" @@ -659,7 +650,6 @@ def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): for item in self.paged_tensors_to_reload: if len(self.paged_tensors_to_reload[item]) > 0: count += 1 - while len(self.paged_tensors_to_reload[pp_schedule_layer]) > 0: paged_tensor = self.paged_tensors_to_reload[pp_schedule_layer].pop(0) stash_buffer = self.stash_buffers[paged_tensor.dtype][paged_tensor.hidden_size] @@ -700,16 +690,14 @@ def allocate_stash_buffers(self, stash_buffer_size_factor=1.10): max_tokens_dict[dtype, hidden_size] * scale ) self.stash_buffers[dtype][hidden_size] = PagedStashBuffer( - num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype + num_tokens, hidden_size, self.page_size, self.device, self.overflow, torch.uint8 if dtype in [torch.float8_e4m3fn, torch.float8_e8m0fnu] else dtype ) + print (f'allocate_stash_buffers num_tokens: {self.stash_buffers[dtype][hidden_size].buffer.shape}-{self.stash_buffers[dtype][hidden_size].dtype} ({dtype})') def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): """Update the pp schedule.""" if self._pp_schedule is None: self._pp_schedule = [] - # current layer and microbatch for each vp stage for forward pass - self.current_layer = [1 for _ in range(self.vp_size)] - self.current_microbatch = [1 for _ in range(self.vp_size)] assert self.vp_size is not None if layer_no is None: @@ -718,9 +706,6 @@ def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): layer_no = self.current_layer[vp_stage_index] self.current_layer[vp_stage_index] += 1 microbatch_no = self.current_microbatch[vp_stage_index] - if self._last_layer: - self.current_layer[vp_stage_index] = 1 - self.current_microbatch[vp_stage_index] += 1 if self.status == 'capture': self._pp_schedule.append(self.get_schedule_layer(vp_stage, layer_no, microbatch_no)) @@ -732,54 +717,66 @@ def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): return layer_no, microbatch_no + + def update_model_chunk(self, vp_stage_index): + """Update layer=1, increment microbatch of new vp vp_stage.""" + if self.current_layer is None: + # current layer and microbatch for each vp stage for forward pass + self.current_layer = [1 for _ in range(self.vp_size)] + self.current_microbatch = [0 for _ in range(self.vp_size)] + self.current_layer[vp_stage_index] = 1 + self.current_microbatch[vp_stage_index] += 1 + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: """ Hook called when autograd saves a tensor for backward pass. Returns a tag to identify the tensor later. """ - # Handle 0-dim tensors (torch.Size([])) - they have no size(0) if ( self.max_num_tokens is None or tensor.dim() == 0 - or tensor.size(0) != self.max_num_tokens + or not hasattr(tensor, 'grouped_name') + or (tensor.size(0) != self.max_num_tokens and (tensor.logical_shape is None or tensor.logical_shape[0] != self.max_num_tokens)) ): return tensor.detach() - if isinstance(tensor, MXFP8Tensor): - assert ( - tensor._rowwise_data is None - ), f"rowwise_data is not None; Only columnwise data is supported for paged stashing" + + assert isinstance(tensor, torch.Tensor), f"tensor is not a torch.Tensor {type(tensor)}" + #if hasattr(tensor, 'grouped_name'): + # print (f'on_save_for_backward {self.status} tensor: num_tokens: {self.num_tokens_tensor.item()}-{type(tensor)}-{tensor.shape}-{tensor.dtype}-{hex(tensor.data_ptr())}-grouped: name: {tensor.grouped_name} element_size: {tensor.element_size()} logical_shape: {tensor.logical_shape if hasattr(tensor, 'logical_shape') else None}') + #else: + # print (f'on_save_for_backward {self.status} tensor: num_tokens: {self.num_tokens_tensor.item()}-{type(tensor)}-{tensor.shape}-{tensor.dtype}-{hex(tensor.data_ptr())} element_size: {tensor.element_size()} logical_shape: {tensor.logical_shape if hasattr(tensor, 'logical_shape') else None}') + + original_shape = tensor.shape + grouped_name = tensor.grouped_name + tensor = tensor.flatten() + dtype = tensor.dtype + columnwise_scale_inv = 'columnwise_scale_inv' in grouped_name + hidden_size = tensor.numel() // (self.max_num_tokens if not columnwise_scale_inv else self.max_num_tokens // SCALE_INV_BLOCK_SIZE) + + if self.max_tokens_across_vp_stages is None: + self.max_tokens_across_vp_stages = {} + self.temp_tokens_across_vp_stages = {} + self.max_avg_tokens_across_vp_stages = {} + self.temp_avg_tokens_across_vp_stages = {} avg_num_tokens = None if self.status == 'capture': self.num_tokens = self.num_tokens_tensor.item() + actual_num_tokens = self.num_tokens // SCALE_INV_BLOCK_SIZE if columnwise_scale_inv else self.num_tokens + avg_num_tokens = ( int(self.avg_num_tokens) if self.avg_num_tokens is not None else None ) - dtype = ( - tensor.dtype - if not isinstance(tensor, MXFP8Tensor) - else tensor._columnwise_data.dtype - ) - # Get hidden_size from tensor shape - if isinstance(tensor, MXFP8Tensor): - hidden_size = ( - tensor._columnwise_data.shape[1] - if tensor._columnwise_data.ndim > 1 - else tensor._columnwise_data.numel() - ) - else: - hidden_size = tensor.shape[1] if tensor.ndim > 1 else tensor.numel() - if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages: self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0 self.max_tokens_across_vp_stages[dtype, hidden_size] = 0 self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] = 0 self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = 0 - self.temp_tokens_across_vp_stages[dtype, hidden_size] += self.num_tokens + self.temp_tokens_across_vp_stages[dtype, hidden_size] += actual_num_tokens self.max_tokens_across_vp_stages[dtype, hidden_size] = max( self.max_tokens_across_vp_stages[dtype, hidden_size], self.temp_tokens_across_vp_stages[dtype, hidden_size], @@ -787,39 +784,36 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any: # Track avg tokens across vp stages (if provided) using the same accumulation model. if avg_num_tokens is not None: - self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] += avg_num_tokens + self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] += (avg_num_tokens if not columnwise_scale_inv else avg_num_tokens // SCALE_INV_BLOCK_SIZE) self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = max( self.max_avg_tokens_across_vp_stages[dtype, hidden_size], self.temp_avg_tokens_across_vp_stages[dtype, hidden_size], ) + # Since capture stage does not use CUDA graph, we can truncate # the saved tensor to actual num_tokens - new_size = (self.num_tokens, *tensor.shape[1:]) + new_size = (actual_num_tokens * hidden_size,) - if isinstance(tensor, MXFP8Tensor): - tensor_truncated = torch.empty( - new_size, dtype=tensor._columnwise_data.dtype, device=tensor.device - ) - tensor_truncated.copy_(tensor._columnwise_data[: self.num_tokens, ...]) - tensor._columnwise_data = tensor_truncated - else: - tensor_truncated = torch.empty(new_size, dtype=tensor.dtype, device=tensor.device) - tensor_truncated.copy_(tensor[: self.num_tokens, ...]) - tensor = tensor_truncated + tensor_truncated = torch.empty(new_size, dtype=dtype, device=tensor.device) + tensor_truncated.copy_(tensor[: actual_num_tokens * hidden_size]) + tensor = tensor_truncated + tensor.grouped_name = grouped_name paged_tensor = PagedTensor( tensor, num_tokens_tensor=self.num_tokens_tensor, avg_num_tokens=avg_num_tokens, vp_stage=self.current_vp_stage, + original_shape=original_shape, schedule_layer_no=( self._pp_schedule[self.current_schedule_index] if self._pp_schedule is not None and self.current_schedule_index < len(self._pp_schedule) else None ), - layer_name=self._current_layer_name, - max_tokens=self.max_num_tokens, + layer_name=tensor.grouped_name, + max_num_tokens=self.max_num_tokens, + hidden_size=hidden_size, page_size=self.page_size, ) @@ -833,44 +827,41 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: Returns the actual tensor (potentially reloading from CPU). """ if isinstance(saved_state, (PagedTensor)): + columnwise_scale_inv = 'columnwise_scale_inv' in saved_state.layer_name if self.status == 'capture': num_tokens = saved_state.num_tokens_tensor.item() key = (saved_state.dtype, saved_state.hidden_size) if key in self.temp_tokens_across_vp_stages: - self.temp_tokens_across_vp_stages[key] -= num_tokens + self.temp_tokens_across_vp_stages[key] -= (num_tokens if not columnwise_scale_inv else num_tokens // SCALE_INV_BLOCK_SIZE) if ( saved_state.avg_num_tokens is not None and key in self.temp_avg_tokens_across_vp_stages ): - self.temp_avg_tokens_across_vp_stages[key] -= int(saved_state.avg_num_tokens) + self.temp_avg_tokens_across_vp_stages[key] -= (int(saved_state.avg_num_tokens) if not columnwise_scale_inv else int(saved_state.avg_num_tokens) // SCALE_INV_BLOCK_SIZE) + + # Handle 1-byte tensors (torch.uint8) + dtype = saved_state._tensor.dtype + if saved_state._tensor.element_size() == 1: + saved_state._tensor = saved_state._tensor.view(torch.uint8) + # Pad the tensor to the max number of tokens - npad = self.max_num_tokens - num_tokens - pad = () - # check if the tensor is 2D - assert ( - saved_state._tensor.ndim == 2 - ), f"saved_state._tensor.ndim is not 2 {saved_state._tensor.ndim}" - for _ in range(saved_state._tensor.ndim - 1): - pad = pad + (0, 0) - pad = pad + (0, npad) - if isinstance(saved_state._tensor, MXFP8Tensor): - saved_state._tensor._columnwise_data = torch.nn.functional.pad( - saved_state._tensor._columnwise_data, pad - ) - else: - saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad) + # check if the tensor is 1D + assert saved_state._tensor.ndim == 1, f"saved_state._tensor.ndim is not 1 {saved_state._tensor.ndim}" + npad = (self.max_num_tokens - num_tokens) * saved_state.hidden_size + if columnwise_scale_inv: + npad = npad // SCALE_INV_BLOCK_SIZE + pad = (0, npad) + saved_state._tensor = torch.nn.functional.pad(saved_state._tensor, pad).view(dtype) assert ( saved_state._tensor is not None ), f"saved_state._tensor is None {saved_state._tensor}" # Record cross-stream usage (important when tensor was produced on another stream). - if isinstance(saved_state._tensor, MXFP8Tensor): - saved_state._tensor._columnwise_data.record_stream(torch.cuda.current_stream()) - elif isinstance(saved_state._tensor, torch.Tensor) and saved_state._tensor.is_cuda: + if isinstance(saved_state._tensor, torch.Tensor) and saved_state._tensor.is_cuda: saved_state._tensor.record_stream(torch.cuda.current_stream()) - return saved_state._tensor + return saved_state._tensor.view(saved_state.original_shape) return saved_state @@ -946,19 +937,9 @@ def paged_stash_group_commit(tensor, name=None): def paged_stash_init_chunk_handler(vp_size, vp_stage): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" stash_manager = PagedStashManager.get_instance() - if not stash_manager.enabled: - return + stash_manager.vp_size = vp_size if vp_size is not None else 1 stash_manager.current_vp_stage = vp_stage if vp_stage is not None else 0 - if vp_size is not None: - stash_manager.vp_size = vp_size - else: - stash_manager.vp_size = 1 - if stash_manager.max_tokens_across_vp_stages is None: - stash_manager.max_tokens_across_vp_stages = {} - stash_manager.temp_tokens_across_vp_stages = {} - stash_manager.max_avg_tokens_across_vp_stages = {} - stash_manager.temp_avg_tokens_across_vp_stages = {} - + stash_manager.update_model_chunk(stash_manager.current_vp_stage) def paged_stash_set_last_layer(is_last_layer=False): """Set the last layer flag.""" @@ -967,7 +948,6 @@ def paged_stash_set_last_layer(is_last_layer=False): return stash_manager._last_layer = is_last_layer - def paged_stash_reset(enabled=True): """Reset the chunk handler, called at the start of a training iteration.""" stash_manager = PagedStashManager.get_instance() @@ -983,6 +963,7 @@ def paged_stash_reset(enabled=True): stash_manager.status = 'capture' elif stash_manager.status == 'capture': stash_manager.status = 'captured' + print (f'schedule {stash_manager._pp_schedule}') stash_buffer_size_factor = float(os.getenv('STASH_BUFFER_SIZE_FACTOR', '1.10')) stash_manager.allocate_stash_buffers(stash_buffer_size_factor=stash_buffer_size_factor) elif stash_manager.status == 'captured': @@ -998,7 +979,7 @@ def paged_stash_reset(enabled=True): stash_manager.stash_buffers[dtype][hidden_size].reset() stash_manager.overflow.zero_() stash_manager.current_layer = [1 for _ in range(stash_manager.vp_size)] - stash_manager.current_microbatch = [1 for _ in range(stash_manager.vp_size)] + stash_manager.current_microbatch = [0 for _ in range(stash_manager.vp_size)] assert ( len(stash_manager.paged_tensors_to_stash) == 0 ), f"paged_tensors_to_stash is not empty {stash_manager.paged_tensors_to_stash}" @@ -1007,12 +988,159 @@ def paged_stash_reset(enabled=True): f"{stash_manager.paged_tensors_stash_in_progress}" ) - def check_paged_stash_overflow(): """Check if paged stash overflow""" stash_manager = PagedStashManager.get_instance() if not stash_manager.enabled or stash_manager.overflow is None: - return - overflow = stash_manager.overflow.item() - if overflow != 0: - raise RuntimeError("PagedStashManager overflow!!!") + return torch.zeros(1, dtype=torch.bool, device='cuda') + overflow = stash_manager.overflow.ne(0) + return overflow + +class PagedStashRunner: + """Runner for paged stash""" + + def __init__(self, config, copy_main_params, model, optimizer, forward_backward_func): + self.stash_manager = PagedStashManager.get_instance() + self.config = config + self.copy_main_params = copy_main_params + self.model = model + self.optimizer = optimizer + self.forward_backward_func = forward_backward_func + self.moe_layers = [] + for model_chunk in self.model: + for layer in model_chunk.module.module.decoder.layers: + mlp = layer.mlp + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher, 'check_over_budget' + ): + self.moe_layers.append(mlp) + if model_chunk.module.module.mtp_process: + for layer in model_chunk.module.module.mtp.layers: + mlp = layer.mtp_model_layer.mlp + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher, 'check_over_budget' + ): + self.moe_layers.append(mlp) + print (f"PagedStashRunner: Moe layers {len(self.moe_layers)}!!!") + + def data_read(self, data_iterator, model, training, num_microbatches): + """Read all microbatch inputs from Dataloader and copy to static buffers.""" + data_iterator_saved = [] + if not isinstance(model, list) or len(model) == 1: + assert not isinstance(data_iterator, list) or len(data_iterator) == 1 + iterator0 = data_iterator if not isinstance(data_iterator, list) else data_iterator[0] + data_list = [] + if iterator0 is not None: + for b in range(num_microbatches): + data_list.append(next(iterator0)) + data_iterator_saved.extend(data_list) + data_list = [iter(data_list)] + else: + data_list.append(None) + else: + assert isinstance(data_iterator, list) and len(data_iterator) == len(model) + data_list = [] + for i in range(len(model)): + if data_iterator[i] is not None: + data_list_i = [] + for b in range(num_microbatches): + data_list_i.append(next(data_iterator[i])) + data_iterator_saved.extend(data_list_i) + data_list.append(iter(data_list_i)) + else: + data_list.append(None) + return iter(data_iterator_saved), data_list + + def check_moe_overflow(self): + # check for paged stash overflow + overflow = check_paged_stash_overflow() + # check for token dispatcher overflow + for mlp in self.moe_layers: + overbudget = mlp.token_dispatcher.check_over_budget() + overflow |= overbudget + + overflow_int = overflow.to(torch.int32) + torch.distributed.all_reduce(overflow_int, op=torch.distributed.ReduceOp.SUM) + overflow_int = overflow_int.item() + return overflow_int + + def prepare_for_rerun(self, is_training=True): + """Prepare for rerun""" + print (f"!!!!!!!! Attempting to run without expert_rank_capacity_factor padding") + # check for token dispatcher overflow + for mlp in self.moe_layers: + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher._comm_manager, 'moe_expert_rank_capacity_factor' + ): + mlp.token_dispatcher._comm_manager.moe_expert_rank_capacity_factor = None + mlp.token_dispatcher._comm_manager.over_budget.fill_(0) + self.stash_manager.overflow.zero_() + self.config.moe_paged_stash = False + + # Set grad to zero. + for model_chunk in self.model: + model_chunk.zero_grad_buffer() + self.optimizer.zero_grad() + + #_handle_mxfp8_param_buffer_copy + if self.copy_main_params: + def _try_copy_main_params(opt): + if isinstance(opt, DistributedOptimizer) and hasattr(opt, 'shard_fp32_from_float16_groups'): + opt._copy_main_params_to_param_buffer() + # Handle both ChainedOptimizer and direct DistributedOptimizer cases + # Note: FSDP's DistributedOptimizer doesn't have shard_fp32_from_float16_groups, + # so we check for this attribute before calling _copy_main_params_to_param_buffer + if hasattr(self.optimizer, 'chained_optimizers'): + for optim_instance in self.optimizer.chained_optimizers: + _try_copy_main_params(optim_instance) + else: + _try_copy_main_params(self.optimizer) + + # Delete the CUDA graph + if isinstance(self.forward_backward_func, FullCudaGraphWrapper): + self.forward_backward_func.reset_cuda_graph(stage='training' if is_training else 'validation') + + def __call__(self, *args, **kwargs): + """Run the paged stash""" + assert len(args) == 0, 'forward_backward_func does not accept positional args' + assert all( + [ + kwarg in kwargs + for kwarg in [ + 'model', + 'data_iterator', + 'num_microbatches', + 'seq_length', + 'forward_only', + ] + ] + ) + model = kwargs['model'] + num_microbatches = kwargs['num_microbatches'] + + training = not kwargs['forward_only'] + data_iterator = kwargs['data_iterator'] + saved_moe_paged_stash = self.config.moe_paged_stash + num_tries = 0 + while True: + assert num_tries < 2, f"PagedStashRunner: num_tries {num_tries} exceeded max attempts!!!" + num_tries += 1 + data_iterator, data_list = self.data_read(data_iterator, model, training, num_microbatches) + kwargs['data_iterator'] = data_list + result = self.forward_backward_func(*args, **kwargs) + + overflow_int = self.check_moe_overflow() + # if no overflow, set the expert_rank_capacity_factor to the original value + if overflow_int == 0: + for mlp in self.moe_layers: + if hasattr(mlp, 'token_dispatcher') and hasattr( + mlp.token_dispatcher._comm_manager, 'moe_expert_rank_capacity_factor' + ): + mlp.token_dispatcher._comm_manager.moe_expert_rank_capacity_factor = mlp.token_dispatcher.config.moe_expert_rank_capacity_factor + self.config.moe_paged_stash = saved_moe_paged_stash + break + + # if overflow, set the expert_rank_capacity_factor to None + print(f"MBridge train(): PagedStashManager or Token Dispatcher overflow detected across ranks {overflow_int} config.moe_paged_stash: {self.config.moe_paged_stash}!!!") + self.prepare_for_rerun(is_training=training) + return result diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 284217d356c..d0a46606645 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1086,6 +1086,8 @@ def dispatch( if self.moe_expert_rank_capacity_factor is not None: over_budget = self.handle[8] != 0 # this is overflow_flag self.over_budget |= over_budget + if self.config.cuda_graph_impl == 'none': + assert not self.over_budget.item(), 'Over budget' if self.num_permuted_tokens is None: self.tokens_per_expert = tokens_per_expert.to(torch.int64) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 7b8a764b813..c71c9377542 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -685,9 +685,10 @@ def process_mtp_loss( mtp_loss = compute_language_model_loss(mtp_labels, mtp_logits) mtp_loss = loss_mask * mtp_loss if is_training: + # Safe divide without sync: mask numerator when num_tokens==0, divide by clamp(min=1) mtp_loss_for_log = ( - torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0) - ) + torch.sum(mtp_loss) * (num_tokens > 0).to(mtp_loss.dtype) + ) / num_tokens.clamp(min=1) MTPLossLoggingHelper.save_loss_to_tracker( mtp_loss_for_log, mtp_layer_number, diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index b4e560868fa..2af2eeb9a8f 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -17,16 +17,12 @@ from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer.moe.paged_stash import ( - paged_stash_set_last_layer, -) from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import CheckpointManager from megatron.core.transformer.enums import CudaGraphScope, LayerType from megatron.core.transformer.hyper_connection import HyperConnectionModule from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule -from megatron.core.transformer.moe.paged_stash import paged_stash_set_last_layer from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.torch_norm import LayerNormBuilder from megatron.core.transformer.transformer_config import TransformerConfig @@ -896,11 +892,6 @@ def forward( mhc_manager.is_last_layer_in_recompute_block = ( mhc_is_last_in_recompute_block[l_no] ) - if self.config.moe_paged_stash: - paged_stash_set_last_layer( - is_last_layer=(l_no == self.num_layers_per_pipeline_rank - 1) - ) - with self.offload_context, inner_quantization_context: hidden_states, context = layer( hidden_states=hidden_states, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 19fc7844175..2a280b8ee19 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1267,10 +1267,10 @@ def __post_init__(self): ) if self.moe_expert_rank_capacity_factor is not None: - if not self.moe_use_device_initiated_grouped_gemm: + if not self.use_transformer_engine_op_fuser: raise ValueError( "moe_expert_rank_capacity_factor requires " - "moe_use_device_initiated_grouped_gemm to be enabled." + "use_transformer_engine_op_fuser to be enabled." ) if self.moe_flex_dispatcher_backend != "hybridep": raise ValueError( @@ -1479,10 +1479,7 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) - if self.moe_paged_stash: - assert ( - not self.cpu_offloading and not self.fine_grained_activation_offloading - ), "paged_stash cannot be enabled with cpu_offloading." + if self.moe_paged_stash:# vasu assert ( self.stash_modules is not None and len(self.stash_modules) > 0 ), "stash_modules must be specified when moe_paged_stash is enabled." @@ -1503,7 +1500,16 @@ def __post_init__(self): f"A module cannot be stashed and offloaded at the same time. " f"Found overlapping modules: {overlap}" ) - + # Check that Full/Selective recompute for MOE not enabled when paged stash is enabled + if self.moe_paged_stash: + if self.recompute_granularity == "full": + raise ValueError( + "Full recompute is not supported when paged stash is enabled." + ) + if self.recompute_granularity == "selective" and "moe" in self.recompute_modules: + raise ValueError( + "Selective recompute for MOE is not supported when paged stash is enabled." + ) if ( self.num_layers_in_first_pipeline_stage is not None or self.num_layers_in_last_pipeline_stage is not None @@ -2112,14 +2118,15 @@ def __post_init__(self): ) if self.cuda_graph_impl != "none": - assert ( - self.cuda_graph_impl == "transformer_engine" - and CudaGraphScope.moe not in self.cuda_graph_scope - and CudaGraphScope.mlp not in self.cuda_graph_scope - ), ( - 'CUDA graph scope on moe and mlp is not ' - 'supported with overlap_moe_expert_parallel_comm' - ) + if self.cuda_graph_impl == "transformer_engine": + assert ( + self.cuda_graph_impl == "transformer_engine" + and CudaGraphScope.moe not in self.cuda_graph_scope + and CudaGraphScope.mlp not in self.cuda_graph_scope + ), ( + 'CUDA graph scope on moe and mlp is not ' + 'supported with overlap_moe_expert_parallel_comm' + ) # Check delay_wgrad_compute compatibility if self.delay_wgrad_compute: From 2c143d40646f2ba9bb208efd4123fbefe65ad982 Mon Sep 17 00:00:00 2001 From: Pingtian Li Date: Wed, 11 Mar 2026 20:28:25 -0700 Subject: [PATCH 43/57] support use-dynamic-comp-stream --- megatron/core/model_parallel_config.py | 8 +++ .../common/model_chunk_schedule_plan.py | 19 ++++--- .../core/models/gpt/fine_grained_callables.py | 4 +- .../core/pipeline_parallel/combined_1f1b.py | 13 +++-- megatron/core/pipeline_parallel/utils.py | 55 +++++++++++++++---- .../core/transformer/transformer_config.py | 6 ++ .../test_cuda_graphed_schedule_chunk_1f1b.py | 3 +- .../a2a_overlap/test_schedule_chunk_1f1b.py | 3 +- .../a2a_overlap/test_schedule_layer_1f1b.py | 21 ++++--- ...test_fine_grained_activation_offloading.py | 3 +- 10 files changed, 98 insertions(+), 37 deletions(-) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 075aa75c76a..d8e66c0e52a 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -275,6 +275,14 @@ class ModelParallelConfig: in 1f1b phase of pipelining or non-pipelining schedule. """ + use_dynamic_comp_stream: bool = False + """Use dynamic computation stream selection instead of binding to the default stream. + When enabled, get_comp_stream() returns torch.cuda.current_stream() at call time, + allowing CUDA graph capture and replay on non-default streams. This is required for + full-iteration CUDA graph with 1f1b EP overlap where the capture stream differs + from the default stream. + """ + delay_wgrad_compute: bool = False """Delay the weight gradient computation to improve batch-level communication overlapping""" diff --git a/megatron/core/models/common/model_chunk_schedule_plan.py b/megatron/core/models/common/model_chunk_schedule_plan.py index c5cf05a8f6e..506f290e7d8 100644 --- a/megatron/core/models/common/model_chunk_schedule_plan.py +++ b/megatron/core/models/common/model_chunk_schedule_plan.py @@ -64,8 +64,8 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar event (torch.cuda.Event): record CUDA event across multiple nodes on different streams for synchronization. chunk_state (ModelChunkState): model state shared in the model chunk. - comp_stream (torch.cuda.Stream): CUDA stream for computation. - comm_stream (torch.cuda.Stream): CUDA stream for communication. + comp_stream (Callable): Func that returns CUDA stream for computation. + comm_stream (Callable): Func that returns CUDA stream for communication. extra_args (dict): extra arguments for the layer. The event and chunk_state are binded to the TransformerModelChunkSchedulePlan @@ -318,9 +318,6 @@ def __init__( self.post_process = None self.vp_stage = model.vp_stage - comp_stream = get_comp_stream() - comm_stream = get_comm_stream() - # save the inputs of model.forward() to ModelChunkState self._model_chunk_state.input_ids = input_ids self._model_chunk_state.position_ids = position_ids @@ -339,18 +336,22 @@ def __init__( self._model_chunk_state.attention_bias = None # build preprocess - self.pre_process = PreProcessNode(model, self._model_chunk_state, self._event, comp_stream) + self.pre_process = PreProcessNode( + model, self._model_chunk_state, self._event, get_comp_stream + ) # build layer schedule plan for each layer. # The methods to obtain layers are different for MTP so we need the other build plan for # MTP. Also, this can help annotate MTP layer so that it can know where MTP is. - self._build_layer_schedule_plan(model.decoder, comp_stream, comm_stream) - self._build_layer_schedule_plan(getattr(model, "mtp", None), comp_stream, comm_stream) + self._build_layer_schedule_plan(model.decoder, get_comp_stream, get_comm_stream) + self._build_layer_schedule_plan( + getattr(model, "mtp", None), get_comp_stream, get_comm_stream + ) # build post process if model.post_process: self.post_process = PostProcessNode( - model, self._model_chunk_state, self._event, comp_stream + model, self._model_chunk_state, self._event, get_comp_stream ) def _build_layer_schedule_plan(self, module, comp_stream, comm_stream): diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 6658b6363ea..8d1036b5bae 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -3,7 +3,7 @@ import weakref from contextlib import nullcontext from functools import partial -from typing import Optional +from typing import Callable, Optional import torch from torch import Tensor @@ -330,6 +330,8 @@ def backward_dw(self): """Computes the weight gradients for the transformer layer node.""" if not self.delay_wgrad_compute: return + if isinstance(self.stream, Callable): + self.stream = self.stream() with torch.cuda.stream(self.stream): torch.cuda.nvtx.range_push(f"{self.name} wgrad") for module in self.bwd_dw_callables: diff --git a/megatron/core/pipeline_parallel/combined_1f1b.py b/megatron/core/pipeline_parallel/combined_1f1b.py index 232d9c8cd70..892832059d7 100644 --- a/megatron/core/pipeline_parallel/combined_1f1b.py +++ b/megatron/core/pipeline_parallel/combined_1f1b.py @@ -8,7 +8,12 @@ from megatron.core.enums import Fp8Recipe from megatron.core.fp8_utils import get_fp8_context -from megatron.core.pipeline_parallel.utils import AbstractSchedulePlan, ScheduleNode, set_streams +from megatron.core.pipeline_parallel.utils import ( + AbstractSchedulePlan, + ScheduleNode, + get_comp_stream, + set_streams, +) from megatron.core.utils import get_attr_wrapped_model # Types @@ -47,7 +52,7 @@ def combined_1f1b_schedule_for_no_pipelining( Phases 4: 4th microbatch backward """ - set_streams() + set_streams(use_dynamic_comp_stream=config.use_dynamic_comp_stream) # The forward step for the first microbatch is executed alone, no a2a overlapping output_tensor, num_tokens, _ = combined_forward_backward_step( forward_step_func, @@ -173,7 +178,7 @@ def combined_1f1b_schedule_for_interleaved_pipelining(): # backward_step_helper_postprocess() """ - set_streams() + set_streams(use_dynamic_comp_stream=config.use_dynamic_comp_stream) # forward prepare f_model_chunk_id = None f_microbatch_id = None @@ -405,7 +410,7 @@ def forward_backward_step(): from megatron.core.pipeline_parallel.schedules import forward_step_calc_loss loss_node = ScheduleNode( - loss_func, torch.cuda.current_stream(), f_schedule_plan.event, name="loss_func" + loss_func, get_comp_stream, f_schedule_plan.event, name="loss_func" ) loss_func = loss_node.forward output_tensor, num_tokens = forward_step_calc_loss( diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index 8f6b25eec32..08da68971ec 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -154,7 +154,7 @@ def __init__( Args: forward_func (callable): Function to execute during the forward pass. - stream (torch.cuda.Stream): The CUDA stream for this node's computation. + stream (Callable): Func that returns CUDA stream for computation. This can be either a 'compute' stream or a 'communicate' stream. - 'compute' stream: Used for computational nodes like attention and experts. - 'communicate' stream: Used for nodes that handle token communication, @@ -198,6 +198,9 @@ def forward(self, inputs=()): return self._forward(*inputs) def _forward(self, *inputs): + # Lazy initialization of stream + if isinstance(self.stream, Callable): + self.stream = self.stream() with self.stream_acquire_context(f"{self.name} forward"): self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs] for i, input in enumerate(self.inputs): @@ -235,6 +238,9 @@ def backward(self, output_grad): return self._backward(*output_grad) def _backward(self, *output_grad): + # Lazy initialization of stream + if isinstance(self.stream, Callable): + self.stream = self.stream() with self.stream_acquire_context(f"{self.name} backward"): outputs = self.output if not isinstance(outputs, tuple): @@ -323,31 +329,56 @@ def run( ... +_USE_DYNAMIC_COMP_STREAM = None _COMP_STREAM = None _COMM_STREAM = None -def set_streams(comp_stream=None, comm_stream=None): - """Set the streams for communication and computation""" +def set_streams(comp_stream=None, comm_stream=None, use_dynamic_comp_stream=False): + """Set the streams for communication and computation. + + When use_dynamic_comp_stream is True, get_comp_stream() will return + torch.cuda.current_stream() at call time instead of a cached stream, + which is required for full-iteration CUDA graph capture/replay where + the capture stream differs from the default stream. + """ global _COMP_STREAM global _COMM_STREAM - if _COMP_STREAM is not None and _COMM_STREAM is not None: + global _USE_DYNAMIC_COMP_STREAM + + if _USE_DYNAMIC_COMP_STREAM is None: + _USE_DYNAMIC_COMP_STREAM = use_dynamic_comp_stream + + # Set communication stream + if _COMM_STREAM is None: + if comm_stream is None: + comm_stream = torch.cuda.Stream(device="cuda") + _COMM_STREAM = comm_stream + + # In dynamic mode, comp stream is resolved at call time via current_stream() + if _USE_DYNAMIC_COMP_STREAM: + _COMP_STREAM = None return + if _COMP_STREAM is None: + if comp_stream is None: + comp_stream = torch.cuda.current_stream() + _COMP_STREAM = comp_stream - if comp_stream is None: - comp_stream = torch.cuda.current_stream() - if comm_stream is None: - comm_stream = torch.cuda.Stream(device="cuda") - assert _COMP_STREAM is None - assert _COMM_STREAM is None - _COMP_STREAM = comp_stream - _COMM_STREAM = comm_stream +def reset_streams(): + """Reset all stream state. Intended for testing or reinitialisation.""" + global _COMP_STREAM, _COMM_STREAM, _USE_DYNAMIC_COMP_STREAM + _USE_DYNAMIC_COMP_STREAM = None + _COMP_STREAM = None + _COMM_STREAM = None def get_comp_stream(): """Get the stream for computation""" global _COMP_STREAM + global _USE_DYNAMIC_COMP_STREAM + if _USE_DYNAMIC_COMP_STREAM: + return torch.cuda.current_stream() return _COMP_STREAM diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 2a280b8ee19..b57cd4974a0 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2148,6 +2148,12 @@ def __post_init__(self): 'ep_overlap_early_attn_memory_release' ) + if self.use_dynamic_comp_stream: + assert self.overlap_moe_expert_parallel_comm, ( + 'overlap_moe_expert_parallel_comm must be enabled when enabling ' + 'use_dynamic_comp_stream' + ) + if self.context_parallel_size > 1 and self.cp_comm_type is not None: if isinstance(self.cp_comm_type, list): assert len(self.cp_comm_type) == self.num_layers, ( diff --git a/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py b/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py index 85586095bd7..cc56b170ba9 100644 --- a/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py +++ b/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py @@ -13,7 +13,7 @@ ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator -from megatron.core.pipeline_parallel.utils import set_streams +from megatron.core.pipeline_parallel.utils import reset_streams, set_streams from megatron.core.tensor_parallel.random import HAVE_TE, model_parallel_cuda_manual_seed from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.module import float16_to_fp32 @@ -68,6 +68,7 @@ def teardown_method(self, method): os.environ.pop(key, None) else: os.environ[key] = value + reset_streams() Utils.destroy_model_parallel() destroy_global_vars() destroy_num_microbatches_calculator() diff --git a/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py b/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py index 6c59dd3f9e3..0e950946898 100644 --- a/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py +++ b/tests/unit_tests/a2a_overlap/test_schedule_chunk_1f1b.py @@ -10,7 +10,7 @@ get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.pipeline_parallel.utils import set_streams +from megatron.core.pipeline_parallel.utils import reset_streams, set_streams from megatron.core.transformer.module import float16_to_fp32 from megatron.core.utils import is_te_min_version from tests.unit_tests.a2a_overlap.utils import ( @@ -80,6 +80,7 @@ def setup_method(self, method): set_streams() def teardown_method(self, method): + reset_streams() Utils.destroy_model_parallel() @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") diff --git a/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py b/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py index c6c4a75af99..69926dbf466 100644 --- a/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py +++ b/tests/unit_tests/a2a_overlap/test_schedule_layer_1f1b.py @@ -12,6 +12,12 @@ get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.pipeline_parallel.utils import ( + get_comm_stream, + get_comp_stream, + reset_streams, + set_streams, +) from megatron.core.utils import is_te_min_version from tests.unit_tests.a2a_overlap.utils import ( DummyState, @@ -68,9 +74,8 @@ def run_transformer_layer_a2a_overlap_with_capture(model, input_tensors, microba for i in range(len(input_tensors)): input_tensors[i] = input_tensors[i].clone() + set_streams() event = torch.cuda.Event() - comp_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream(device="cuda") state = DummyState() state.is_mtp = False state.model = model @@ -79,8 +84,8 @@ def run_transformer_layer_a2a_overlap_with_capture(model, input_tensors, microba transformer_layer, event, state, - comp_stream, - comm_stream, + get_comp_stream, + get_comm_stream, extra_args={"is_moe": True, "enable_deepep": False}, ) for _ in range(microbatches) @@ -183,8 +188,7 @@ def run_mtp_layer_a2a_overlap_with_capture( for i in range(len(hidden_states)): hidden_states[i] = hidden_states[i].clone() - comp_stream = torch.cuda.current_stream() - comm_stream = torch.cuda.Stream(device="cuda") + set_streams() layers = [] for _ in range(microbatches): state = DummyState() @@ -203,8 +207,8 @@ def run_mtp_layer_a2a_overlap_with_capture( model.mtp.layers[0], event, state, - comp_stream, - comm_stream, + get_comp_stream, + get_comm_stream, extra_args={ "is_moe": True, "enable_deepep": False, @@ -255,6 +259,7 @@ def setup_method(self, method): ) def teardown_method(self, method): + reset_streams() Utils.destroy_model_parallel() @pytest.mark.skipif(not is_te_min_version("1.9.0.dev0"), reason="Requires TE >= 1.9.0.dev0") diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 558c6934a0c..d39a2ed4f0f 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -346,7 +346,7 @@ def test_fine_grained_activation_offload_with_ep_a2a_overlap_compatibility( from megatron.core.models.common.model_chunk_schedule_plan import ( TransformerModelChunkSchedulePlan, ) - from megatron.core.pipeline_parallel.utils import set_streams + from megatron.core.pipeline_parallel.utils import reset_streams, set_streams from tests.unit_tests.a2a_overlap.utils import deterministic_mode # EP overlap requires distributed initialization with EP groups. @@ -570,4 +570,5 @@ def _run_schedule_1f1b_two_microbatches( f"(rel_err={rel_err:.2f}, abs_err={abs_err:.2f})" ) finally: + reset_streams() Utils.destroy_model_parallel() From 0c2da521f1a328c5342005c3e81627f00dcab5b4 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Tue, 17 Mar 2026 18:18:51 +0800 Subject: [PATCH 44/57] Revert "remove encoder_and_decoder from enums (#3406)" This reverts commit 7c7c9e12f2266f97c8712f92ab486166432998ef. --- megatron/core/enums.py | 8 ++++++++ megatron/core/pipeline_parallel/schedules.py | 9 +++++++++ megatron/core/transformer/enums.py | 9 +++++++++ 3 files changed, 26 insertions(+) diff --git a/megatron/core/enums.py b/megatron/core/enums.py index cb378d88e0f..9b76bc52a87 100644 --- a/megatron/core/enums.py +++ b/megatron/core/enums.py @@ -8,6 +8,14 @@ class ModelType(enum.Enum): encoder_or_decoder = 1 + @property + def encoder_and_decoder(self): + """Deprecated property - use encoder_or_decoder instead.""" + raise ValueError( + "ModelType.encoder_and_decoder is deprecated. Please use ModelType.encoder_or_decoder " + "instead." + ) + class Fp8Recipe(str, enum.Enum): """FP8 recipe names: delayed, tensorwise, mxfp8, blockwise, custom.""" diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 04b22bfc297..7dd860cf2ad 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -9,6 +9,7 @@ from torch.autograd.variable import Variable from megatron.core import parallel_state +from megatron.core.enums import ModelType from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) @@ -1017,6 +1018,10 @@ def forward_backward_pipelining_with_interleaving( elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model[0]) + assert model_type != ModelType.encoder_and_decoder, ( + "encoder PP stages not yet supported when passing custom process groups. " + "support coming soon!" + ) assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" assert hasattr(pg_collection, 'tp'), "pg_collection must have a tp_group" assert hasattr(pg_collection, 'cp'), "pg_collection must have a cp_group" @@ -2201,6 +2206,10 @@ def forward_backward_pipelining_without_interleaving( ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model) + assert model_type != ModelType.encoder_and_decoder, ( + "encoder PP stages not yet supported when passing custom process groups. " + "support coming soon!" + ) assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" assert hasattr(pg_collection, 'tp'), "pg_collection must have tp_group" assert hasattr(pg_collection, 'cp'), "pg_collection must have cp_group" diff --git a/megatron/core/transformer/enums.py b/megatron/core/transformer/enums.py index 1bf16095908..d57e24887ab 100644 --- a/megatron/core/transformer/enums.py +++ b/megatron/core/transformer/enums.py @@ -9,10 +9,19 @@ class ModelType(enum.Enum): """Model Type encoder_or_decoder for bert, gpt etc + encoder_and_decoder for multimodal , T5 etc """ encoder_or_decoder = 1 + @property + def encoder_and_decoder(self): + """Deprecated property - use encoder_or_decoder instead.""" + raise ValueError( + "ModelType.encoder_and_decoder is deprecated. Please use ModelType.encoder_or_decoder " + "instead." + ) + class LayerType(enum.Enum): """Layer type From 5e83aa0c562de4bb21142fde4076480bb7a0f5a5 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Tue, 17 Mar 2026 18:31:36 +0800 Subject: [PATCH 45/57] Remove the WAR of running warmup on a side stream --- megatron/core/full_cuda_graph.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/core/full_cuda_graph.py b/megatron/core/full_cuda_graph.py index 06144ed6e44..faff0c4ab59 100644 --- a/megatron/core/full_cuda_graph.py +++ b/megatron/core/full_cuda_graph.py @@ -190,9 +190,7 @@ def __call__(self, *args, **kwargs): torch.distributed.barrier() logger.info(f'CUDA graph capture done for {training_str}!!!') if FullCudaGraphWrapper.cuda_graph[training_str] is None: - # do it in a side stream - with torch.cuda.stream(torch.cuda.Stream()): - FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) + FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs) else: FullCudaGraphWrapper.cuda_graph[training_str].replay() self.next_iter(training_str) From d053991cb6b330c229bf71a3517097bead3e983c Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Tue, 17 Mar 2026 19:49:28 +0800 Subject: [PATCH 46/57] Reapply "remove encoder_and_decoder from enums (#3406)" This reverts commit d71009fc47854ff3440b0fc50c389bfa164c60e7. --- megatron/core/enums.py | 8 -------- megatron/core/pipeline_parallel/schedules.py | 9 --------- megatron/core/transformer/enums.py | 9 --------- 3 files changed, 26 deletions(-) diff --git a/megatron/core/enums.py b/megatron/core/enums.py index 9b76bc52a87..cb378d88e0f 100644 --- a/megatron/core/enums.py +++ b/megatron/core/enums.py @@ -8,14 +8,6 @@ class ModelType(enum.Enum): encoder_or_decoder = 1 - @property - def encoder_and_decoder(self): - """Deprecated property - use encoder_or_decoder instead.""" - raise ValueError( - "ModelType.encoder_and_decoder is deprecated. Please use ModelType.encoder_or_decoder " - "instead." - ) - class Fp8Recipe(str, enum.Enum): """FP8 recipe names: delayed, tensorwise, mxfp8, blockwise, custom.""" diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 7dd860cf2ad..04b22bfc297 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -9,7 +9,6 @@ from torch.autograd.variable import Variable from megatron.core import parallel_state -from megatron.core.enums import ModelType from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) @@ -1018,10 +1017,6 @@ def forward_backward_pipelining_with_interleaving( elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model[0]) - assert model_type != ModelType.encoder_and_decoder, ( - "encoder PP stages not yet supported when passing custom process groups. " - "support coming soon!" - ) assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" assert hasattr(pg_collection, 'tp'), "pg_collection must have a tp_group" assert hasattr(pg_collection, 'cp'), "pg_collection must have a cp_group" @@ -2206,10 +2201,6 @@ def forward_backward_pipelining_without_interleaving( ) elif p2p_communicator is not None and pg_collection is not None: model_type = get_model_type(model) - assert model_type != ModelType.encoder_and_decoder, ( - "encoder PP stages not yet supported when passing custom process groups. " - "support coming soon!" - ) assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" assert hasattr(pg_collection, 'tp'), "pg_collection must have tp_group" assert hasattr(pg_collection, 'cp'), "pg_collection must have cp_group" diff --git a/megatron/core/transformer/enums.py b/megatron/core/transformer/enums.py index d57e24887ab..1bf16095908 100644 --- a/megatron/core/transformer/enums.py +++ b/megatron/core/transformer/enums.py @@ -9,19 +9,10 @@ class ModelType(enum.Enum): """Model Type encoder_or_decoder for bert, gpt etc - encoder_and_decoder for multimodal , T5 etc """ encoder_or_decoder = 1 - @property - def encoder_and_decoder(self): - """Deprecated property - use encoder_or_decoder instead.""" - raise ValueError( - "ModelType.encoder_and_decoder is deprecated. Please use ModelType.encoder_or_decoder " - "instead." - ) - class LayerType(enum.Enum): """Layer type From c62a8650c5c1724c0f0f021613f802ba4db87d03 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Tue, 17 Mar 2026 17:32:19 -0700 Subject: [PATCH 47/57] Fix for data_iterator type check in Paged Stashing fallback --- megatron/core/transformer/moe/paged_stash.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 483ebf7e16b..66111805cc9 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -1033,7 +1033,7 @@ def data_read(self, data_iterator, model, training, num_microbatches): if iterator0 is not None: for b in range(num_microbatches): data_list.append(next(iterator0)) - data_iterator_saved.extend(data_list) + data_iterator_saved.append(data_list) data_list = [iter(data_list)] else: data_list.append(None) @@ -1045,11 +1045,11 @@ def data_read(self, data_iterator, model, training, num_microbatches): data_list_i = [] for b in range(num_microbatches): data_list_i.append(next(data_iterator[i])) - data_iterator_saved.extend(data_list_i) + data_iterator_saved.append(iter(data_list_i)) data_list.append(iter(data_list_i)) else: data_list.append(None) - return iter(data_iterator_saved), data_list + return data_iterator_saved, data_list def check_moe_overflow(self): # check for paged stash overflow From 5ee817c60cb635ab7387c8124e907de21779f356 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Tue, 17 Mar 2026 19:43:12 -0700 Subject: [PATCH 48/57] Change to support eager-mode fallback for validation --- megatron/core/transformer/moe/paged_stash.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 66111805cc9..ccae6ca9c12 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -1080,7 +1080,8 @@ def prepare_for_rerun(self, is_training=True): # Set grad to zero. for model_chunk in self.model: model_chunk.zero_grad_buffer() - self.optimizer.zero_grad() + if self.optimizer is not None: + self.optimizer.zero_grad() #_handle_mxfp8_param_buffer_copy if self.copy_main_params: @@ -1090,11 +1091,12 @@ def _try_copy_main_params(opt): # Handle both ChainedOptimizer and direct DistributedOptimizer cases # Note: FSDP's DistributedOptimizer doesn't have shard_fp32_from_float16_groups, # so we check for this attribute before calling _copy_main_params_to_param_buffer - if hasattr(self.optimizer, 'chained_optimizers'): - for optim_instance in self.optimizer.chained_optimizers: - _try_copy_main_params(optim_instance) - else: - _try_copy_main_params(self.optimizer) + if self.optimizer is not None: + if hasattr(self.optimizer, 'chained_optimizers'): + for optim_instance in self.optimizer.chained_optimizers: + _try_copy_main_params(optim_instance) + else: + _try_copy_main_params(self.optimizer) # Delete the CUDA graph if isinstance(self.forward_backward_func, FullCudaGraphWrapper): From 485dd7e45707a95804f01a9cb3b604c1f0e034ab Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Wed, 18 Mar 2026 13:58:17 +0800 Subject: [PATCH 49/57] Revert "Check in dynamic-shape-aware SwiGLU triton kernel" This reverts commit be3eec12bde1992c7e45d81bfa51f12cf45a6fc4. --- megatron/core/fusions/fused_bias_swiglu.py | 286 +-------------------- megatron/core/transformer/moe/experts.py | 8 - 2 files changed, 10 insertions(+), 284 deletions(-) diff --git a/megatron/core/fusions/fused_bias_swiglu.py b/megatron/core/fusions/fused_bias_swiglu.py index d50164ea59b..1161c832d79 100644 --- a/megatron/core/fusions/fused_bias_swiglu.py +++ b/megatron/core/fusions/fused_bias_swiglu.py @@ -5,8 +5,6 @@ import torch import torch.nn.functional as F -import triton -import triton.language as tl from megatron.core.jit import jit_fuser from megatron.core.utils import nvtx_decorator @@ -192,51 +190,20 @@ def backward(ctx, grad_output): class WeightedSwiGLUFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weights, fp8_input_store, num_tokens_tensor=None): - """Forward pass for weighted SwiGLU. - - Args: - input: [total_tokens, hidden_size * 2] - weights: [total_tokens, 1] - fp8_input_store: Whether to store in FP8 - num_tokens_tensor: Optional scalar tensor with actual token count - (uses Triton if provided) - """ - # Convert input for backward pass + # bias is an optional argument + def forward(ctx, input, weights, fp8_input_store): input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input - - # Use Triton implementation if num_tokens_tensor provided and available - if num_tokens_tensor is not None and input.dim() == 2: - output = weighted_swiglu_triton(input, weights, num_tokens_tensor) - ctx.save_for_backward(input_for_backward, weights, num_tokens_tensor) - ctx.use_triton = True - else: - # Fallback to JIT fused implementation - output = weighted_swiglu(input, weights) - ctx.save_for_backward(input_for_backward, weights) - ctx.use_triton = False - + ctx.save_for_backward(input_for_backward, weights) ctx.ori_input_dtype = input.dtype ctx.fp8_input_store = fp8_input_store - return output + return weighted_swiglu(input, weights) @staticmethod def backward(ctx, grad_output): - """Backward pass for weighted SwiGLU.""" - if ctx.use_triton: - # Triton backward path - input, weights, num_tokens_tensor = ctx.saved_tensors - input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input - grad_input, grad_weights = weighted_swiglu_triton_back( - grad_output, input, weights, num_tokens_tensor - ) - return grad_input, grad_weights, None, None - else: - # JIT fused backward path - input, weights = ctx.saved_tensors - input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input - tmp, wgrad = weighted_swiglu_back(grad_output, input, weights) - return tmp, wgrad, None, None + input, weights = ctx.saved_tensors + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp, wgrad = weighted_swiglu_back(grad_output, input, weights) + return tmp, wgrad, None def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False): @@ -269,7 +236,7 @@ def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) -def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False, num_tokens_tensor=None): +def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False): """ Token-wise-weighted bias swiglu fusion. """ @@ -279,7 +246,7 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False, num_t if bias is not None: raise NotImplementedError("Bias is not supported for weighted swiglu fusion") else: - output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store, num_tokens_tensor) + output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store) return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) @@ -287,236 +254,3 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False, num_t # bias_swiglu_impl = BiasSwiGLUFunction.apply # swiglu_impl = SwiGLUFunction.apply - -@triton.jit -def _weighted_swiglu_fwd_kernel( - input_ptr, - weights_ptr, - output_ptr, - num_tokens_ptr, - hidden_size: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """Triton kernel for weighted SwiGLU forward pass. - - Processes tokens in strided pattern, only operating on valid tokens. - Formula: output = SiLU(input[:, :H]) * input[:, H:] * weights - """ - pid = tl.program_id(axis=0) - num_blocks = tl.num_programs(axis=0) - - # Load actual number of tokens - num_tokens = tl.load(num_tokens_ptr) - - # Strided access: each block handles tokens [pid, pid+num_blocks, ...] - token_idx = pid - while token_idx < num_tokens: - token_idx_i64 = token_idx.to(tl.int64) - # Load weight for this token - weight = tl.load(weights_ptr + token_idx_i64) - - # Process hidden dimension - for h_offset in range(0, hidden_size, BLOCK_SIZE): - h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size - - # Load input chunks (gate and value) - input_offset_1 = token_idx_i64 * (hidden_size * 2) + h_offset - input_offset_2 = token_idx_i64 * (hidden_size * 2) + hidden_size + h_offset - - y1 = tl.load( - input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 - ) - y2 = tl.load( - input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 - ) - - # SwiGLU: SiLU(y1) * y2 * weight - # SiLU(x) = x * sigmoid(x) - # Cast to fp32 for sigmoid computation (required by Triton) - y1_fp32 = y1.to(tl.float32) - y2_fp32 = y2.to(tl.float32) - weight_fp32 = weight.to(tl.float32) - - sigmoid_y1 = tl.sigmoid(y1_fp32) - silu_y1 = y1_fp32 * sigmoid_y1 - result = silu_y1 * y2_fp32 * weight_fp32 - - # Store output (cast back to original dtype) - output_offset = token_idx_i64 * hidden_size + h_offset - tl.store( - output_ptr + output_offset + tl.arange(0, BLOCK_SIZE), - result.to(y1.dtype), - mask=h_mask, - ) - - # Stride to next token - token_idx += num_blocks - - -@triton.jit -def _weighted_swiglu_bwd_kernel( - grad_output_ptr, - input_ptr, - weights_ptr, - grad_input_ptr, - grad_weights_ptr, - num_tokens_ptr, - hidden_size: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """Triton kernel for weighted SwiGLU backward pass. - - Computes gradients with respect to input and weights for valid tokens only. - """ - pid = tl.program_id(axis=0) - num_blocks = tl.num_programs(axis=0) - - # Load actual number of tokens - num_tokens = tl.load(num_tokens_ptr) - - # Strided access - token_idx = pid - while token_idx < num_tokens: - token_idx_i64 = token_idx.to(tl.int64) - # Load weight for this token - weight = tl.load(weights_ptr + token_idx_i64) - - # Accumulator for weight gradient (fp32 for precision) - weight_grad_acc = 0.0 - - # Process hidden dimension - for h_offset in range(0, hidden_size, BLOCK_SIZE): - h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size - - # Load grad_output - grad_out_offset = token_idx_i64 * hidden_size + h_offset - grad_out = tl.load( - grad_output_ptr + grad_out_offset + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 - ) - - # Load input chunks - input_offset_1 = token_idx_i64 * (hidden_size * 2) + h_offset - input_offset_2 = token_idx_i64 * (hidden_size * 2) + hidden_size + h_offset - - y1 = tl.load( - input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 - ) - y2 = tl.load( - input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0 - ) - - # Cast to fp32 for sigmoid computation (required by Triton) - y1_fp32 = y1.to(tl.float32) - y2_fp32 = y2.to(tl.float32) - grad_out_fp32 = grad_out.to(tl.float32) - weight_fp32 = weight.to(tl.float32) - - # Forward calculations - sigmoid_y1 = tl.sigmoid(y1_fp32) - silu_y1 = y1_fp32 * sigmoid_y1 - - # Gradient for y1 (gate): d(SiLU(y1))/dy1 * y2 * weight * grad_out - # d(SiLU(y1))/dy1 = sigmoid(y1) * (1 + y1 * (1 - sigmoid(y1))) - dsilu_dy1 = sigmoid_y1 * (1.0 + y1_fp32 * (1.0 - sigmoid_y1)) - grad_y1 = grad_out_fp32 * weight_fp32 * dsilu_dy1 * y2_fp32 - - # Gradient for y2 (value): SiLU(y1) * weight * grad_out - grad_y2 = grad_out_fp32 * weight_fp32 * silu_y1 - - # Store input gradients (cast back to original dtype) - tl.store( - grad_input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), - grad_y1.to(y1.dtype), - mask=h_mask, - ) - tl.store( - grad_input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), - grad_y2.to(y2.dtype), - mask=h_mask, - ) - - # Accumulate weight gradient: swiglu(y) * grad_out - # swiglu(y) = silu_y1 * y2 - weight_grad_contribution = silu_y1 * y2_fp32 * grad_out_fp32 - weight_grad_acc += tl.sum(weight_grad_contribution) - - # Store weight gradient after processing all chunks - tl.store(grad_weights_ptr + token_idx_i64, weight_grad_acc) - - # Stride to next token - token_idx += num_blocks - - -def weighted_swiglu_triton(input, weights, num_tokens_tensor): - """Triton implementation of weighted SwiGLU forward pass. - - Args: - input: [total_tokens, hidden_size * 2] - weights: [total_tokens, 1] - num_tokens_tensor: Scalar tensor with actual token count - - Returns: - output: [total_tokens, hidden_size] - """ - assert input.dim() == 2, "Input must be 2D [total_tokens, hidden_size*2]" - assert weights.dim() == 2 and weights.size(1) == 1, "Weights must be [total_tokens, 1]" - - total_tokens, hidden_size_2 = input.shape - hidden_size = hidden_size_2 // 2 - - # Allocate output - output = torch.empty((total_tokens, hidden_size), dtype=input.dtype, device=input.device) - - # Launch kernel - BLOCK_SIZE = 128 - num_blocks = min(total_tokens, 4096) - grid = (num_blocks,) - - _weighted_swiglu_fwd_kernel[grid]( - input, - weights, - output, - num_tokens_tensor, - hidden_size=hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - return output - - -def weighted_swiglu_triton_back(grad_output, input, weights, num_tokens_tensor): - """Triton implementation of weighted SwiGLU backward pass. - - Args: - grad_output: [total_tokens, hidden_size] - input: [total_tokens, hidden_size * 2] - weights: [total_tokens, 1] - num_tokens_tensor: Scalar tensor with actual token count - - Returns: - grad_input: [total_tokens, hidden_size * 2] - grad_weights: [total_tokens, 1] - """ - total_tokens, hidden_size_2 = input.shape - hidden_size = hidden_size_2 // 2 - - # Allocate gradients - grad_input = torch.empty_like(input) - grad_weights = torch.empty_like(weights) - - # Launch kernel - BLOCK_SIZE = 128 - num_blocks = min(total_tokens, 4096) - grid = (num_blocks,) - - _weighted_swiglu_bwd_kernel[grid]( - grad_output, - input, - weights, - grad_input, - grad_weights, - num_tokens_tensor, - hidden_size=hidden_size, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return grad_input, grad_weights diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 7d7d67a575e..12d7b2998fc 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -1145,14 +1145,6 @@ def remove_glu_interleaving(x: torch.Tensor) -> torch.Tensor: bias_parallel, permuted_probs, self.config.activation_func_fp8_input_store, - ( - tokens_per_expert.sum() - if ( - isinstance(tokens_per_expert, torch.Tensor) - and tokens_per_expert.is_cuda - ) - else None - ), ) elif self.activation_func == quick_gelu and self.config.gated_linear_unit: intermediate_parallel = weighted_bias_quick_geglu_impl( From 4ed48536d4fda04597e4bf0a38d11693ee0330c8 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Wed, 18 Mar 2026 14:31:35 +0800 Subject: [PATCH 50/57] Fixed some minor issues --- megatron/core/transformer/moe/fused_a2a.py | 1 - megatron/core/transformer/moe/token_dispatcher.py | 2 ++ megatron/training/utils.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/moe/fused_a2a.py b/megatron/core/transformer/moe/fused_a2a.py index aa13b9b5b5b..39f50a4a670 100644 --- a/megatron/core/transformer/moe/fused_a2a.py +++ b/megatron/core/transformer/moe/fused_a2a.py @@ -329,7 +329,6 @@ def reset_hybrid_ep_buffer(): _hybrid_ep_buffer = None -@internal_api class HybridEPDispatch(torch.autograd.Function): ''' Fused dispatch operation for permute + dispatch a2a + permute using the HybridEP backend diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index d0a46606645..4750fcd9b61 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -10,6 +10,7 @@ from megatron.core.config import is_experimental_enabled from megatron.core.fusions.fused_indices_converter import fused_indices_to_multihot from megatron.core.fusions.fused_pad_routing_map import fused_pad_routing_map +from megatron.core.jit import jit_fuser from megatron.core.tensor_parallel import ( all_to_all, gather_from_sequence_parallel_region, @@ -1452,6 +1453,7 @@ def _initialize_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor) - return routing_map, probs + @jit_fuser def dispatch_preprocess( self, hidden_states: torch.Tensor, routing_map: torch.Tensor, probs: torch.Tensor ): diff --git a/megatron/training/utils.py b/megatron/training/utils.py index d083f07e7ba..53802726cf8 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -562,8 +562,7 @@ def _broadcast(item): def _broadcast_cu_seqlens(cu_seqlens): dev = torch.cuda.current_device() n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) - n_tensor = torch.ones(1, dtype=torch.int64, device=dev) * n - # n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + n_tensor = torch.empty(1, dtype=torch.int64, device=dev).fill_(n) _broadcast(n_tensor) if n == 0: From a6875f13adaf1b5567c497f03815d953e0c20e95 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Wed, 18 Mar 2026 19:09:49 +0800 Subject: [PATCH 51/57] Fix the unit test --- megatron/core/transformer/moe/router.py | 4 +- .../core/transformer/moe/token_dispatcher.py | 2 - .../transformer/moe/test_paged_stashing.py | 235 ++++++++++++------ 3 files changed, 158 insertions(+), 83 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index af8e08d755f..e2dde33c9a3 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -55,7 +55,7 @@ def __init__( self.tp_cp_group = pg_collection.tp_cp self.tp_dp_cp_group = pg_collection.tp_dp_cp self.tp_ep_group = pg_collection.tp_ep - self.dp_group = pg_collection.dp + self.expt_dp_group = pg_collection.expt_dp # Initialize the gate weights. # TODO: Add support for GPU initialization, which requires updating the golden values. @@ -718,7 +718,7 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No layer_number=self.layer_number, num_local_experts=num_local_experts, tp_ep_group=self.tp_ep_group, - dp_group=self.dp_group, + dp_group=self.expt_dp_group, ) return probs, routing_map diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 4750fcd9b61..67f742212f4 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -1087,8 +1087,6 @@ def dispatch( if self.moe_expert_rank_capacity_factor is not None: over_budget = self.handle[8] != 0 # this is overflow_flag self.over_budget |= over_budget - if self.config.cuda_graph_impl == 'none': - assert not self.over_budget.item(), 'Over budget' if self.num_permuted_tokens is None: self.tokens_per_expert = tokens_per_expert.to(torch.int64) diff --git a/tests/unit_tests/transformer/moe/test_paged_stashing.py b/tests/unit_tests/transformer/moe/test_paged_stashing.py index 757c602a7ec..1a759bd55c5 100644 --- a/tests/unit_tests/transformer/moe/test_paged_stashing.py +++ b/tests/unit_tests/transformer/moe/test_paged_stashing.py @@ -1,38 +1,24 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -import copy -import dataclasses - import pytest import torch +import torch.nn.functional as F -from megatron.core import config, parallel_state +from megatron.core import config from megatron.core.fp8_utils import get_fp8_context from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_utils import get_align_size_for_quantization +from megatron.core.transformer.moe.experts import TEGroupedMLP from megatron.core.transformer.moe.paged_stash import ( paged_stash_init_chunk_handler, paged_stash_reset, ) from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version from megatron.training.initialize import _set_random_seed from tests.unit_tests.test_utilities import Utils -def token_permutation(token_dispatcher, hidden_states, probs, indices): - residual = hidden_states - hidden_states, probs = token_dispatcher.dispatch_preprocess(hidden_states, indices, probs) - hidden_states, probs = token_dispatcher.token_dispatch(hidden_states, probs) - return hidden_states, probs, residual - - -def token_unpermutation(token_dispatcher, hidden_states): - hidden_states = token_dispatcher.token_combine(hidden_states) - hidden_states = token_dispatcher.combine_postprocess(hidden_states) - return hidden_states, None - - class MoEModelTestContainer: def __init__( self, @@ -66,12 +52,6 @@ def __init__( expert_tensor_parallel_size=moe_tp_size, ) _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) - local_expert_indices_offset = ( - parallel_state.get_expert_model_parallel_rank() * self.num_local_experts - ) - self.local_expert_indices = [ - local_expert_indices_offset + i for i in range(self.num_local_experts) - ] self.config = TransformerConfig( tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size, @@ -110,15 +90,19 @@ def __init__( stash_modules=kwargs.get("stash_modules", None), moe_expert_rank_capacity_factor=kwargs.get("moe_expert_rank_capacity_factor", None), moe_router_padding_for_fp8=kwargs.get("moe_router_padding_for_fp8", True), + use_transformer_engine_op_fuser=kwargs.get("use_transformer_engine_op_fuser", False), + moe_mlp_glu_interleave_size=kwargs.get("moe_mlp_glu_interleave_size", None), + moe_router_padding_for_quantization=kwargs.get("moe_router_padding_for_quantization", False), + gated_linear_unit=kwargs.get("gated_linear_unit", False), + activation_func=kwargs.get("activation_func", F.gelu), + moe_router_force_biased=kwargs.get("moe_router_force_biased", None), ) - # init moe layers - self.moe_layers = [self.new_moe_layer(layer_number=i) for i in range(num_layers)] + self.moe_layer = self._create_moe_layer(layer_number=0) - def new_moe_layer(self, **kargs): + def _create_moe_layer(self, layer_number=0): transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=self.config.num_moe_experts, moe_grouped_gemm=True ) - layer_number = kargs.get("layer_number", 0) quantization_context = get_fp8_context(self.config, layer_number, is_init=True) with quantization_context: moe_layer = ( @@ -130,62 +114,50 @@ def new_moe_layer(self, **kargs): return moe_layer def zero_grad(self): - for moe_layer in self.moe_layers: - moe_layer.zero_grad() + self.moe_layer.zero_grad() def __del__(self): torch.distributed.barrier() torch.cuda.synchronize() Utils.destroy_model_parallel() - @pytest.mark.internal - def dispatcher_dropless_test(self, inp_hidden_states=None): - moe_layers = self.moe_layers - - inp_hidden_states = inp_hidden_states.cuda() - # Permute and then unpermute data are supposed to restore original data - inp_hidden_states.requires_grad = True - hidden_states = inp_hidden_states - for i, moe_layer in enumerate(moe_layers): - quantization_context = get_fp8_context(self.config) - with quantization_context: - probs, indices = moe_layer.router(hidden_states) - probs = torch.ones_like(probs) / moe_layer.router.topk - - (dispatched_input, probs, residual) = token_permutation( - moe_layer.token_dispatcher, hidden_states, probs, indices - ) - output, _ = moe_layer.routed_experts_compute(dispatched_input, probs, residual) - output, _ = token_unpermutation(moe_layer.token_dispatcher, output) - hidden_states = output - torch.autograd.backward(output, inp_hidden_states) - return output, inp_hidden_states.grad - - def set_params(self): - # TODO: Set consistent parameters for various parallelisms. - raise NotImplementedError + def forward_backward(self, hidden_states): + """Run one forward and backward pass through the MoE layer. + + Returns: + output: MoE layer output (detached). + hidden_states_grad: Gradient w.r.t. hidden_states. + routing_map: Token-to-expert routing map from the dispatcher (after forward). + tokens_per_expert: Number of tokens per local expert on this EP rank (after forward). + """ + hidden_states = hidden_states.cuda().requires_grad_(True) + quantization_context = get_fp8_context(self.config) + with quantization_context: + output, _ = self.moe_layer(hidden_states) + # Capture routing_map and tokens_per_expert after forward (before backward) + comm = getattr(self.moe_layer.token_dispatcher, "_comm_manager", None) + routing_map = getattr(comm, "routing_map", None) + tokens_per_expert = ( + comm.get_number_of_tokens_per_expert() + if comm is not None and hasattr(comm, "get_number_of_tokens_per_expert") + else None + ) + # Use contiguous gradient to avoid non-contiguous grad in HybridEP combine backward + # (output.sum().backward() produces a broadcast gradient that is non-contiguous) + output.backward(torch.ones_like(output)) + return output.detach(), hidden_states.grad, routing_map, tokens_per_expert def destroy(self): Utils.destroy_model_parallel() -permute_fusion_params = [False] -if is_te_min_version("2.1.0"): - permute_fusion_params.append(True) - - -def is_deep_ep_available(): - from megatron.core.transformer.moe.fused_a2a import HAVE_DEEP_EP - return HAVE_DEEP_EP - - def is_hybrid_ep_available(): from megatron.core.transformer.moe.fused_a2a import HAVE_HYBRIDEP return HAVE_HYBRIDEP @pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available") -class TestFlexDispatcher: +class TestPagedStashing: def setup_method(self, method): pass @@ -219,29 +191,134 @@ def test_forward_backward(self): moe_paged_stash=True, stash_modules=["expert_fc1", "moe_act", "expert_fc2"], moe_expert_rank_capacity_factor=1.5, + use_transformer_engine_op_fuser=True, + moe_mlp_glu_interleave_size=32, + moe_router_padding_for_quantization=True, + gated_linear_unit=True, + activation_func=F.silu, ) - bs = 32 - seql = 8 - - inp_hidden_states = torch.randn( - (bs, seql, container.moe_layers[0].config.hidden_size), dtype=torch.bfloat16 + if not isinstance(container.moe_layer.experts, TEGroupedMLP) or not container.moe_layer.experts._is_fused_impl_supported(): + container.destroy() + pytest.skip("TEGroupedMLP fused impl not supported") + + # [sequence_length, batch_size, hidden_size] for MoELayer.forward + seq_length = 1024 + batch_size = 1 + hidden_size = container.config.hidden_size + hidden_states = torch.randn( + (seq_length, batch_size, hidden_size), dtype=torch.bfloat16 ) - # First iteration to capture schedule, calculate capacity, etc. + + # First iteration: capture schedule, capacity, etc. paged_stash_reset(True) paged_stash_init_chunk_handler(1, 0) - output_ref, inp_hidden_states_grad_ref = container.dispatcher_dropless_test( - inp_hidden_states + output_ref, hidden_states_grad_ref, routing_map_ref, tokens_per_expert_ref = ( + container.forward_backward(hidden_states) ) container.zero_grad() - # Second iteration to run with paged stash. + # Second iteration: run with paged stash. paged_stash_reset(True) paged_stash_init_chunk_handler(1, 0) - output, inp_hidden_states_grad = container.dispatcher_dropless_test(inp_hidden_states) + output, hidden_states_grad, routing_map, tokens_per_expert = container.forward_backward( + hidden_states + ) - # verify output and input gradient are the same as the first iteration. + # Verify output and input gradient match the first iteration. torch.testing.assert_close(output, output_ref, atol=1e-4, rtol=1e-4) torch.testing.assert_close( - inp_hidden_states_grad, inp_hidden_states_grad_ref, atol=1e-4, rtol=1e-4 + hidden_states_grad, hidden_states_grad_ref, atol=1e-4, rtol=1e-4 + ) + # Routing and token counts available after forward (e.g. for debugging or further checks) + if routing_map is not None and tokens_per_expert is not None: + num_tokens_per_ep_rank = tokens_per_expert.sum().item() + assert num_tokens_per_ep_rank > 0 + assert routing_map_ref is not None and tokens_per_expert_ref is not None + torch.testing.assert_close(tokens_per_expert, tokens_per_expert_ref) + + +@pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available") +class TestPagedStashingOverBudget: + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + def test_overload_factor_and_over_budget(self): + """Test budget computation (same as token_dispatcher lines 1017-1025) and assert + over_budget flag is set when tokens_per_ep_rank exceeds budget.""" + if not is_hybrid_ep_available(): + pytest.skip("Hybrid EP is not available") + + config.ENABLE_EXPERIMENTAL = True + + container = MoEModelTestContainer( + tp_size=1, + ep_size=4, + pp_size=1, + num_moe_experts=8, + num_layers=1, + moe_router_topk=4, + moe_router_load_balancing_type="aux_loss", + moe_token_dispatcher_type="flex", + moe_permute_fusion=True, + hidden_size=1024, + moe_flex_dispatcher_backend="hybridep", + test_dtype=torch.bfloat16, + moe_grouped_gemm=True, + moe_use_device_initiated_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, + moe_paged_stash=True, + stash_modules=["expert_fc1", "moe_act", "expert_fc2"], + moe_expert_rank_capacity_factor=1.0, + use_transformer_engine_op_fuser=True, + moe_mlp_glu_interleave_size=32, + moe_router_padding_for_quantization=True, + gated_linear_unit=True, + activation_func=F.silu, + moe_router_force_biased=1, ) + if not isinstance(container.moe_layer.experts, TEGroupedMLP) or not container.moe_layer.experts._is_fused_impl_supported(): + container.destroy() + pytest.skip("TEGroupedMLP fused impl not supported") + + seq_length = 4096 + batch_size = 1 + topk = container.config.moe_router_topk + capacity_factor = container.config.moe_expert_rank_capacity_factor + hidden_size = container.config.hidden_size + hidden_states = torch.randn( + (seq_length, batch_size, hidden_size), dtype=torch.bfloat16 + ) + + # Budget computed like token_dispatcher._HybridEPManager.setup_metadata (lines 1017-1025) + num_tokens = seq_length * batch_size + pad_multiple = get_align_size_for_quantization(container.config) + budget = int(num_tokens * topk * capacity_factor) + budget += -budget % pad_multiple + + paged_stash_reset(True) + paged_stash_init_chunk_handler(1, 0) + _, _, _, tokens_per_expert = container.forward_backward(hidden_states) + + assert tokens_per_expert is not None + tokens_per_ep_rank = tokens_per_expert.sum().item() + overload_factor = tokens_per_ep_rank / (seq_length * topk) + over_budget_tensor = container.moe_layer.token_dispatcher.check_over_budget() + over_budget = over_budget_tensor.item() if over_budget_tensor is not None else False + + # When tokens_per_ep_rank > budget, over_budget flag must be raised + if tokens_per_ep_rank >= budget: + assert over_budget, ( + f"tokens_per_ep_rank ({tokens_per_ep_rank}) > budget ({budget}), " + "but over_budget flag was not set" + ) + else: + assert not over_budget, ( + f"tokens_per_ep_rank ({tokens_per_ep_rank}) <= budget ({budget}), " + "but over_budget flag was set" + ) From acda6d16d33a7ad7b3e4a59fd60ac8148f1db40f Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Sat, 14 Mar 2026 09:19:29 +0800 Subject: [PATCH 52/57] Initial commit for spill to cpu feature --- megatron/core/transformer/moe/paged_stash.py | 381 +++++++++++-------- 1 file changed, 222 insertions(+), 159 deletions(-) diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index ccae6ca9c12..45769b2b39c 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -17,137 +17,185 @@ class PagedStashBuffer: """ A paged stash buffer with page-level memory management. + Supports both CUDA and optional pinned host buffer for overflow fallback. - The buffer is organized as [num_pages, page_size, hidden_size]. - Uses a free list (circular buffer) to track available pages. + Buffers are organized as [num_pages, page_size, hidden_size]. + Uses per-buffer free lists (circular buffer) tracked as two-element state: [0]=CUDA, [1]=host. """ - def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype): + def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype, num_tokens_host=0): """ Args: - num_tokens: Maximum number of tokens the buffer can hold + num_tokens: Maximum number of tokens the CUDA buffer can hold hidden_size: Hidden dimension size page_size: Number of tokens per page device: Device for the buffer overflow: Overflow flag tensor (shared across all buffers) dtype: Data type + num_tokens_host: If > 0, allocate pinned host buffer with this many tokens for spillover. """ self.hidden_size = hidden_size self.page_size = page_size - self.num_pages = (num_tokens + page_size - 1) // page_size # Ceiling division - self.total_tokens = self.num_pages * page_size - - # Create 2D buffer [total_tokens, hidden_size] - # Organized as pages: [page_0_tokens, page_1_tokens, ...] - if os.getenv('PAGED_STASH_TO_CPU', '0') == '1': - self.buffer = torch.empty( - (self.total_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True + self.device = device + self.dtype = dtype + self.overflow = overflow # GPU flag (shared) + + # CUDA buffer + self.num_cuda_pages = (num_tokens + page_size - 1) // page_size + self.total_cuda_tokens = self.num_cuda_pages * page_size + self.cuda_buffer = torch.empty( + (self.total_cuda_tokens, hidden_size), dtype=dtype, device=device + ) + + # Host buffer (pinned), optional + self.num_host_pages = (num_tokens_host + page_size - 1) // page_size if num_tokens_host > 0 else 0 + self.total_host_tokens = self.num_host_pages * page_size if self.num_host_pages > 0 else 0 + if self.num_host_pages > 0: + self.host_buffer = torch.empty( + (self.total_host_tokens, hidden_size), dtype=dtype, device='cpu', pin_memory=True ) else: - self.buffer = torch.empty((self.total_tokens, hidden_size), dtype=dtype, device=device) + self.host_buffer = None - self.overflow = overflow # GPU flag (shared) - self.device = device - self.dtype = dtype + # Free list state: shape (2,) index 0 = CUDA, 1 = host (all in device memory for kernel) + self.free_list_head = torch.zeros(2, dtype=torch.int64, device=device) + self.free_list_tail = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], dtype=torch.int64, device=device + ) + self.free_list_capacity = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], dtype=torch.int64, device=device + ) - # Free list as circular buffer: stores available page IDs - self.free_list = torch.arange(self.num_pages, dtype=torch.int64, device=device) + # Free list arrays (device memory): page IDs for each buffer + self.free_list_cuda = torch.arange(self.num_cuda_pages, dtype=torch.int64, device=device) + if self.num_host_pages > 0: + self.free_list_host = torch.arange(self.num_host_pages, dtype=torch.int64, device=device) + else: + self.free_list_host = torch.empty(0, dtype=torch.int64, device=device) - # Head and tail pointers for free_list circular buffer - self.free_list_head = torch.zeros( - 1, dtype=torch.int64, device=device - ) # Read pointer (allocation) - self.free_list_tail = self.num_pages * torch.ones( - 1, dtype=torch.int64, device=device - ) # Write pointer (deallocation) + # Pre-allocated reset values (CUDA graph safe: no allocation in reset()) + self._reset_tail = torch.tensor( + [self.num_cuda_pages, self.num_host_pages], + dtype=torch.int64, + device=device, + ) + self._reset_free_list_cuda = torch.arange( + self.num_cuda_pages, dtype=torch.int64, device=device + ) + if self.num_host_pages > 0: + self._reset_free_list_host = torch.arange( + self.num_host_pages, dtype=torch.int64, device=device + ) + else: + self._reset_free_list_host = None - # Capacity of free list - self.free_list_capacity = self.num_pages * torch.ones(1, dtype=torch.int64, device=device) + # Legacy single-buffer API (used by kernels when host disabled): same as cuda + self.free_list = self.free_list_cuda def reset(self): - """Reset the paged buffer - reinitialize free list.""" - self.free_list.copy_(torch.arange(self.num_pages, dtype=torch.int64, device=self.device)) + """Reset both CUDA and host free lists (CUDA graph safe: no new allocations).""" + self.free_list_cuda.copy_(self._reset_free_list_cuda) self.free_list_head.zero_() - self.free_list_tail.fill_(self.num_pages) + self.free_list_tail.copy_(self._reset_tail) + if self._reset_free_list_host is not None: + self.free_list_host.copy_(self._reset_free_list_host) def __repr__(self): return ( - f"PagedStashBuffer(num_pages={self.num_pages}, page_size={self.page_size}, " - f"hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" + f"PagedStashBuffer(num_cuda_pages={self.num_cuda_pages}, num_host_pages={self.num_host_pages}, " + f"page_size={self.page_size}, hidden_size={self.hidden_size}, device={self.device}, dtype={self.dtype})" ) @triton.jit def _paged_stash_copy_kernel( src_ptr, - dst_ptr, + cuda_dst_ptr, + host_dst_ptr, num_tokens_ptr, - free_list_ptr, - free_list_head_ptr, # Read-only: current head position - free_list_tail_ptr, # Read-only: current tail position (for overflow check) + free_list_cuda_ptr, + free_list_host_ptr, + free_list_head_ptr, # shape (2,): [cuda_head, host_head] + free_list_tail_ptr, # shape (2,) free_list_capacity_ptr, - page_record_ptr, # Output: records which pages were used + page_record_ptr, overflow_ptr, - new_free_list_head_ptr, # Output: new head position (written by kernel) + spilled_to_host_ptr, # Output: 0 = stored in CUDA, 1 = stored in host or overflow + new_free_list_head_ptr, # Output: shape (2,) updated heads PAGE_SIZE: tl.constexpr, HIDDEN_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, + HAS_HOST_BUFFER: tl.constexpr, ): - """Triton kernel to copy tokens to paged stash buffer. - - Allocates pages from free list (reads from head, advances head). - Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. - Grid: (num_blocks,) where blocks process tokens in a strided pattern. - Writes new head to temporary tensor to avoid race conditions. - """ + """Copy tokens to paged stash: try CUDA first (fast path), then host if CUDA full.""" pid = tl.program_id(axis=0) num_blocks = tl.num_programs(axis=0) - # Load parameters - num_tokens = tl.load(num_tokens_ptr) - free_list_head = tl.load(free_list_head_ptr) - free_list_tail = tl.load(free_list_tail_ptr) - free_list_capacity = tl.load(free_list_capacity_ptr) - - # Check available pages (unwrapped indices: simple subtraction, no modulo needed) - avail_pages = free_list_tail - free_list_head + # Load overflow first (get in flight early); branch on it only before any write + overflow = tl.load(overflow_ptr) - # Calculate required pages + num_tokens = tl.load(num_tokens_ptr) required_pages = tl.cdiv(num_tokens, PAGE_SIZE) - overflow_detected = avail_pages < required_pages - - # Only block 0 writes overflow flag - if pid == 0 and overflow_detected: - tl.store(overflow_ptr, 1) - # All blocks return early if overflow - if overflow_detected: + # Common case: load only CUDA state (and head_host for output when use_cuda) + head_cuda = tl.load(free_list_head_ptr) + head_host = tl.load(free_list_head_ptr + 1) + tail_cuda = tl.load(free_list_tail_ptr) + cap_cuda = tl.load(free_list_capacity_ptr) + + avail_cuda = tail_cuda - head_cuda + use_cuda = avail_cuda >= required_pages + + # Assume CUDA path: set everything for GPU stash + spill = 0 + dst_ptr = cuda_dst_ptr + free_list_ptr = free_list_cuda_ptr + head = head_cuda + cap = cap_cuda + new_head_cuda = head_cuda + required_pages + new_head_host = head_host + + if overflow == 1: return - # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] + # Only when CUDA is full: load host state and maybe switch to host + if not use_cuda: + tail_host = tl.load(free_list_tail_ptr + 1) + cap_host = tl.load(free_list_capacity_ptr + 1) + use_host = HAS_HOST_BUFFER == 1 and (tail_host - head_host) >= required_pages + if use_host: + spill = 1 + dst_ptr = host_dst_ptr + free_list_ptr = free_list_host_ptr + head = head_host + cap = cap_host + new_head_cuda = head_cuda + new_head_host = head_host + required_pages + else: + if pid == 0: + tl.store(overflow_ptr, 1) + tl.store(spilled_to_host_ptr, 1) + tl.store(new_free_list_head_ptr, head_cuda) + tl.store(new_free_list_head_ptr + 1, head_host) + return + + if pid == 0: + tl.store(spilled_to_host_ptr, spill) + + # Copy loop: strided over tokens token_idx = pid while token_idx < num_tokens: - # Determine which page this token belongs to page_slot = token_idx // PAGE_SIZE token_in_page = token_idx % PAGE_SIZE - - # Read page ID from free list (with wraparound) - free_list_idx = (free_list_head + page_slot) % free_list_capacity + free_list_idx = (head + page_slot) % cap page_id = tl.load(free_list_ptr + free_list_idx) - - # First token in page: record the page ID (only if this block handles token 0 of the page) if token_in_page == 0: tl.store(page_record_ptr + page_slot, page_id) - - # Calculate destination address in paged buffer dst_token_idx = page_id * PAGE_SIZE + token_in_page - # Copy token data (2D: hidden dimension) elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 num_iters = elements_per_thread + (1 if need_mask else 0) - - # Use int64 for address math to avoid int32 overflow when indices get large. token_idx_i64 = token_idx.to(tl.int64) dst_token_idx_i64 = dst_token_idx.to(tl.int64) src_base = src_ptr + token_idx_i64 * HIDDEN_SIZE @@ -164,67 +212,80 @@ def _paged_stash_copy_kernel( hidden_offsets = tl.arange(0, BLOCK_SIZE) + iter * BLOCK_SIZE data = tl.load(src_base + hidden_offsets) tl.store(dst_base + hidden_offsets, data) - - # Stride to next token for this block token_idx += num_blocks - # Calculate and store new free list head (only block 0) - # We consumed pages, so advance head forward (unwrapped: no modulo) - # Write to temporary tensor to avoid race conditions if pid == 0: - new_head = free_list_head + required_pages - tl.store(new_free_list_head_ptr, new_head) + tl.store(new_free_list_head_ptr, new_head_cuda) + tl.store(new_free_list_head_ptr + 1, new_head_host) @triton.jit def _paged_stash_pop_kernel( - src_ptr, + cuda_src_ptr, + host_src_ptr, dst_ptr, num_tokens_ptr, - page_record_ptr, # Input: which pages to read - free_list_ptr, - free_list_head_ptr, # Read-only: current head position (not used) - free_list_tail_ptr, # Read-only: current tail position + page_record_ptr, + spilled_to_host_ptr, # 0 = read from CUDA, 1 = read from host + overflow_ptr, + free_list_cuda_ptr, + free_list_host_ptr, + free_list_tail_ptr, # shape (2,) free_list_capacity_ptr, - new_free_list_tail_ptr, # Output: new tail position (written by kernel) + new_free_list_tail_ptr, # Output: shape (2,) updated tails PAGE_SIZE: tl.constexpr, HIDDEN_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - """Triton kernel to reload tokens from paged stash buffer. - - Returns pages to free list (writes to tail, advances tail). - Uses strided access pattern: block i handles tokens [i, i+num_blocks, i+2*num_blocks, ...]. - Grid: (num_blocks,) where blocks process tokens in a strided pattern. - Writes new tail to temporary tensor to avoid race conditions. - """ + """Reload tokens from paged stash; CUDA path fast, host path when spilled_to_host.""" pid = tl.program_id(axis=0) num_blocks = tl.num_programs(axis=0) - # Load parameters + # Load overflow first (get in flight early); branch on it only before any write + overflow = tl.load(overflow_ptr) + num_tokens = tl.load(num_tokens_ptr) - free_list_tail = tl.load(free_list_tail_ptr) - free_list_capacity = tl.load(free_list_capacity_ptr) + spill = tl.load(spilled_to_host_ptr) + required_pages = tl.cdiv(num_tokens, PAGE_SIZE) + + # Common case: load only CUDA state (and tail_host for output when spill=0) + tail_cuda = tl.load(free_list_tail_ptr) + tail_host = tl.load(free_list_tail_ptr + 1) + cap_cuda = tl.load(free_list_capacity_ptr) + + # Assume CUDA path + src_ptr = cuda_src_ptr + free_list_ptr = free_list_cuda_ptr + tail = tail_cuda + cap = cap_cuda + new_tail_cuda = tail_cuda + required_pages + new_tail_host = tail_host + + # Only when spilled to host: load host state and switch + if spill == 1: + cap_host = tl.load(free_list_capacity_ptr + 1) + if cap_host == 0: + return + src_ptr = host_src_ptr + free_list_ptr = free_list_host_ptr + tail = tail_host + cap = cap_host + new_tail_cuda = tail_cuda + new_tail_host = tail_host + required_pages + + if overflow == 1: + return - # Strided access: block pid handles tokens [pid, pid+num_blocks, pid+2*num_blocks, ...] token_idx = pid while token_idx < num_tokens: - # Determine which page this token belongs to page_slot = token_idx // PAGE_SIZE token_in_page = token_idx % PAGE_SIZE - - # Read page ID from page record page_id = tl.load(page_record_ptr + page_slot) - - # Calculate source address in paged buffer src_token_idx = page_id * PAGE_SIZE + token_in_page - # Copy token data (2D: hidden dimension) elements_per_thread = HIDDEN_SIZE // BLOCK_SIZE need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0 num_iters = elements_per_thread + (1 if need_mask else 0) - - # Use int64 for address math to avoid int32 overflow when indices get large. src_token_idx_i64 = src_token_idx.to(tl.int64) token_idx_i64 = token_idx.to(tl.int64) src_base = src_ptr + src_token_idx_i64 * HIDDEN_SIZE @@ -242,22 +303,14 @@ def _paged_stash_pop_kernel( data = tl.load(src_base + hidden_offsets) tl.store(dst_base + hidden_offsets, data) - # First token in page: release page back to free list if token_in_page == 0: - # Write page ID back to free list at tail position (with wraparound) - write_idx = (free_list_tail + page_slot) % free_list_capacity + write_idx = (tail + page_slot) % cap tl.store(free_list_ptr + write_idx, page_id) - - # Stride to next token for this block token_idx += num_blocks - # Calculate and store new free list tail (only block 0) - # We returned pages, so advance tail forward (unwrapped: no modulo) - # Write to temporary tensor to avoid race conditions if pid == 0: - required_pages = tl.cdiv(num_tokens, PAGE_SIZE) - new_tail = free_list_tail + required_pages - tl.store(new_free_list_tail_ptr, new_tail) + tl.store(new_free_list_tail_ptr, new_tail_cuda) + tl.store(new_free_list_tail_ptr + 1, new_tail_host) class PagedTensor: @@ -315,6 +368,8 @@ def __init__( # Page record: stores which pages are being used for this tensor self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device) + # Set by copy kernel: 0 = data in CUDA stash, 1 = data in host (pinned) stash + self.spilled_to_host = torch.zeros(1, dtype=torch.int64, device=self.device) @property def schedule_layer(self): @@ -322,12 +377,7 @@ def schedule_layer(self): return self.schedule_layer_no def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): - """Offload the paged tensor to paged stash buffer. - - Args: - paged_stash_buffer: The paged stash buffer to offload to - max_blocks: Maximum number of blocks for Triton kernel - """ + """Offload the paged tensor to paged stash buffer (CUDA or host if CUDA full).""" self._tensor = self._tensor.contiguous() if self.num_tokens_tensor.dim() == 0: self.num_tokens_tensor = self.num_tokens_tensor.reshape(1) @@ -338,49 +388,48 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048 num_tokens_tensor = self.num_tokens_tensor max_num_tokens = self.max_num_tokens - # Get 1D tensor tensor_to_copy = self._tensor - - # Determine grid size BLOCK_SIZE = GLOBAL_BLOCK_SIZE num_blocks = min(max_num_tokens, max_blocks) grid = (num_blocks,) - # Create temporary tensor for new head - new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device) + new_free_list_head = torch.empty(2, dtype=torch.int64, device=self.device) + has_host = 1 if paged_stash_buffer.host_buffer is not None else 0 + host_dst = ( + paged_stash_buffer.host_buffer + if paged_stash_buffer.host_buffer is not None + else paged_stash_buffer.cuda_buffer + ) - # Launch paged stash copy kernel _paged_stash_copy_kernel[grid]( - tensor_to_copy.view(paged_stash_buffer.buffer.dtype), - paged_stash_buffer.buffer, + tensor_to_copy.view(paged_stash_buffer.cuda_buffer.dtype), + paged_stash_buffer.cuda_buffer, + host_dst, num_tokens_tensor, - paged_stash_buffer.free_list, + paged_stash_buffer.free_list_cuda, + paged_stash_buffer.free_list_host, paged_stash_buffer.free_list_head, paged_stash_buffer.free_list_tail, paged_stash_buffer.free_list_capacity, - self.page_record, # Triton kernel will populate page_record + self.page_record, paged_stash_buffer.overflow, + self.spilled_to_host, new_free_list_head, PAGE_SIZE=self.page_size, HIDDEN_SIZE=self.hidden_size, BLOCK_SIZE=BLOCK_SIZE, + HAS_HOST_BUFFER=has_host, ) - - # Update free list head + # if self.spilled_to_host.item() == 1: + # print(f"PagedTensor {self.layer_name} spilled to host", flush=True) + # else: + # print(f"PagedTensor {self.layer_name} stashed to cuda", flush=True) paged_stash_buffer.free_list_head.copy_(new_free_list_head) - - # Save reference to original tensor self._original_tensor = self._tensor self._tensor = None def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048): - """Reload the paged tensor from paged stash buffer. - - Args: - paged_stash_buffer: The paged stash buffer to reload from - max_blocks: Maximum number of blocks for Triton kernel - """ - # Allocate output tensor + """Reload the paged tensor from paged stash buffer (CUDA or host from spilled_to_host).""" self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device) tensor_to_reload = self._tensor @@ -390,22 +439,26 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 else: num_tokens_tensor = self.num_tokens_tensor max_num_tokens = self.max_num_tokens - # Determine grid size BLOCK_SIZE = GLOBAL_BLOCK_SIZE num_blocks = min(max_num_tokens, max_blocks) grid = (num_blocks,) - # Create temporary tensor for new tail - new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device) - - # Launch paged stash pop kernel + new_free_list_tail = torch.empty(2, dtype=torch.int64, device=self.device) + host_src = ( + paged_stash_buffer.host_buffer + if paged_stash_buffer.host_buffer is not None + else paged_stash_buffer.cuda_buffer + ) _paged_stash_pop_kernel[grid]( - paged_stash_buffer.buffer, - tensor_to_reload.view(paged_stash_buffer.buffer.dtype), + paged_stash_buffer.cuda_buffer, + host_src, + tensor_to_reload.view(paged_stash_buffer.cuda_buffer.dtype), num_tokens_tensor, - self.page_record, # Triton kernel will read from page_record - paged_stash_buffer.free_list, - paged_stash_buffer.free_list_head, + self.page_record, + self.spilled_to_host, + paged_stash_buffer.overflow, + paged_stash_buffer.free_list_cuda, + paged_stash_buffer.free_list_host, paged_stash_buffer.free_list_tail, paged_stash_buffer.free_list_capacity, new_free_list_tail, @@ -414,7 +467,6 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204 BLOCK_SIZE=BLOCK_SIZE, ) - # Update free list tail paged_stash_buffer.free_list_tail.copy_(new_free_list_tail) @@ -682,17 +734,28 @@ def allocate_stash_buffers(self, stash_buffer_size_factor=1.10): if not max_tokens_dict: max_tokens_dict = self.max_tokens_across_vp_stages + cpu_size_factor = float(os.getenv('STASH_BUFFER_CPU_SIZE_FACTOR', '0')) for dtype, hidden_size in max_tokens_dict: if dtype not in self.stash_buffers: self.stash_buffers[dtype] = {} assert hidden_size not in self.stash_buffers[dtype] - num_tokens = int( - max_tokens_dict[dtype, hidden_size] * scale - ) + num_tokens = int(max_tokens_dict[dtype, hidden_size] * scale) + num_tokens_host = int(max_tokens_dict[dtype, hidden_size] * cpu_size_factor) if cpu_size_factor > 0 else 0 + buf_dtype = torch.uint8 if dtype in [torch.float8_e4m3fn, torch.float8_e8m0fnu] else dtype self.stash_buffers[dtype][hidden_size] = PagedStashBuffer( - num_tokens, hidden_size, self.page_size, self.device, self.overflow, torch.uint8 if dtype in [torch.float8_e4m3fn, torch.float8_e8m0fnu] else dtype + num_tokens, + hidden_size, + self.page_size, + self.device, + self.overflow, + buf_dtype, + num_tokens_host=num_tokens_host, ) - print (f'allocate_stash_buffers num_tokens: {self.stash_buffers[dtype][hidden_size].buffer.shape}-{self.stash_buffers[dtype][hidden_size].dtype} ({dtype})') + sb = self.stash_buffers[dtype][hidden_size] + msg = f'allocate_stash_buffers cuda: {sb.cuda_buffer.shape}' + if sb.host_buffer is not None: + msg += f' host: {sb.host_buffer.shape}' + print(f'{msg} dtype={sb.dtype} ({dtype})') def update_pp_schedule(self, vp_stage, layer_no=None, microbatch_no=None): """Update the pp schedule.""" From 6ddc49b425aba5a9c28a28d46a16bed68f05e44f Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Wed, 18 Mar 2026 20:17:39 +0800 Subject: [PATCH 53/57] Move paged stashing knobs from env vars to transformer_config knobs --- megatron/core/pipeline_parallel/schedules.py | 6 +- megatron/core/transformer/moe/paged_stash.py | 58 +++++++++++++------ .../core/transformer/transformer_config.py | 7 +++ .../transformer/moe/test_paged_stashing.py | 1 - 4 files changed, 49 insertions(+), 23 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 04b22bfc297..03dbcf1f79c 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -591,7 +591,7 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - paged_stash_reset(enabled=config.moe_paged_stash and not forward_only) + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) no_sync_func = config.no_sync_func if no_sync_func is None: @@ -1052,7 +1052,7 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" - paged_stash_reset(enabled=config.moe_paged_stash and not forward_only) + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -2237,7 +2237,7 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - paged_stash_reset(enabled=config.moe_paged_stash and not forward_only) + paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config) # Disable async grad reductions no_sync_func = config.no_sync_func diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 45769b2b39c..decb012def8 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -1,8 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -import os from contextlib import nullcontext -from typing import Any +from typing import Any, Tuple, Union import torch import triton @@ -14,6 +13,16 @@ GLOBAL_BLOCK_SIZE = 1024 SCALE_INV_BLOCK_SIZE = 32 + +def _normalize_stash_buffer_size_factor( + value: Union[float, Tuple[float, float], list], +) -> Tuple[float, float]: + """Normalize stash_buffer_size_factor to (cuda_factor, cpu_factor).""" + if isinstance(value, (list, tuple)) and len(value) == 2: + return (float(value[0]), float(value[1])) + return (float(value), 0.0) + + class PagedStashBuffer: """ A paged stash buffer with page-level memory management. @@ -89,9 +98,6 @@ def __init__(self, num_tokens, hidden_size, page_size, device, overflow, dtype, else: self._reset_free_list_host = None - # Legacy single-buffer API (used by kernels when host disabled): same as cuda - self.free_list = self.free_list_cuda - def reset(self): """Reset both CUDA and host free lists (CUDA graph safe: no new allocations).""" self.free_list_cuda.copy_(self._reset_free_list_cuda) @@ -610,8 +616,8 @@ def __init__(self): self.overflow = None self.device = None - # Page size for paged memory management - self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page + # Page size for paged memory management (default; overwritten from config in paged_stash_reset) + self.page_size = 64 @property def pack_stream(self): @@ -713,28 +719,33 @@ def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): f"{self.paged_tensors_to_reload[pp_schedule_layer]}" ) - def allocate_stash_buffers(self, stash_buffer_size_factor=1.10): - """Allocate stash buffers organized by [dtype][hidden_size].""" + def allocate_stash_buffers( + self, stash_buffer_size_factor: Union[float, Tuple[float, float]] = 1.10 + ): + """Allocate stash buffers organized by [dtype][hidden_size]. + + stash_buffer_size_factor: single float for CUDA factor (CPU disabled), or (cuda, cpu). + """ self.stash_buffers = {} self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) - # stash_buffer_size_factor controls both which sizing signal to use and how much headroom - # to allocate: + cuda_factor, cpu_size_factor = _normalize_stash_buffer_size_factor(stash_buffer_size_factor) + + # cuda_factor controls both which sizing signal to use and how much headroom to allocate: # - positive: size based on avg_num_tokens-derived maxima # - negative: size based on actual num_tokens-derived maxima (legacy behavior) - # In both cases we scale by abs(stash_buffer_size_factor). - if stash_buffer_size_factor >= 0: + # In both cases we scale by abs(cuda_factor). + if cuda_factor >= 0: max_tokens_dict = self.max_avg_tokens_across_vp_stages - scale = stash_buffer_size_factor + scale = cuda_factor else: max_tokens_dict = self.max_tokens_across_vp_stages - scale = -stash_buffer_size_factor + scale = -cuda_factor # Fallback safety: if avg-based dict is not available/populated yet, use actual-max dict. if not max_tokens_dict: max_tokens_dict = self.max_tokens_across_vp_stages - cpu_size_factor = float(os.getenv('STASH_BUFFER_CPU_SIZE_FACTOR', '0')) for dtype, hidden_size in max_tokens_dict: if dtype not in self.stash_buffers: self.stash_buffers[dtype] = {} @@ -1011,11 +1022,18 @@ def paged_stash_set_last_layer(is_last_layer=False): return stash_manager._last_layer = is_last_layer -def paged_stash_reset(enabled=True): - """Reset the chunk handler, called at the start of a training iteration.""" +def paged_stash_reset(enabled=True, config=None): + """Reset the chunk handler, called at the start of a training iteration. + + config: optional TransformerConfig; if provided, stash_buffer_size_factor and + moe_paged_stash_page_size are read from it. Otherwise defaults to 1.10 (CUDA only) + and page_size 64. + """ stash_manager = PagedStashManager.get_instance() stash_manager.enabled = enabled stash_manager.iteration += 1 + if config is not None: + stash_manager.page_size = config.moe_paged_stash_page_size # current layer and microbatch for each vp stage for forward pass stash_manager.current_schedule_index = 0 @@ -1027,7 +1045,9 @@ def paged_stash_reset(enabled=True): elif stash_manager.status == 'capture': stash_manager.status = 'captured' print (f'schedule {stash_manager._pp_schedule}') - stash_buffer_size_factor = float(os.getenv('STASH_BUFFER_SIZE_FACTOR', '1.10')) + stash_buffer_size_factor = ( + config.stash_buffer_size_factor if config is not None else 1.10 + ) stash_manager.allocate_stash_buffers(stash_buffer_size_factor=stash_buffer_size_factor) elif stash_manager.status == 'captured': pass diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index b57cd4974a0..d4d178eef87 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1004,6 +1004,9 @@ class TransformerConfig(ModelParallelConfig): moe_paged_stash: bool = False """If True, enable paged stash for MoE expert activations.""" + moe_paged_stash_page_size: int = 64 + """Number of tokens per page for paged stash memory management.""" + stash_modules: Optional[list[str]] = None """The MoE submodules to stash activations for. choices: "expert_fc1", "moe_act", "expert_fc2". @@ -1012,6 +1015,10 @@ class TransformerConfig(ModelParallelConfig): "expert_fc2": stash the input of the expert fc2 part. """ + stash_buffer_size_factor: Union[float, Tuple[float, float]] = 1.10 + """Scale factor(s) for paged stash buffer allocation. A single float sets the CUDA buffer factor + (CPU buffer disabled). Two numbers (cuda, cpu) set both: e.g. (1.10, 0.5) for 10% CUDA headroom + and 0.5x host buffer. For CUDA, sign selects sizing: positive = avg-based, negative = actual-max.""" def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. diff --git a/tests/unit_tests/transformer/moe/test_paged_stashing.py b/tests/unit_tests/transformer/moe/test_paged_stashing.py index 1a759bd55c5..a34503092f6 100644 --- a/tests/unit_tests/transformer/moe/test_paged_stashing.py +++ b/tests/unit_tests/transformer/moe/test_paged_stashing.py @@ -307,7 +307,6 @@ def test_overload_factor_and_over_budget(self): assert tokens_per_expert is not None tokens_per_ep_rank = tokens_per_expert.sum().item() - overload_factor = tokens_per_ep_rank / (seq_length * topk) over_budget_tensor = container.moe_layer.token_dispatcher.check_over_budget() over_budget = over_budget_tensor.item() if over_budget_tensor is not None else False From 7c7ab96fd321c2039f9d57b0295a0a985fdaf06f Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Wed, 18 Mar 2026 20:33:26 +0800 Subject: [PATCH 54/57] Refactor the knobs a bit so it is more intuitive --- megatron/core/transformer/moe/paged_stash.py | 62 ++++++++++--------- .../core/transformer/transformer_config.py | 11 ++-- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index decb012def8..24ed71eae1d 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from contextlib import nullcontext -from typing import Any, Tuple, Union +from typing import Any import torch import triton @@ -14,15 +14,6 @@ SCALE_INV_BLOCK_SIZE = 32 -def _normalize_stash_buffer_size_factor( - value: Union[float, Tuple[float, float], list], -) -> Tuple[float, float]: - """Normalize stash_buffer_size_factor to (cuda_factor, cpu_factor).""" - if isinstance(value, (list, tuple)) and len(value) == 2: - return (float(value[0]), float(value[1])) - return (float(value), 0.0) - - class PagedStashBuffer: """ A paged stash buffer with page-level memory management. @@ -720,38 +711,52 @@ def reload_paged_tensors(self, pp_schedule_layer, no_wait=False): ) def allocate_stash_buffers( - self, stash_buffer_size_factor: Union[float, Tuple[float, float]] = 1.10 + self, + stash_buffer_size_factor_cuda: float = 1.10, + stash_buffer_size_factor_cpu: float = 0.0, ): - """Allocate stash buffers organized by [dtype][hidden_size]. - - stash_buffer_size_factor: single float for CUDA factor (CPU disabled), or (cuda, cpu). - """ + """Allocate stash buffers organized by [dtype][hidden_size].""" self.stash_buffers = {} self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device) - cuda_factor, cpu_size_factor = _normalize_stash_buffer_size_factor(stash_buffer_size_factor) + cuda_factor = stash_buffer_size_factor_cuda + cpu_factor = stash_buffer_size_factor_cpu - # cuda_factor controls both which sizing signal to use and how much headroom to allocate: + # Both factors use the same sign convention: # - positive: size based on avg_num_tokens-derived maxima # - negative: size based on actual num_tokens-derived maxima (legacy behavior) - # In both cases we scale by abs(cuda_factor). + # Scale is always abs(factor). For CPU, 0 means no host buffer. if cuda_factor >= 0: max_tokens_dict = self.max_avg_tokens_across_vp_stages - scale = cuda_factor + cuda_scale = cuda_factor else: max_tokens_dict = self.max_tokens_across_vp_stages - scale = -cuda_factor + cuda_scale = -cuda_factor # Fallback safety: if avg-based dict is not available/populated yet, use actual-max dict. if not max_tokens_dict: max_tokens_dict = self.max_tokens_across_vp_stages + if cpu_factor > 0: + host_tokens_dict = self.max_avg_tokens_across_vp_stages or self.max_tokens_across_vp_stages + cpu_scale = cpu_factor + elif cpu_factor < 0: + host_tokens_dict = self.max_tokens_across_vp_stages + cpu_scale = -cpu_factor + else: + host_tokens_dict = None + cpu_scale = 0.0 + for dtype, hidden_size in max_tokens_dict: if dtype not in self.stash_buffers: self.stash_buffers[dtype] = {} assert hidden_size not in self.stash_buffers[dtype] - num_tokens = int(max_tokens_dict[dtype, hidden_size] * scale) - num_tokens_host = int(max_tokens_dict[dtype, hidden_size] * cpu_size_factor) if cpu_size_factor > 0 else 0 + num_tokens = int(max_tokens_dict[dtype, hidden_size] * cuda_scale) + num_tokens_host = ( + int(host_tokens_dict[dtype, hidden_size] * cpu_scale) + if host_tokens_dict is not None and (dtype, hidden_size) in host_tokens_dict + else 0 + ) buf_dtype = torch.uint8 if dtype in [torch.float8_e4m3fn, torch.float8_e8m0fnu] else dtype self.stash_buffers[dtype][hidden_size] = PagedStashBuffer( num_tokens, @@ -1025,9 +1030,8 @@ def paged_stash_set_last_layer(is_last_layer=False): def paged_stash_reset(enabled=True, config=None): """Reset the chunk handler, called at the start of a training iteration. - config: optional TransformerConfig; if provided, stash_buffer_size_factor and - moe_paged_stash_page_size are read from it. Otherwise defaults to 1.10 (CUDA only) - and page_size 64. + config: optional TransformerConfig; if provided, stash_buffer_size_factor_cuda/cpu and + moe_paged_stash_page_size are read from it. Otherwise defaults to 1.10 (CUDA), 0.0 (CPU). """ stash_manager = PagedStashManager.get_instance() stash_manager.enabled = enabled @@ -1045,10 +1049,12 @@ def paged_stash_reset(enabled=True, config=None): elif stash_manager.status == 'capture': stash_manager.status = 'captured' print (f'schedule {stash_manager._pp_schedule}') - stash_buffer_size_factor = ( - config.stash_buffer_size_factor if config is not None else 1.10 + cuda_factor = config.stash_buffer_size_factor_cuda if config is not None else 1.10 + cpu_factor = config.stash_buffer_size_factor_cpu if config is not None else 0.0 + stash_manager.allocate_stash_buffers( + stash_buffer_size_factor_cuda=cuda_factor, + stash_buffer_size_factor_cpu=cpu_factor, ) - stash_manager.allocate_stash_buffers(stash_buffer_size_factor=stash_buffer_size_factor) elif stash_manager.status == 'captured': pass diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index d4d178eef87..1fadc3da242 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1015,10 +1015,13 @@ class TransformerConfig(ModelParallelConfig): "expert_fc2": stash the input of the expert fc2 part. """ - stash_buffer_size_factor: Union[float, Tuple[float, float]] = 1.10 - """Scale factor(s) for paged stash buffer allocation. A single float sets the CUDA buffer factor - (CPU buffer disabled). Two numbers (cuda, cpu) set both: e.g. (1.10, 0.5) for 10% CUDA headroom - and 0.5x host buffer. For CUDA, sign selects sizing: positive = avg-based, negative = actual-max.""" + stash_buffer_size_factor_cuda: float = 1.10 + """Scale factor for paged stash CUDA buffer allocation. Sign selects sizing: positive = avg-based, + negative = actual-max. Magnitude is headroom (e.g. 1.10 = 10%).""" + + stash_buffer_size_factor_cpu: float = 0.0 + """Scale factor for paged stash host buffer. 0 disables host buffer. Same sign convention as + stash_buffer_size_factor_cuda: positive = avg-based, negative = actual-max; scale = abs(factor).""" def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. From 2dc0c53ddbfbad562f92337187078629a8f7d3f8 Mon Sep 17 00:00:00 2001 From: Vasudevan Rengasamy Date: Wed, 18 Mar 2026 15:31:58 -0700 Subject: [PATCH 55/57] Use get_attr_wrapped_model util to access moe and mtp layers --- megatron/core/transformer/moe/paged_stash.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/moe/paged_stash.py b/megatron/core/transformer/moe/paged_stash.py index 24ed71eae1d..06281583a9f 100644 --- a/megatron/core/transformer/moe/paged_stash.py +++ b/megatron/core/transformer/moe/paged_stash.py @@ -9,6 +9,7 @@ from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from megatron.core.full_cuda_graph import FullCudaGraphWrapper +from megatron.core.utils import get_attr_wrapped_model GLOBAL_BLOCK_SIZE = 1024 SCALE_INV_BLOCK_SIZE = 32 @@ -1097,14 +1098,17 @@ def __init__(self, config, copy_main_params, model, optimizer, forward_backward_ self.forward_backward_func = forward_backward_func self.moe_layers = [] for model_chunk in self.model: - for layer in model_chunk.module.module.decoder.layers: + model_with_decoder = get_attr_wrapped_model( + model_chunk, "decoder", allow_none=False, return_model_obj=True + ) + for layer in model_with_decoder.decoder.layers: mlp = layer.mlp if hasattr(mlp, 'token_dispatcher') and hasattr( mlp.token_dispatcher, 'check_over_budget' ): self.moe_layers.append(mlp) - if model_chunk.module.module.mtp_process: - for layer in model_chunk.module.module.mtp.layers: + if model_with_decoder.mtp_process: + for layer in model_with_decoder.mtp.layers: mlp = layer.mtp_model_layer.mlp if hasattr(mlp, 'token_dispatcher') and hasattr( mlp.token_dispatcher, 'check_over_budget' From bfb9dd47ead2cd6d68c4c2f4541cf7ca51cc57a1 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Fri, 20 Mar 2026 20:10:38 +0800 Subject: [PATCH 56/57] Refactor the unit test for paged stashing --- .../transformer/moe/test_paged_stashing.py | 230 ++++++++++++------ 1 file changed, 156 insertions(+), 74 deletions(-) diff --git a/tests/unit_tests/transformer/moe/test_paged_stashing.py b/tests/unit_tests/transformer/moe/test_paged_stashing.py index a34503092f6..02bf3bb542f 100644 --- a/tests/unit_tests/transformer/moe/test_paged_stashing.py +++ b/tests/unit_tests/transformer/moe/test_paged_stashing.py @@ -11,6 +11,7 @@ from megatron.core.transformer.moe.moe_utils import get_align_size_for_quantization from megatron.core.transformer.moe.experts import TEGroupedMLP from megatron.core.transformer.moe.paged_stash import ( + check_paged_stash_overflow, paged_stash_init_chunk_handler, paged_stash_reset, ) @@ -19,6 +20,34 @@ from tests.unit_tests.test_utilities import Utils +def _global_tokens_per_expert_from_local_routing_map(routing_map: torch.Tensor) -> torch.Tensor: + """Per-expert token counts from a local routing map, summed across the default process group. + + ``routing_map`` is shaped [num_local_token_rows, num_experts] (as in + ``_HybridEPManager``). Tests here assume world size equals expert-parallel size (all GPUs + are EP ranks); ``all_reduce`` on the world group aggregates disjoint local maps. + """ + counts = routing_map.sum(dim=0).to(torch.int64) + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.all_reduce(counts, op=torch.distributed.ReduceOp.SUM) + return counts + + +def _tokens_per_expert_from_routing_map(routing_map: torch.Tensor, layer: MoELayer) -> torch.Tensor: + """Per-local-expert assignment counts from the routing map (columns for this EP rank).""" + counts = _global_tokens_per_expert_from_local_routing_map(routing_map) + idx = torch.as_tensor(layer.local_expert_indices, device=counts.device, dtype=torch.long) + return counts[idx].to(torch.int64).clone() + + +def _pad_token_counts_to_align_size( + tokens_per_expert: torch.Tensor, pad_multiple: int +) -> torch.Tensor: + """Round each count up to a multiple of ``pad_multiple`` (``n + (-n % m)`` like budget).""" + t = tokens_per_expert.to(torch.int64) + return t + (-t % pad_multiple) + + class MoEModelTestContainer: def __init__( self, @@ -92,12 +121,19 @@ def __init__( moe_router_padding_for_fp8=kwargs.get("moe_router_padding_for_fp8", True), use_transformer_engine_op_fuser=kwargs.get("use_transformer_engine_op_fuser", False), moe_mlp_glu_interleave_size=kwargs.get("moe_mlp_glu_interleave_size", None), - moe_router_padding_for_quantization=kwargs.get("moe_router_padding_for_quantization", False), + moe_router_padding_for_quantization=kwargs.get( + "moe_router_padding_for_quantization", False + ), gated_linear_unit=kwargs.get("gated_linear_unit", False), activation_func=kwargs.get("activation_func", F.gelu), moe_router_force_biased=kwargs.get("moe_router_force_biased", None), + stash_buffer_size_factor_cuda=0.5, + stash_buffer_size_factor_cpu=1.5, ) - self.moe_layer = self._create_moe_layer(layer_number=0) + self.moe_layers = [ + self._create_moe_layer(layer_number=i) for i in range(num_layers) + ] + self.moe_layer = self.moe_layers[0] def _create_moe_layer(self, layer_number=0): transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( @@ -114,43 +150,44 @@ def _create_moe_layer(self, layer_number=0): return moe_layer def zero_grad(self): - self.moe_layer.zero_grad() + for layer in self.moe_layers: + layer.zero_grad() def __del__(self): torch.distributed.barrier() torch.cuda.synchronize() Utils.destroy_model_parallel() - def forward_backward(self, hidden_states): - """Run one forward and backward pass through the MoE layer. - - Returns: - output: MoE layer output (detached). - hidden_states_grad: Gradient w.r.t. hidden_states. - routing_map: Token-to-expert routing map from the dispatcher (after forward). - tokens_per_expert: Number of tokens per local expert on this EP rank (after forward). - """ - hidden_states = hidden_states.cuda().requires_grad_(True) - quantization_context = get_fp8_context(self.config) - with quantization_context: - output, _ = self.moe_layer(hidden_states) - # Capture routing_map and tokens_per_expert after forward (before backward) - comm = getattr(self.moe_layer.token_dispatcher, "_comm_manager", None) - routing_map = getattr(comm, "routing_map", None) - tokens_per_expert = ( - comm.get_number_of_tokens_per_expert() - if comm is not None and hasattr(comm, "get_number_of_tokens_per_expert") - else None - ) - # Use contiguous gradient to avoid non-contiguous grad in HybridEP combine backward - # (output.sum().backward() produces a broadcast gradient that is non-contiguous) - output.backward(torch.ones_like(output)) - return output.detach(), hidden_states.grad, routing_map, tokens_per_expert - def destroy(self): Utils.destroy_model_parallel() +def _forward_backward_all_layers(container: MoEModelTestContainer, hidden_states: torch.Tensor): + """Forward/backward all MoE layers; returns output, input grad, last layer routing state.""" + initial_hidden_states = hidden_states.cuda().requires_grad_(True) + hidden_states = initial_hidden_states + quantization_context = get_fp8_context(container.config) + with quantization_context: + for layer in container.moe_layers: + hidden_states, _ = layer(hidden_states) + output = hidden_states + last_layer = container.moe_layers[-1] + comm = getattr(last_layer.token_dispatcher, "_comm_manager", None) + routing_map = getattr(comm, "routing_map", None) + tokens_per_expert = ( + comm.get_number_of_tokens_per_expert() + if comm is not None and hasattr(comm, "get_number_of_tokens_per_expert") + else None + ) + output.backward(torch.ones_like(output)) + return ( + output.detach(), + initial_hidden_states.grad, + routing_map, + tokens_per_expert, + ) + + def is_hybrid_ep_available(): from megatron.core.transformer.moe.fused_a2a import HAVE_HYBRIDEP return HAVE_HYBRIDEP @@ -166,7 +203,8 @@ def teardown_method(self, method): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal - def test_forward_backward(self): + def test_forward_backward_4_layers(self): + """Test paged stashing with 4 MoE layers: ref run vs paged run match.""" if not is_hybrid_ep_available(): pytest.skip("Hybrid EP is not available") @@ -177,7 +215,7 @@ def test_forward_backward(self): ep_size=4, pp_size=1, num_moe_experts=8, - num_layers=2, + num_layers=4, moe_router_topk=2, moe_router_load_balancing_type="aux_loss", moe_token_dispatcher_type="flex", @@ -197,11 +235,12 @@ def test_forward_backward(self): gated_linear_unit=True, activation_func=F.silu, ) - if not isinstance(container.moe_layer.experts, TEGroupedMLP) or not container.moe_layer.experts._is_fused_impl_supported(): + experts = container.moe_layer.experts + fused_ok = isinstance(experts, TEGroupedMLP) and experts._is_fused_impl_supported() + if not fused_ok: container.destroy() pytest.skip("TEGroupedMLP fused impl not supported") - # [sequence_length, batch_size, hidden_size] for MoELayer.forward seq_length = 1024 batch_size = 1 hidden_size = container.config.hidden_size @@ -210,32 +249,42 @@ def test_forward_backward(self): ) # First iteration: capture schedule, capacity, etc. - paged_stash_reset(True) + paged_stash_reset(True, config=container.config) paged_stash_init_chunk_handler(1, 0) output_ref, hidden_states_grad_ref, routing_map_ref, tokens_per_expert_ref = ( - container.forward_backward(hidden_states) + _forward_backward_all_layers(container, hidden_states) ) container.zero_grad() # Second iteration: run with paged stash. - paged_stash_reset(True) + paged_stash_reset(True, config=container.config) paged_stash_init_chunk_handler(1, 0) - output, hidden_states_grad, routing_map, tokens_per_expert = container.forward_backward( - hidden_states + output, hidden_states_grad, routing_map, tokens_per_expert = _forward_backward_all_layers( + container, hidden_states ) - # Verify output and input gradient match the first iteration. - torch.testing.assert_close(output, output_ref, atol=1e-4, rtol=1e-4) - torch.testing.assert_close( - hidden_states_grad, hidden_states_grad_ref, atol=1e-4, rtol=1e-4 + overflow = check_paged_stash_overflow() + assert overflow.any().item() == 0 + + assert torch.allclose(output, output_ref, atol=1e-4, rtol=1e-4), ( + f"output != output_ref: max diff = {(output - output_ref).abs().max().item()}" + ) + assert torch.allclose(hidden_states_grad, hidden_states_grad_ref, atol=1e-4, rtol=1e-4), ( + f"hidden_states_grad != ref: max diff = " + f"{(hidden_states_grad - hidden_states_grad_ref).abs().max().item()}" ) - # Routing and token counts available after forward (e.g. for debugging or further checks) if routing_map is not None and tokens_per_expert is not None: num_tokens_per_ep_rank = tokens_per_expert.sum().item() - assert num_tokens_per_ep_rank > 0 + assert num_tokens_per_ep_rank > 0, ( + f"num_tokens_per_ep_rank={num_tokens_per_ep_rank} (expected > 0)" + ) assert routing_map_ref is not None and tokens_per_expert_ref is not None - torch.testing.assert_close(tokens_per_expert, tokens_per_expert_ref) + tpe_f = tokens_per_expert.float() + ref_f = tokens_per_expert_ref.float() + assert torch.allclose(tpe_f, ref_f, atol=1e-4, rtol=1e-4), ( + f"tokens_per_expert != ref: max diff = {(tpe_f - ref_f).abs().max().item()}" + ) @pytest.mark.skipif(not is_hybrid_ep_available(), reason="Hybrid EP are not available") @@ -249,8 +298,7 @@ def teardown_method(self, method): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal def test_overload_factor_and_over_budget(self): - """Test budget computation (same as token_dispatcher lines 1017-1025) and assert - over_budget flag is set when tokens_per_ep_rank exceeds budget.""" + """Budget matches HybridEP setup_metadata; over_budget matches map-derived load.""" if not is_hybrid_ep_available(): pytest.skip("Hybrid EP is not available") @@ -261,8 +309,8 @@ def test_overload_factor_and_over_budget(self): ep_size=4, pp_size=1, num_moe_experts=8, - num_layers=1, - moe_router_topk=4, + num_layers=4, + moe_router_topk=2, moe_router_load_balancing_type="aux_loss", moe_token_dispatcher_type="flex", moe_permute_fusion=True, @@ -274,7 +322,7 @@ def test_overload_factor_and_over_budget(self): moe_use_legacy_grouped_gemm=False, moe_paged_stash=True, stash_modules=["expert_fc1", "moe_act", "expert_fc2"], - moe_expert_rank_capacity_factor=1.0, + moe_expert_rank_capacity_factor=1.5, use_transformer_engine_op_fuser=True, moe_mlp_glu_interleave_size=32, moe_router_padding_for_quantization=True, @@ -282,42 +330,76 @@ def test_overload_factor_and_over_budget(self): activation_func=F.silu, moe_router_force_biased=1, ) - if not isinstance(container.moe_layer.experts, TEGroupedMLP) or not container.moe_layer.experts._is_fused_impl_supported(): + experts = container.moe_layer.experts + fused_ok = isinstance(experts, TEGroupedMLP) and experts._is_fused_impl_supported() + if not fused_ok: container.destroy() pytest.skip("TEGroupedMLP fused impl not supported") - seq_length = 4096 + seq_length = 1024 batch_size = 1 topk = container.config.moe_router_topk capacity_factor = container.config.moe_expert_rank_capacity_factor - hidden_size = container.config.hidden_size hidden_states = torch.randn( - (seq_length, batch_size, hidden_size), dtype=torch.bfloat16 + (seq_length, batch_size, container.config.hidden_size), dtype=torch.bfloat16 ) - # Budget computed like token_dispatcher._HybridEPManager.setup_metadata (lines 1017-1025) - num_tokens = seq_length * batch_size + num_tokens = seq_length * batch_size * topk pad_multiple = get_align_size_for_quantization(container.config) - budget = int(num_tokens * topk * capacity_factor) + budget = int(num_tokens * capacity_factor) budget += -budget % pad_multiple - paged_stash_reset(True) + paged_stash_reset(True, config=container.config) paged_stash_init_chunk_handler(1, 0) - _, _, _, tokens_per_expert = container.forward_backward(hidden_states) - - assert tokens_per_expert is not None - tokens_per_ep_rank = tokens_per_expert.sum().item() - over_budget_tensor = container.moe_layer.token_dispatcher.check_over_budget() - over_budget = over_budget_tensor.item() if over_budget_tensor is not None else False - - # When tokens_per_ep_rank > budget, over_budget flag must be raised - if tokens_per_ep_rank >= budget: - assert over_budget, ( - f"tokens_per_ep_rank ({tokens_per_ep_rank}) > budget ({budget}), " - "but over_budget flag was not set" + _forward_backward_all_layers(container, hidden_states) + + overflow = check_paged_stash_overflow() + num_layers = len(container.moe_layers) + stash_cuda = container.config.stash_buffer_size_factor_cuda + stash_cpu = container.config.stash_buffer_size_factor_cpu + stash_buffer_size = num_tokens * num_layers * (stash_cuda + stash_cpu) + + total_tokens = 0 + for layer_idx, layer in enumerate(container.moe_layers): + comm = getattr(layer.token_dispatcher, "_comm_manager", None) + routing_map = getattr(comm, "routing_map", None) if comm is not None else None + over_budget_tensor = ( + layer.token_dispatcher.check_over_budget() + if hasattr(layer.token_dispatcher, "check_over_budget") + else None ) - else: - assert not over_budget, ( - f"tokens_per_ep_rank ({tokens_per_ep_rank}) <= budget ({budget}), " - "but over_budget flag was set" + over_budget = over_budget_tensor.item() if over_budget_tensor is not None else False + + assert routing_map is not None, f"layer {layer_idx}: routing_map is None" + assert routing_map.dim() == 2, f"layer {layer_idx}: expected 2D routing_map" + assert routing_map.shape[1] == container.config.num_moe_experts, ( + f"layer {layer_idx}: routing_map has {routing_map.shape[1]} experts, " + f"expected {container.config.num_moe_experts}" + ) + tokens_per_expert_from_map = _tokens_per_expert_from_routing_map(routing_map, layer) + tokens_per_expert_from_map_padded = _pad_token_counts_to_align_size( + tokens_per_expert_from_map, pad_multiple ) + tokens_per_ep_rank_from_map = tokens_per_expert_from_map_padded.sum().item() + total_tokens += tokens_per_ep_rank_from_map + + # Padded map-derived tokens strictly over budget iff dispatcher reports over_budget + if tokens_per_ep_rank_from_map > budget: + assert over_budget, ( + f"layer {layer_idx}: tokens_per_ep_rank_from_map " + f"({tokens_per_ep_rank_from_map}) > budget ({budget}), " + f"but over_budget flag was not set" + ) + else: + assert not over_budget, ( + f"layer {layer_idx}: tokens_per_ep_rank_from_map " + f"({tokens_per_ep_rank_from_map}) <= budget ({budget}), " + f"but over_budget flag was set" + ) + + overflow_set = overflow.any().item() + stash_exceeded = total_tokens > stash_buffer_size + assert overflow_set == stash_exceeded, ( + f"overflow {overflow_set} should match total_tokens > stash_buffer_size " + f"({total_tokens} > {stash_buffer_size})" + ) From cd89d4ffe8569380013db70b3ce1984cb6998114 Mon Sep 17 00:00:00 2001 From: Nan Zheng Date: Sat, 21 Mar 2026 19:54:34 +0800 Subject: [PATCH 57/57] Clean up after rebase --- megatron/core/fp8_utils.py | 2 +- megatron/core/transformer/transformer_config.py | 4 ---- tests/unit_tests/transformer/moe/test_paged_stashing.py | 5 ----- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/megatron/core/fp8_utils.py b/megatron/core/fp8_utils.py index 0a85ea42e19..fa6be91dfbf 100644 --- a/megatron/core/fp8_utils.py +++ b/megatron/core/fp8_utils.py @@ -168,7 +168,7 @@ def _get_custom_recipe(quantizer_factory_python_path: str) -> Union[Fp8Recipe, F def get_fp8_align_size(fp8_recipe: Fp8Recipe) -> int: """Get the alignment size required for fp8 GEMM.""" if fp8_recipe == Fp8Recipe.mxfp8: - return 128 + return 32 else: return 16 diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 1fadc3da242..fa1c0e1e215 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -690,10 +690,6 @@ class TransformerConfig(ModelParallelConfig): GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm). """ - moe_use_device_initiated_grouped_gemm: bool = False - """Use the cutlass grouped gemm kernel, which allows for the token_per_expert tensor on GPU. - This can prevent the GPU-CPU synchronization during the grouped gemm.""" - moe_use_legacy_grouped_gemm: bool = False """Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.""" diff --git a/tests/unit_tests/transformer/moe/test_paged_stashing.py b/tests/unit_tests/transformer/moe/test_paged_stashing.py index 02bf3bb542f..62a22e04054 100644 --- a/tests/unit_tests/transformer/moe/test_paged_stashing.py +++ b/tests/unit_tests/transformer/moe/test_paged_stashing.py @@ -111,9 +111,6 @@ def __init__( moe_permute_fusion=kwargs.get("moe_permute_fusion", False), moe_flex_dispatcher_backend=kwargs.get("moe_flex_dispatcher_backend", None), moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False), - moe_use_device_initiated_grouped_gemm=kwargs.get( - "moe_use_device_initiated_grouped_gemm", False - ), moe_use_legacy_grouped_gemm=kwargs.get("moe_use_legacy_grouped_gemm", False), moe_paged_stash=kwargs.get("moe_paged_stash", False), stash_modules=kwargs.get("stash_modules", None), @@ -224,7 +221,6 @@ def test_forward_backward_4_layers(self): moe_flex_dispatcher_backend="hybridep", test_dtype=torch.bfloat16, moe_grouped_gemm=True, - moe_use_device_initiated_grouped_gemm=True, moe_use_legacy_grouped_gemm=False, moe_paged_stash=True, stash_modules=["expert_fc1", "moe_act", "expert_fc2"], @@ -318,7 +314,6 @@ def test_overload_factor_and_over_budget(self): moe_flex_dispatcher_backend="hybridep", test_dtype=torch.bfloat16, moe_grouped_gemm=True, - moe_use_device_initiated_grouped_gemm=True, moe_use_legacy_grouped_gemm=False, moe_paged_stash=True, stash_modules=["expert_fc1", "moe_act", "expert_fc2"],