Skip to content

Commit c011cb8

Browse files
authored
[Bug Fix] Fix scheduler bug in develop (#3292)
* Fix scheduler bug in develop * Fix scheduler bug in develop * Fix scheduler bug in develop
1 parent 1e4968e commit c011cb8

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
7575
self.running: list[Request] = []
7676
self.finish_execution_pool = ThreadPoolExecutor(max_workers=1)
7777
self.lock = threading.Lock()
78+
self.to_be_rescheduled_request_id_set = set()
7879

7980
def allocated_slots(self, request: Request):
8081
return len(request.block_tables) * self.config.cache_config.block_size
@@ -96,6 +97,13 @@ def _prepare_decode_task(self, request):
9697

9798
def _prepare_preempt_task(self, request):
9899
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
100+
101+
def reschedule_preempt_task(self, request_id):
102+
with self.lock:
103+
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
104+
request = self.requests[request_id]
105+
self.waiting.appendleft(request)
106+
self.to_be_rescheduled_request_id_set.remove(request_id)
99107

100108
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
101109
can_schedule = True
@@ -106,7 +114,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
106114
preempted_req.num_computed_tokens = 0
107115
preempted_req.prefill_block_num = 0
108116
self._free_blocks(preempted_req)
109-
self.waiting.appendleft(preempted_req)
117+
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
110118
preempted_reqs.append(preempted_req)
111119
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
112120
if preempted_req == request:
@@ -381,8 +389,9 @@ def get_prefix_cached_blocks(self, request: Request):
381389
return False
382390

383391
def add_request(self, request: Request) -> None:
384-
self.waiting.append(request)
385-
self.requests[request.request_id] = request
392+
with self.lock:
393+
self.waiting.append(request)
394+
self.requests[request.request_id] = request
386395

387396
def _free_blocks(self, request: Request):
388397
if self.config.cache_config.enable_prefix_caching:
@@ -409,9 +418,15 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]):
409418
if request is None:
410419
# Invalid request ID.
411420
continue
412-
request.status = RequestStatus.FINISHED
413-
self.running.remove(request)
414-
self._free_blocks(request)
421+
if request in self.running: # normally run and finished
422+
self.running.remove(request)
423+
request.status = RequestStatus.FINISHED
424+
self._free_blocks(request)
425+
if request.request_id in self.to_be_rescheduled_request_id_set: # finished after preempted, blocks have been recycled.
426+
self.to_be_rescheduled_request_id_set.remove(request.request_id) # just remove from to_be_rescheduled_request_id_set
427+
if request in self.waiting: # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here
428+
raise RuntimeError(f"request {request.request_id} scheduled into waiting list, after finished")
429+
415430
self.tasks_list[request.idx] = None
416431
self.stop_flags[request.idx] = True
417432
del self.requests[req_id]

fastdeploy/output/token_processor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,13 @@ def _process_batch_output(self):
431431
else:
432432
batch = self.output_tokens[1, 0]
433433
tokens = tokens[2 : batch + 2]
434-
434+
435435
batch_result = list()
436+
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
437+
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
438+
for request_id in need_to_be_reschedule_req_ids:
439+
if self.resource_manager.requests[request_id].idx >= (batch - 1): # No more token generated for preempted request
440+
self.resource_manager.reschedule_preempt_task(request_id)
436441
for i in range(batch):
437442
if self.resource_manager.stop_flags[i]:
438443
continue
@@ -459,6 +464,8 @@ def _process_batch_output(self):
459464
if recovery_stop:
460465
llm_logger.info(f"recovery stop signal found at task {task_id}")
461466
if not recovery_stop and token_id < 0:
467+
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
468+
self.resource_manager.reschedule_preempt_task(task_id)
462469
continue
463470

464471
if task.get("prefill_chunk_info", None) is not None:

fastdeploy/worker/gpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,11 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
215215

216216
req_len = len(req_dicts)
217217
has_prefill_task = False
218+
has_decode_task = False
218219
for i in range(req_len):
219220
request = req_dicts[i]
220221
idx = request.idx
221222
if request.task_type.value == RequestType.PREFILL.value: # prefill task
222-
logger.debug(f"Handle prefill request {request} at idx {idx}")
223223
prefill_start_index = request.prefill_start_index
224224
prefill_end_index = request.prefill_end_index
225225
length = prefill_end_index - prefill_start_index
@@ -265,6 +265,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
265265
)
266266

267267
input_ids = request.prompt_token_ids + request.output_token_ids
268+
logger.debug(f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}")
268269
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
269270
input_ids[prefill_start_index:prefill_end_index]
270271
)
@@ -293,6 +294,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
293294
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
294295
request.block_tables, dtype="int32"
295296
)
297+
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
298+
has_decode_task = True
296299
continue
297300
else: # preempted task
298301
logger.debug(f"Handle preempted request {request} at idx {idx}")
@@ -338,7 +341,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
338341
else:
339342
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
340343

341-
if has_prefill_task:
344+
if has_prefill_task or has_decode_task:
342345
self.share_inputs["not_need_stop"][0] = True
343346
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
344347

0 commit comments

Comments
 (0)