Skip to content

Commit 8a36651

Browse files
committed
fix
1 parent 4847d6b commit 8a36651

File tree

3 files changed

+173
-150
lines changed

3 files changed

+173
-150
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 101 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -53,31 +53,22 @@ def get_overlap_stream(self) -> torch.cuda.Stream:
5353
self.overlap_stream = torch.cuda.Stream()
5454
return self.overlap_stream
5555

56-
def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_req_obj=True):
56+
def add_reqs(self, requests: List[Tuple[int, int, Any, int]]):
5757
request_ids = []
5858
for r in requests:
59-
6059
r_id, r_index, multimodal_params, _ = r
61-
if r_id not in self.requests_mapping.keys():
62-
r_obj = InferReq(
63-
req_id=r_id,
64-
req_idx=self.req_manager.alloc(),
65-
shm_index=r_index,
66-
multimodal_params=multimodal_params,
67-
vocab_size=self.vocab_size,
68-
)
69-
self.requests_mapping[r_id] = r_obj
70-
else:
71-
r_obj: InferReq = self.requests_mapping[r_id]
72-
assert r_obj.paused is True
73-
60+
assert r_id not in self.requests_mapping.keys()
61+
r_obj = InferReq(
62+
req_id=r_id,
63+
req_idx=self.req_manager.alloc(),
64+
shm_index=r_index,
65+
multimodal_params=multimodal_params,
66+
vocab_size=self.vocab_size,
67+
)
68+
self.requests_mapping[r_id] = r_obj
7469
request_ids.append(r_id)
7570

76-
if init_req_obj:
77-
r_obj.init_all()
78-
7971
self.infer_req_ids.extend(request_ids)
80-
8172
return
8273

8374
def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finished: bool):
@@ -169,27 +160,47 @@ def filter_reqs(self, finished_reqs: List["InferReq"]):
169160
return
170161

171162
@torch.no_grad()
172-
def pause_reqs(self, pause_req_ids: List[int]):
173-
free_token_index = []
174-
for request_id in pause_req_ids:
175-
req: InferReq = self.requests_mapping[request_id]
176-
self.infer_req_ids.remove(request_id)
163+
def pause_reqs(self, pause_reqs: List["InferReq"]):
164+
if pause_reqs:
165+
g_infer_state_lock.acquire()
177166

178-
if req.initialized:
179-
# 不支持多输出的情况的暂停
167+
free_token_index = []
168+
for req in pause_reqs:
169+
# 不支持多输出的情况的暂停, 不能支持 diverse 输出模式。
180170
self.free_a_req_mem(free_token_index, req, is_group_finished=True)
181171
req.cur_kv_len = 0
182172
req.shm_req.shm_cur_kv_len = req.cur_kv_len
183-
req.paused = True # 暂停信息标记。
184-
else:
173+
assert req.wait_pause is True
174+
req.wait_pause = False
185175
req.paused = True
186176

187-
if len(free_token_index) != 0:
188-
free_token_index = custom_cat(free_token_index)
189-
self.req_manager.free_token(free_token_index)
177+
if len(free_token_index) != 0:
178+
free_token_index = custom_cat(free_token_index)
179+
self.req_manager.free_token(free_token_index)
190180

181+
g_infer_state_lock.release()
191182
return self
192183

184+
def recover_paused_reqs(self, paused_reqs: List["InferReq"]):
185+
if paused_reqs:
186+
g_infer_state_lock.acquire()
187+
188+
for req in paused_reqs:
189+
req._match_radix_cache()
190+
assert req.paused is True
191+
req.paused = False
192+
193+
g_infer_state_lock.release()
194+
return
195+
196+
def get_can_alloc_token_num(self):
197+
radix_cache_unref_token_num = 0
198+
if self.radix_cache is not None:
199+
radix_cache_unref_token_num = (
200+
self.radix_cache.get_tree_total_tokens_num() - self.radix_cache.get_refed_tokens_num()
201+
)
202+
return self.req_manager.mem_manager.can_use_mem_size + radix_cache_unref_token_num
203+
193204

194205
g_infer_context = InferenceContext()
195206

@@ -256,9 +267,14 @@ def __init__(
256267
self.shm_index = shm_index
257268
self.multimodal_params = multimodal_params
258269
self.vocab_size = vocab_size
259-
self.initialized = False
270+
271+
# 请求需要被暂停
272+
self.wait_pause = False
273+
# 请求已经被暂停
260274
self.paused = False
275+
261276
self.infer_aborted = False
277+
self.filter_mark = False
262278
self.need_out_token_id_statistics = True
263279
self.out_token_id_count: Dict[int, int] = None
264280

@@ -268,51 +284,48 @@ def __init__(
268284
# 步骤中需要重新进行校验。
269285
self.mtp_gen_token_ids: List[int] = []
270286

271-
def init_all(self):
272-
if self.initialized is False:
273-
self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index)
274-
self.shm_req.link_prompt_ids_shm_array()
275-
self.shm_req.link_logprobs_shm_array()
276-
self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size)
277-
self.cur_kv_len = 0
278-
self.cur_output_len = 0
279-
280-
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
281-
282-
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
283-
# token healing mode 才被使用的管理对象
284-
if self.shm_req.prefix_token_ids.size != 0:
285-
self.prefix_token_ids = self.shm_req.prefix_token_ids.get_token_ids()
286-
else:
287-
self.prefix_token_ids = []
288-
self.multimodal_params = self.multimodal_params.to_dict()
289-
self.shared_kv_node: TreeNode = None
290-
291-
self.finish_status = FinishStatus()
292-
293-
if self.paused or not self.initialized:
294-
# 如果是具有 prompt_cache 的使用特性则需要进行提前的填充和恢复操作。
295-
if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1:
296-
input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()]
297-
key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu")
298-
key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
299-
share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True)
300-
if share_node is not None:
301-
self.shared_kv_node = share_node
302-
ready_cache_len = share_node.node_prefix_total_len
303-
# 从 cpu 到 gpu 是流内阻塞操作
304-
g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor
305-
self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
306-
self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度
307-
308-
self.shm_req.shm_cur_kv_len = self.cur_kv_len
309-
310-
self.initialized = True
311-
self.paused = False
287+
self._init_all_state()
288+
self._match_radix_cache()
312289
return
313290

314-
def is_uninitialized(self):
315-
return not self.initialized or self.paused
291+
def _init_all_state(self):
292+
self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index)
293+
self.shm_req.link_prompt_ids_shm_array()
294+
self.shm_req.link_logprobs_shm_array()
295+
self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size)
296+
self.cur_kv_len = 0
297+
self.cur_output_len = 0
298+
299+
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
300+
301+
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
302+
# token healing mode 才被使用的管理对象
303+
if self.shm_req.prefix_token_ids.size != 0:
304+
self.prefix_token_ids = self.shm_req.prefix_token_ids.get_token_ids()
305+
else:
306+
self.prefix_token_ids = []
307+
self.multimodal_params = self.multimodal_params.to_dict()
308+
self.shared_kv_node: TreeNode = None
309+
310+
self.finish_status = FinishStatus()
311+
return
312+
313+
def _match_radix_cache(self):
314+
if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1 and self.cur_kv_len == 0:
315+
input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()]
316+
key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu")
317+
key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
318+
share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True)
319+
if share_node is not None:
320+
self.shared_kv_node = share_node
321+
ready_cache_len = share_node.node_prefix_total_len
322+
# 从 cpu 到 gpu 是流内阻塞操作
323+
g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor
324+
self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换
325+
self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度
326+
327+
self.shm_req.shm_cur_kv_len = self.cur_kv_len
328+
return
316329

317330
def get_output_len(self):
318331
return self.cur_output_len
@@ -372,6 +385,19 @@ def _stop_sequences_matched(self, output_len: int):
372385
return True
373386
return False
374387

388+
def prefill_need_token_num(self, is_chuncked_prefill: bool):
389+
if is_chuncked_prefill:
390+
input_token_ids = self.get_chuncked_input_token_ids()
391+
else:
392+
input_token_ids = self.get_input_token_ids()
393+
394+
seq_len = len(input_token_ids)
395+
input_token_len = seq_len - self.cur_kv_len
396+
return input_token_len
397+
398+
def decode_need_token_num(self):
399+
return 1 + len(self.mtp_gen_token_ids)
400+
375401

376402
class InferReqGroup:
377403
def __init__(

0 commit comments

Comments
 (0)