@@ -480,24 +480,6 @@ def _allocate_dp_chunking_outputs(
480480
481481 return final_shared_hidden_states , final_fused_hidden_states
482482
483- def _maybe_overlap_gate_with_shared_experts (
484- self ,
485- hidden_states : torch .Tensor ,
486- router_logits : torch .Tensor ,
487- shared_experts_input : torch .Tensor | None ,
488- ) -> torch .Tensor :
489- # If router/gate provided, then apply it here.
490- # (Note: This code runs only when "overlapped mode" is on to allow
491- # parallel execution of shared experts with the FusedMoE via
492- # separate cuda stream)
493- if self .shared_experts is not None :
494- self .shared_experts .maybe_setup_shared_experts_stream (shared_experts_input )
495-
496- if self .gate is not None :
497- router_logits , _ = self .gate (hidden_states )
498-
499- return router_logits
500-
501483 @property
502484 def do_naive_dispatch_combine (self ) -> bool :
503485 return (
@@ -621,11 +603,15 @@ def forward_dispatch(
621603 # TODO(bnell): this can be removed after MK migration is complete.
622604 layer .ensure_moe_quant_config_init ()
623605
624- router_logits = self ._maybe_overlap_gate_with_shared_experts (
625- hidden_states ,
626- router_logits ,
627- shared_experts_input ,
628- )
606+ # Sync aux and main stream for shared expert multi-stream overlap.
607+ if self .shared_experts is not None :
608+ self .shared_experts .maybe_setup_shared_experts_stream (shared_experts_input )
609+
610+ # If the Runner holds the gate, apply it after the stream sync,
611+ # so it can run overlapped with the
612+ # NOTE: in future PR, MoE runner will always hold the gate.
613+ if self .gate is not None :
614+ router_logits , _ = self .gate (hidden_states )
629615
630616 self ._maybe_apply_shared_experts (
631617 shared_experts_input ,
0 commit comments