Skip to content

Commit 3740e33

Browse files
【Feature】ResourceManagerV1 support need block num notifying (#4220)
* support need block num notifying * adapt t2i * fix unexpected change
1 parent 70633c6 commit 3740e33

File tree

3 files changed

+211
-61
lines changed

3 files changed

+211
-61
lines changed

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 172 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
3131
from fastdeploy.engine.resource_manager import ResourceManager
32+
from fastdeploy.inter_communicator import IPCSignal
3233
from fastdeploy.metrics.metrics import main_process_metrics
3334
from fastdeploy.platforms import current_platform
3435
from fastdeploy.utils import llm_logger
@@ -69,6 +70,69 @@ class ScheduledExtendBlocksTask:
6970
task_type: RequestType = RequestType.EXTEND
7071

7172

73+
class SignalConsumer:
74+
"""
75+
A class that consumes a signal value up to a specified limit.
76+
77+
This class maintains an internal signal value and allows controlled consumption
78+
of that signal. The signal can be watched at any time, but can only be consumed
79+
a limited number of times before being reset to zero.
80+
"""
81+
82+
def __init__(self, signal, consume_limit):
83+
"""
84+
Initialize the SignalConsumer with a signal value and consumption limit.
85+
86+
Args:
87+
signal: The initial signal value to be consumed.
88+
consume_limit (int): The maximum number of times the signal can be consumed
89+
before being reset to 0. Must be a positive integer.
90+
91+
Raises:
92+
AssertionError: If consume_limit is not greater than 0.
93+
"""
94+
assert consume_limit > 0
95+
96+
self._signal = signal
97+
self._consume_limit = consume_limit
98+
99+
def watch(self):
100+
"""
101+
Get the current signal value without consuming it.
102+
103+
This method allows reading the signal value any number of times without
104+
affecting the consumption limit or the signal value itself.
105+
106+
Returns:
107+
The current signal value.
108+
"""
109+
return self._signal
110+
111+
def consume(self):
112+
"""
113+
Consume the signal value, decrementing the consumption limit.
114+
115+
This method returns the current signal value and decrements the consumption
116+
counter. When the consumption limit reaches zero, the signal is automatically
117+
reset to 0. The consumption happens in a finally block to ensure the limit is
118+
decremented even if an exception occurs while processing the signal.
119+
120+
Returns:
121+
The current signal value before consumption.
122+
123+
Note:
124+
After the consumption limit is reached, this method will continue to
125+
return 0 on subsequent calls.
126+
"""
127+
try:
128+
return self._signal
129+
finally:
130+
if self._consume_limit > 0:
131+
self._consume_limit -= 1
132+
if self._consume_limit == 0:
133+
self._signal = 0
134+
135+
72136
class ResourceManagerV1(ResourceManager):
73137
"""
74138
Resource manager for scheduler v1.
@@ -95,6 +159,19 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
95159
main_process_metrics.max_batch_size.set(max_num_seqs)
96160

97161
self.using_extend_tables_req_id = set()
162+
self.reuse_block_num_map = dict()
163+
164+
# need block nums
165+
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
166+
self.need_block_num_signal = IPCSignal(
167+
name="need_block_num_signal",
168+
array=need_block_num_data,
169+
dtype=np.int32,
170+
suffix=local_data_parallel_id,
171+
create=True,
172+
)
173+
174+
self.need_block_num_map = dict()
98175

99176
def allocated_slots(self, request: Request):
100177
return len(request.block_tables) * self.config.cache_config.block_size
@@ -127,14 +204,35 @@ def reschedule_preempt_task(self, request_id):
127204
self.waiting.appendleft(request)
128205
self.to_be_rescheduled_request_id_set.remove(request_id)
129206

207+
def _info_each_block(self):
208+
"""
209+
print each req block
210+
"""
211+
for req in self.running:
212+
llm_logger.debug(
213+
f"req idx {req.idx} occupy {len(req.block_tables)} block_tables and {len(req.extend_block_tables)} extend_block_tables"
214+
)
215+
216+
def _can_preempt(self):
217+
"""
218+
cannot preempt request which use extend block
219+
"""
220+
for req in self.running:
221+
if not req.use_extend_tables:
222+
return True
223+
return False
224+
130225
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
131226
"""
132227
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
133228
"""
134-
can_schedule = True
135-
while True:
229+
can_schedule = False
230+
while self._can_preempt():
136231
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
137232
preempted_req = self.running.pop()
233+
if preempted_req.use_extend_tables:
234+
self.running.insert(0, preempted_req)
235+
continue
138236
preempted_req.status = RequestStatus.PREEMPTED
139237
preempted_req.num_computed_tokens = 0
140238
if self.config.scheduler_config.splitwise_role == "decode":
@@ -156,6 +254,13 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
156254
main_process_metrics.num_requests_running.dec(1)
157255
preempted_reqs.append(preempted_req)
158256
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
257+
258+
llm_logger.debug(
259+
f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}"
260+
)
261+
llm_logger.debug(self.info())
262+
self._info_each_block()
263+
159264
if preempted_req == request:
160265
# No more request to preempt.
161266
can_schedule = False
@@ -314,6 +419,11 @@ def schedule(self):
314419
num_decoding_req_nums = 0
315420
while req_index < len(self.running) and token_budget > 0:
316421
request = self.running[req_index]
422+
need_block_num = self.need_block_num_signal.value[request.idx]
423+
if need_block_num != 0:
424+
self.need_block_num_map[request.request_id] = SignalConsumer(need_block_num, 1)
425+
self.need_block_num_signal.value[request.idx] = 0
426+
317427
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
318428
if (
319429
self.config.scheduler_config.splitwise_role == "prefill"
@@ -351,6 +461,60 @@ def schedule(self):
351461
scheduled_reqs.append(self._prepare_decode_task(request))
352462
num_decoding_req_nums += 1
353463
token_budget -= 1
464+
465+
if (
466+
request.use_extend_tables
467+
and request.request_id not in self.using_extend_tables_req_id
468+
and self.need_block_num_map[request.request_id].watch() > 0
469+
):
470+
471+
def _allocate_decode_and_extend():
472+
allocate_block_num = self.need_block_num_map[request.request_id].consume()
473+
# Prepare decoding task
474+
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(allocate_block_num))
475+
scheduled_reqs.append(self._prepare_decode_task(request))
476+
477+
# Prepare extend task
478+
reuse_block_num = request.num_total_tokens // self.config.cache_config.block_size
479+
llm_logger.info(
480+
f"req {request.request_id} at batch id {request.idx} with reuse_block_num {reuse_block_num} is going to enable extend tables,"
481+
f"need_block_num {allocate_block_num}"
482+
)
483+
self.using_extend_tables_req_id.add(request.request_id)
484+
self.reuse_block_num_map[request.request_id] = reuse_block_num
485+
486+
request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache
487+
request.extend_block_tables.extend(
488+
self.cache_manager.allocate_gpu_blocks(allocate_block_num)
489+
)
490+
scheduled_reqs.append(
491+
ScheduledExtendBlocksTask(
492+
idx=request.idx,
493+
request_id=request.request_id,
494+
extend_block_tables=request.extend_block_tables,
495+
)
496+
)
497+
llm_logger.debug(f"extend blocks is {request.extend_block_tables}")
498+
499+
if self.cache_manager.can_allocate_gpu_blocks(
500+
2 * self.need_block_num_map[request.request_id].watch()
501+
):
502+
_allocate_decode_and_extend()
503+
else:
504+
llm_logger.info(
505+
f"{request.idx} using extend block need {2 * self.need_block_num_map[request.request_id].watch()} blocks but got not enough blocks, ready to preempt"
506+
)
507+
can_schedule = self._trigger_preempt(
508+
request,
509+
2 * self.need_block_num_map[request.request_id].watch(),
510+
preempted_reqs,
511+
scheduled_reqs,
512+
)
513+
514+
if can_schedule:
515+
_allocate_decode_and_extend()
516+
else:
517+
break
354518
else: # need to prefill
355519
llm_logger.debug(
356520
f"scheduler prefill task: {request} request.need_prefill_tokens {request.need_prefill_tokens} request.num_computed_tokens {request.num_computed_tokens}"
@@ -476,56 +640,6 @@ def schedule(self):
476640
else:
477641
llm_logger.error("Unknown request status type")
478642

479-
# schedule when extend block tables is needed
480-
for req in self.running:
481-
num_prefill_blocks = req.need_prefill_tokens // self.config.cache_config.block_size
482-
# allocate
483-
if req.use_extend_tables and req.request_id not in self.using_extend_tables_req_id:
484-
llm_logger.info(
485-
f"req {req.request_id} at batch id {req.idx} with num_prefill_blocks {num_prefill_blocks} is going to enable extend tables"
486-
)
487-
self.using_extend_tables_req_id.add(req.request_id)
488-
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
489-
req.extend_block_tables = req.block_tables[:num_prefill_blocks] # copy prompt cache
490-
req.extend_block_tables.extend(
491-
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
492-
)
493-
scheduled_reqs.append(
494-
ScheduledExtendBlocksTask(
495-
idx=req.idx, request_id=req.request_id, extend_block_tables=req.extend_block_tables
496-
)
497-
)
498-
llm_logger.info(f"extend blocks is {req.extend_block_tables}")
499-
else:
500-
continue
501-
# recycle
502-
elif not req.use_extend_tables and req.request_id in self.using_extend_tables_req_id:
503-
llm_logger.info(f"req {req.request_id} is going to disable extend tables")
504-
self.using_extend_tables_req_id.remove(req.request_id)
505-
self.cache_manager.recycle_gpu_blocks(req.extend_block_tables[num_prefill_blocks:])
506-
req.extend_block_tables = []
507-
508-
# allocate extend blocks when blocks is going to exhaust
509-
elif req.request_id in self.using_extend_tables_req_id:
510-
if (
511-
self.allocated_slots(req) - req.num_total_tokens
512-
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
513-
):
514-
llm_logger.info(
515-
f"req {req.request_id} is going to allocate more extend tables because allocated_slots {self.allocated_slots(req)} and prealloc_dec_block_slot_num_threshold {self.config.cache_config.prealloc_dec_block_slot_num_threshold} req.num_total_tokens {req.num_total_tokens}"
516-
)
517-
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
518-
req.extend_block_tables.extend(
519-
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
520-
)
521-
scheduled_reqs.append(
522-
ScheduledExtendBlocksTask(
523-
idx=req.idx, request_id=req.request_id, extend_block_tables=req.extend_block_tables
524-
)
525-
)
526-
else:
527-
continue
528-
529643
if scheduled_reqs:
530644
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
531645
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
@@ -725,13 +839,16 @@ def _free_blocks(self, request: Request):
725839
request.block_tables = []
726840

727841
if request.request_id in self.using_extend_tables_req_id:
728-
num_prefill_blocks = request.need_prefill_tokens // self.config.cache_config.block_size
842+
reuse_block_num = self.reuse_block_num_map[request.request_id]
843+
729844
self.using_extend_tables_req_id.remove(request.request_id)
730-
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[num_prefill_blocks:])
845+
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:])
731846
llm_logger.info(
732-
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[num_prefill_blocks:]}"
847+
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}"
733848
)
734849
request.extend_block_tables = []
850+
del self.reuse_block_num_map[request.request_id]
851+
del self.need_block_num_map[request.request_id]
735852

736853
def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
737854
return self.finish_execution_pool.submit(self.finish_requests, request_ids)

fastdeploy/input/tokenzier_client.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import asyncio
18+
import time
1819
from typing import Any, Optional, Union
1920

2021
import httpx
@@ -154,10 +155,22 @@ async def _async_decode_request(self, type: str, request: dict):
154155
url = f"{self.base_url}/image/decode"
155156
else:
156157
raise ValueError("Invalid type")
157-
resp = await client.post(url, json=request)
158-
resp.raise_for_status()
159-
if resp.json().get("code") != 0:
160-
raise RuntimeError(f"Tokenize task creation failed, {resp.json().get('message')}")
158+
159+
max_retries = 10
160+
for attempt in range(max_retries):
161+
try:
162+
resp = await client.post(url, json=request)
163+
resp.raise_for_status()
164+
if resp.json().get("code") != 0:
165+
raise RuntimeError(f"Tokenize task creation failed, {resp.json().get('message')}")
166+
break
167+
except Exception as e:
168+
data_processor_logger.error(f"Attempt to decode_request {attempt + 1} failed: {e}")
169+
if attempt == max_retries - 1:
170+
data_processor_logger.error(
171+
f"Max retries of decode_request reached. Giving up. request is {request}"
172+
)
173+
time.sleep(10)
161174
return resp.json().get("result")
162175
except httpx.RequestError as e:
163176
raise RuntimeError(f"Failed to decode: {e}") from e

fastdeploy/output/token_processor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,27 @@ def run(self):
142142
self.worker.daemon = True
143143
self.worker.start()
144144

145+
def _reschedule_preempt_task_use_zmq(self, datas):
146+
"""reschedule when real batch size is smaller than the insert position of preemted_task"""
147+
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
148+
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
149+
if len(need_to_be_reschedule_req_ids) > 0:
150+
batch_id_set = set()
151+
for data in datas:
152+
batch_id_set.add(data.batch_id)
153+
llm_logger.debug(f"_reschedule_preempt_task_use_zmq batch_id_set {batch_id_set}")
154+
for request_id in need_to_be_reschedule_req_ids:
155+
if (
156+
self.resource_manager.requests[request_id].idx not in batch_id_set
157+
): # No more token generated for preempted request
158+
llm_logger.debug(
159+
f"reschedule_preempt_task request_id {request_id} at {self.resource_manager.requests[request_id].idx}"
160+
)
161+
self.resource_manager.reschedule_preempt_task(request_id)
162+
llm_logger.debug(
163+
f"finish reschedule_preempt_task request_id {request_id} at {self.resource_manager.requests[request_id].idx}"
164+
)
165+
145166
def _reschedule_preempt_task(self, batch_size):
146167
"""reschedule when real batch size is smaller than the insert position of preemted_task"""
147168
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
@@ -264,8 +285,7 @@ def process_sampling_results_use_zmq(self):
264285
assert isinstance(receive_datas, list)
265286
llm_logger.debug(f"token_processor receive_data {receive_datas}")
266287

267-
batch_size = len(receive_datas)
268-
self._reschedule_preempt_task(batch_size)
288+
self._reschedule_preempt_task_use_zmq(receive_datas)
269289

270290
batch_result = self._process_batch_output_use_zmq(receive_datas)
271291
self.postprocess(batch_result)

0 commit comments

Comments
 (0)