2121 count_expert_num_tokens ,
2222 disable_inplace ,
2323)
24- from vllm .platforms import current_platform
2524from vllm .utils .math_utils import cdiv
2625from vllm .v1 .worker .ubatching import (
2726 dbo_enabled ,
@@ -682,14 +681,12 @@ def __init__(
682681 prepare_finalize : FusedMoEPrepareAndFinalize ,
683682 fused_experts : FusedMoEPermuteExpertsUnpermute ,
684683 shared_experts : torch .nn .Module | None = None ,
685- shared_experts_stream : torch .cuda .Stream | None = None ,
686684 moe_parallel_config : FusedMoEParallelConfig | None = None ,
687685 ):
688686 super ().__init__ ()
689687 self .prepare_finalize = prepare_finalize
690688 self .fused_experts = fused_experts
691689 self .shared_experts = shared_experts
692- self .shared_experts_stream = shared_experts_stream
693690
694691 # prefer an explicit FusedMoEParallelConfig when available (from
695692 # FusedMoE layers / tests).
@@ -904,34 +901,6 @@ def _slice_expert_tokens_metadata(
904901 expert_num_tokens_cpu = c_expert_num_tokens_cpu ,
905902 )
906903
907- def _maybe_setup_shared_experts_stream (
908- self , hidden_states : torch .Tensor
909- ) -> tuple [bool , torch .Tensor | None ]:
910- # decide whether to run shared experts on a separate CUDA stream to
911- # overlap with the main fused MoE kernel.
912- use_shared_experts_stream = (
913- self .shared_experts is not None
914- and self .shared_experts_stream is not None
915- and hidden_states .is_cuda
916- and (
917- hidden_states .shape [0 ]
918- <= envs .VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
919- )
920- )
921-
922- hidden_states_clone : torch .Tensor | None = None
923- if use_shared_experts_stream and self .shared_experts_stream is not None :
924- # TODO: Optimize this (complicated)
925- # Note: this clone adds overhead but is required
926- # for correctness with multiple CUDA streams and CUDA graph capture.
927- hidden_states_clone = hidden_states .clone ()
928- # record that the clone will be used by the separate stream so its
929- # lifetime is correctly tracked.
930- hidden_states_clone .record_stream (self .shared_experts_stream )
931- self .shared_experts_stream .wait_stream (torch .cuda .current_stream ())
932-
933- return use_shared_experts_stream , hidden_states_clone
934-
935904 def _prepare (
936905 self ,
937906 hidden_states : torch .Tensor ,
@@ -1119,30 +1088,12 @@ def _finalize(
11191088 topk_weights : torch .Tensor ,
11201089 topk_ids : torch .Tensor ,
11211090 apply_router_weight_on_input : bool ,
1122- hidden_states_clone : torch .Tensor | None = None ,
1123- use_shared_experts_stream : bool = False ,
11241091 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
11251092 """
11261093 The _finalize method is a wrapper around self.prepare_finalize.finalize
11271094 that handles DBO, async and shared expert overlap.
11281095 """
1129-
1130- def maybe_run_shared_experts () -> torch .Tensor | None :
1131- if self .shared_experts is None :
1132- return None
1133-
1134- if (
1135- not use_shared_experts_stream
1136- or self .shared_experts_stream is not None
1137- and (not hidden_states .is_cuda or not torch .cuda .is_available ())
1138- ):
1139- # fall back to running on the current stream
1140- return self .shared_experts (hidden_states )
1141-
1142- assert hidden_states_clone is not None
1143- # launch shared experts on the dedicated stream.
1144- with torch .cuda .stream (self .shared_experts_stream ):
1145- return self .shared_experts (hidden_states_clone )
1096+ shared_output : torch .Tensor | None = None
11461097
11471098 if not self .prepare_finalize .supports_async ():
11481099 assert not dbo_enabled ()
@@ -1155,7 +1106,8 @@ def maybe_run_shared_experts() -> torch.Tensor | None:
11551106 apply_router_weight_on_input ,
11561107 self .fused_experts .finalize_weight_and_reduce_impl (),
11571108 )
1158- shared_output = maybe_run_shared_experts ()
1109+ if self .shared_experts is not None :
1110+ shared_output = self .shared_experts (hidden_states )
11591111 else :
11601112 finalize_ret = self .prepare_finalize .finalize_async (
11611113 output ,
@@ -1165,8 +1117,8 @@ def maybe_run_shared_experts() -> torch.Tensor | None:
11651117 apply_router_weight_on_input ,
11661118 self .fused_experts .finalize_weight_and_reduce_impl (),
11671119 )
1168-
1169- shared_output = maybe_run_shared_experts ( )
1120+ if self . shared_experts is not None :
1121+ shared_output = self . shared_experts ( hidden_states )
11701122
11711123 # TODO(lucas): refactor this in the alternative schedules followup
11721124 # currently unpack if we have hook + receiver pair or just
@@ -1189,28 +1141,12 @@ def maybe_run_shared_experts() -> torch.Tensor | None:
11891141
11901142 receiver ()
11911143
1192- self ._wait_for_shared_experts_stream (hidden_states , use_shared_experts_stream )
1193-
11941144 if self .shared_experts is None :
11951145 return output
11961146 else :
11971147 assert shared_output is not None
11981148 return shared_output , output
11991149
1200- def _wait_for_shared_experts_stream (
1201- self , hidden_states : torch .Tensor , use_shared_experts_stream : bool
1202- ) -> None :
1203- # ensure that any work enqueued on the shared_experts_stream is
1204- # completed before the shared_output tensor is consumed
1205- if (
1206- self .shared_experts is not None
1207- and use_shared_experts_stream
1208- and self .shared_experts_stream is not None
1209- and hidden_states .is_cuda
1210- and current_platform .is_cuda ()
1211- ):
1212- torch .cuda .current_stream ().wait_stream (self .shared_experts_stream )
1213-
12141150 def forward (
12151151 self ,
12161152 hidden_states : torch .Tensor ,
@@ -1257,10 +1193,6 @@ def forward(
12571193 else :
12581194 output = torch .zeros_like (hidden_states )
12591195
1260- use_shared_experts_stream , hidden_states_clone = (
1261- self ._maybe_setup_shared_experts_stream (hidden_states )
1262- )
1263-
12641196 local_num_experts = w1 .size (0 )
12651197 if global_num_experts == - 1 :
12661198 global_num_experts = local_num_experts
@@ -1297,6 +1229,4 @@ def forward(
12971229 topk_weights ,
12981230 topk_ids ,
12991231 apply_router_weight_on_input ,
1300- hidden_states_clone = hidden_states_clone ,
1301- use_shared_experts_stream = use_shared_experts_stream ,
13021232 )
0 commit comments