@@ -345,6 +345,9 @@ def __init__(
345345 self .cudagraph_tensor_store = MoECudaGraphTensorStore ()
346346 self .fwd_execution_map = ["route" , "expert_compute" , "postprocess" ]
347347
348+ # Setup events and streams for delayed wgrad computation.
349+ self .setup_delayed_wgrad_for_dispatch_backward_overlap ()
350+
348351 def _setup_inference_mode (self , pg_collection ):
349352 """Set up inference-optimized token dispatcher and state.
350353
@@ -365,6 +368,16 @@ def _setup_inference_mode(self, pg_collection):
365368 pg_collection = pg_collection ,
366369 )
367370
371+ def setup_delayed_wgrad_for_dispatch_backward_overlap (self ):
372+ """Initializes CUDA events and streams for overlapping expert
373+ weight gradient computation with dispatch backward.
374+ """
375+ self ._delayed_wgrad_event : Optional [torch .cuda .Event ] = None
376+ self ._delayed_wgrad_stream : Optional [torch .cuda .Stream ] = None
377+ if self .config .overlap_dispatch_backward_with_experts_wgrad :
378+ self ._delayed_wgrad_event = torch .cuda .Event ()
379+ self ._delayed_wgrad_stream = torch .cuda .Stream (device = "cuda" )
380+
368381 def set_inference_cuda_graphed_iteration (self ):
369382 """Enable CUDA-graphed iteration mode on this layer, its router, and its experts.
370383
@@ -435,6 +448,8 @@ def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
435448 tokens and their associated probabilities to the devices hosting their assigned
436449 experts.
437450 """
451+ if self .config .overlap_dispatch_backward_with_experts_wgrad :
452+ hidden_states = _RegisterDelayedWgradForExperts .apply (self , hidden_states )
438453 return self .token_dispatcher .token_dispatch (hidden_states , probs )
439454
440455 @maybe_skip_or_early_return_by_cudagraph ("shared_experts_compute" )
@@ -473,6 +488,10 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso
473488 for each expert. It then passes the tokens through the local experts.
474489 The output from the experts is preprocessed for the combine step.
475490 """
491+ if self .config .overlap_dispatch_backward_with_experts_wgrad :
492+ hidden_states = _RecordExpertDgradCompletion .apply (
493+ self ._delayed_wgrad_event , hidden_states
494+ )
476495 dispatched_input , tokens_per_expert , permuted_probs = (
477496 self .token_dispatcher .dispatch_postprocess (hidden_states , probs )
478497 )
@@ -618,24 +637,24 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None):
618637
619638 def backward_dw (self , routed_experts : bool = True , shared_experts : bool = False ):
620639 """Compute weight gradients for experts and shared experts."""
640+ from megatron .core .pipeline_parallel .utils import get_comm_stream
641+
621642 # TODO(Wohox): replace the "routed_experts" and "shared_experts" arguments with better
622643 # naming to better explain that they are actually from different fine-grained callables,
623644 # or use scanning to decide which backward_dw should be called.
624645 if routed_experts :
625646 self .experts .backward_dw ()
626- if self .config .moe_latent_size :
647+ if self .config .moe_latent_size and self . config . overlap_moe_expert_parallel_comm :
627648 # TODO(Wohox): fc2_latent_proj forward and backward are executed in comm stream,
628649 # so we execute its backward_dw in the comm stream too. But this may harm the
629650 # EP overlap performance. Better to check if there is a better way to handle this.
630- from megatron .core .pipeline_parallel .utils import get_comm_stream
631-
632651 comm_stream = get_comm_stream ()
633652 with torch .cuda .stream (comm_stream ):
634653 self .fc2_latent_proj .backward_dw ()
635654 if shared_experts :
636655 if self .use_shared_expert and not self .shared_expert_overlap :
637656 self .shared_experts .backward_dw ()
638- if self .config .moe_latent_size :
657+ if self .config .moe_latent_size and self . config . overlap_moe_expert_parallel_comm :
639658 self .fc1_latent_proj .backward_dw ()
640659
641660 def set_for_recompute_pre_mlp_layernorm (self ):
@@ -646,3 +665,66 @@ def set_for_recompute_pre_mlp_layernorm(self):
646665 from megatron .core .extensions .transformer_engine import set_save_original_input
647666
648667 set_save_original_input (self .shared_experts .linear_fc1 )
668+
669+
670+ class _RecordExpertDgradCompletion (torch .autograd .Function ):
671+ """Autograd function that records a CUDA event when expert data gradients finish.
672+
673+ Placed in the forward graph just before the expert computation so that during
674+ the backward pass, when the expert dgrad completes, we record an event. The
675+ subsequent ``_RegisterDelayedWgradForExperts`` waits on this event before
676+ launching the delayed wgrad computation on a separate CUDA stream.
677+ """
678+
679+ @staticmethod
680+ def forward (ctx , event : torch .cuda .Event , * inputs ):
681+ """Forward pass that stores the event and passes through inputs unchanged."""
682+ ctx .event = event
683+ return inputs [0 ] if len (inputs ) == 1 else inputs
684+
685+ @staticmethod
686+ def backward (ctx , * grad_outputs ):
687+ """Backward pass that records the event when expert dgrad completes."""
688+ ctx .event .record (torch .cuda .current_stream ())
689+ ctx .event = None
690+ return (None ,) + grad_outputs
691+
692+
693+ class _RegisterDelayedWgradForExperts (torch .autograd .Function ):
694+ """Autograd function that orchestrates delayed wgrad computation for MoE experts.
695+
696+ Placed in the forward graph at the dispatch boundary. During the backward pass,
697+ this function:
698+ 1. Records an event on the current (backward) stream to signal the dgrad is done.
699+ 2. Executes the delayed wgrad computation on a dedicated CUDA stream.
700+ 3. Waits for the wgrad computation to complete.
701+ 4. Invokes the registered gradient processing callback (e.g., FSDP reduce-scatter).
702+ """
703+
704+ @staticmethod
705+ def forward (ctx , module : MoELayer , * inputs ):
706+ """Forward pass that stores the MoE module and passes through inputs unchanged."""
707+ ctx .module = module
708+ return inputs [0 ] if len (inputs ) == 1 else inputs
709+
710+ @staticmethod
711+ def backward (ctx , * grad_outputs ):
712+ """Backward pass that executes delayed wgrad computation on a separate stream."""
713+ module = ctx .module
714+ event = module ._delayed_wgrad_event
715+ wgrad_stream = module ._delayed_wgrad_stream
716+
717+ wgrad_stream .wait_event (event )
718+ with torch .cuda .stream (wgrad_stream ):
719+ with torch .cuda .nvtx .range ("delayed_expert_wgrad" ):
720+ module .backward_dw (routed_experts = True , shared_experts = False )
721+ event .record (wgrad_stream )
722+
723+ torch .cuda .current_stream ().wait_event (event )
724+
725+ for param in module .parameters ():
726+ if getattr (param , "post_wgrad_grad_acc_hook" , None ) is not None :
727+ param .post_wgrad_grad_acc_hook ()
728+
729+ ctx .module = None
730+ return (None ,) + grad_outputs
0 commit comments