@@ -315,6 +315,11 @@ def _process_sampling_with_logprob_batch_output(self):
315
315
scores = self .output_scores [: batch * (K + 1 )].numpy ().reshape ([batch , K + 1 ])[:, : (K + 1 )]
316
316
ranks = self .output_ranks [:batch ].numpy ()
317
317
batch_result = list ()
318
+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
319
+ need_to_be_reschedule_req_ids = list (self .resource_manager .to_be_rescheduled_request_id_set )
320
+ for request_id in need_to_be_reschedule_req_ids :
321
+ if self .resource_manager .requests [request_id ].idx >= (batch - 1 ): # No more token generated for preempted request
322
+ self .resource_manager .reschedule_preempt_task (request_id )
318
323
for i in range (batch ):
319
324
if self .resource_manager .stop_flags [i ]:
320
325
continue
@@ -326,6 +331,9 @@ def _process_sampling_with_logprob_batch_output(self):
326
331
if recovery_stop :
327
332
llm_logger .info (f"recovery stop signal found at task { task_id } " )
328
333
if not recovery_stop and token_id < 0 :
334
+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
335
+ if task_id in self .resource_manager .to_be_rescheduled_request_id_set :
336
+ self .resource_manager .reschedule_preempt_task (task_id )
329
337
continue
330
338
331
339
if task .get ("prefill_chunk_info" , None ) is not None :
@@ -382,6 +390,7 @@ def _process_sampling_with_logprob_batch_output(self):
382
390
self .tokens_counter [task_id ] += 1
383
391
if token_id != RECOVERY_STOP_SIGNAL :
384
392
result .outputs .token_ids .append (token_id )
393
+ task .output_token_ids .append (token_id )
385
394
result .outputs .logprob = float (scores [i , 0 ])
386
395
# Construct top_logprobs
387
396
topk_token_ids = tokens [i , :].tolist ()
0 commit comments