@@ -75,6 +75,7 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
75
75
self .running : list [Request ] = []
76
76
self .finish_execution_pool = ThreadPoolExecutor (max_workers = 1 )
77
77
self .lock = threading .Lock ()
78
+ self .to_be_rescheduled_request_id_set = set ()
78
79
79
80
def allocated_slots (self , request : Request ):
80
81
return len (request .block_tables ) * self .config .cache_config .block_size
@@ -96,6 +97,13 @@ def _prepare_decode_task(self, request):
96
97
97
98
def _prepare_preempt_task (self , request ):
98
99
return ScheduledPreemptTask (idx = request .idx , request_id = request .request_id )
100
+
101
+ def reschedule_preempt_task (self , request_id ):
102
+ with self .lock :
103
+ if request_id in self .to_be_rescheduled_request_id_set and request_id in self .requests :
104
+ request = self .requests [request_id ]
105
+ self .waiting .appendleft (request )
106
+ self .to_be_rescheduled_request_id_set .remove (request_id )
99
107
100
108
def _trigger_preempt (self , request , num_new_blocks , preempted_reqs , scheduled_reqs ):
101
109
can_schedule = True
@@ -106,7 +114,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
106
114
preempted_req .num_computed_tokens = 0
107
115
preempted_req .prefill_block_num = 0
108
116
self ._free_blocks (preempted_req )
109
- self .waiting . appendleft (preempted_req )
117
+ self .to_be_rescheduled_request_id_set . add (preempted_req . request_id )
110
118
preempted_reqs .append (preempted_req )
111
119
scheduled_reqs .append (self ._prepare_preempt_task (preempted_req ))
112
120
if preempted_req == request :
@@ -381,8 +389,9 @@ def get_prefix_cached_blocks(self, request: Request):
381
389
return False
382
390
383
391
def add_request (self , request : Request ) -> None :
384
- self .waiting .append (request )
385
- self .requests [request .request_id ] = request
392
+ with self .lock :
393
+ self .waiting .append (request )
394
+ self .requests [request .request_id ] = request
386
395
387
396
def _free_blocks (self , request : Request ):
388
397
if self .config .cache_config .enable_prefix_caching :
@@ -409,9 +418,15 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]):
409
418
if request is None :
410
419
# Invalid request ID.
411
420
continue
412
- request .status = RequestStatus .FINISHED
413
- self .running .remove (request )
414
- self ._free_blocks (request )
421
+ if request in self .running : # normally run and finished
422
+ self .running .remove (request )
423
+ request .status = RequestStatus .FINISHED
424
+ self ._free_blocks (request )
425
+ if request .request_id in self .to_be_rescheduled_request_id_set : # finished after preempted, blocks have been recycled.
426
+ self .to_be_rescheduled_request_id_set .remove (request .request_id ) # just remove from to_be_rescheduled_request_id_set
427
+ if request in self .waiting : # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here
428
+ raise RuntimeError (f"request { request .request_id } scheduled into waiting list, after finished" )
429
+
415
430
self .tasks_list [request .idx ] = None
416
431
self .stop_flags [request .idx ] = True
417
432
del self .requests [req_id ]
0 commit comments