@@ -57,6 +57,9 @@ class InferOutput:
5757 # for logging
5858 req_metrics : RequestMetrics = None
5959
60+ # expert ids
61+ routed_experts : torch .Tensor = None
62+
6063
6164def _tensorlize_block_offsets (block_offsets , dtype = torch .int32 ):
6265 """Tensorlize block_offsets."""
@@ -876,13 +879,18 @@ def _make_infer_outputs(
876879 cur_logprobs = (logprobs .vals [idx ][:num_logprobs + 1 ], logprobs .indices [idx ][:num_logprobs + 1 ])
877880
878881 req_metrics = RequestMetrics (new_token_timestamp , msg .engine_events )
882+ routed_experts = msg .routed_experts if msg .return_routed_experts and finish else None
883+ if routed_experts is not None and self .engine_config .enable_transfer_obj_ref :
884+ # only serialize for api server
885+ routed_experts = self .executor .serialize (routed_experts )
879886 out = InferOutput (session_id = session_id ,
880887 resp = msg .resp ,
881888 finish = finish ,
882889 token_ids = token_ids ,
883890 cache_block_ids = cache_block_ids ,
884891 req_metrics = req_metrics ,
885- logprobs = cur_logprobs )
892+ logprobs = cur_logprobs ,
893+ routed_experts = routed_experts )
886894 outputs [session_id ] = out
887895
888896 if msg .return_logits :
@@ -896,6 +904,10 @@ def __need_logits(seqs: SeqList):
896904 """Need logits."""
897905 return any (seq .return_logits for seq in seqs )
898906
907+ def __need_routed_experts (seqs : SeqList ):
908+ """Need routed experts."""
909+ return any (seq .return_routed_experts for seq in seqs )
910+
899911 def __need_schedule_again (prefill : bool , scheduler_output ):
900912 """Need schedule again."""
901913 # only reschedule when prefill
@@ -939,6 +951,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
939951 inputs = self .create_model_inputs (running , prefill )
940952 sampling_inputs = self .sampling_strategy .make_sampling_inputs (running )
941953 return_logits = __need_logits (running )
954+ return_routed_experts = __need_routed_experts (running )
942955 extra_inputs = self .model_agent_strategy .make_extra_inputs (running )
943956 stopping_criteria = self .model_agent_strategy .make_stopping_criteria (running )
944957
@@ -956,6 +969,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
956969 is_dummy = False ,
957970 sync_long_context = sync_long_context ,
958971 extra_inputs = extra_inputs ,
972+ return_routed_experts = return_routed_experts ,
959973 )
960974
961975 async def _await_forward_event (self , forward_event : asyncio .Event ):
@@ -991,6 +1005,7 @@ def __send_resp(out: InferOutput):
9911005 logits = out .logits ,
9921006 cache_block_ids = out .cache_block_ids ,
9931007 req_metrics = out .req_metrics ,
1008+ routed_experts = out .routed_experts ,
9941009 logprobs = logprobs ))
9951010
9961011 def __update_logprobs (step_outputs : List [InferOutput ]):
0 commit comments