Skip to content

Commit 7ae1d0c

Browse files
author
wangzaijun
committed
diversemode fix
1 parent 542cb48 commit 7ae1d0c

File tree

5 files changed

+209
-117
lines changed

5 files changed

+209
-117
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,13 @@ def _prefill(
361361
alloc_mem_index=infer_state.mem_index,
362362
max_q_seq_len=infer_state.max_q_seq_len,
363363
)
364+
prefill_mem_indexes_ready_event = torch.cuda.Event()
365+
prefill_mem_indexes_ready_event.record()
366+
364367
infer_state.init_some_extra_state(self, model_input.input_ids)
365-
return self._context_forward(model_input.input_ids, infer_state)
368+
model_output = self._context_forward(model_input.input_ids, infer_state)
369+
model_output.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
370+
return model_output
366371

367372
def _decode(
368373
self,
@@ -514,13 +519,18 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
514519
)
515520
infer_state1.init_some_extra_state(self, input_ids1)
516521

522+
prefill_mem_indexes_ready_event = torch.cuda.Event()
523+
prefill_mem_indexes_ready_event.record()
524+
517525
model_output0, model_output1 = self._overlap_tpsp_context_forward(
518526
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
519527
)
520528

521529
# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
522530
# 该调用没有实际意义
523531
dist_group_manager.clear_deepep_buffer()
532+
model_output0.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
533+
model_output1.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
524534
return model_output0, model_output1
525535

526536
@torch.no_grad()

lightllm/common/basemodel/batch_objs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def to_cuda(self):
5858
class ModelOutput:
5959
# 通用变量
6060
logits: torch.Tensor
61+
# 用于判断 mem_indexes 是否成功写入 req manager 中的事件对象。
62+
prefill_mem_indexes_ready_event: torch.Event = None
6163

6264
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
6365
# 的输出变量。只在特殊的模型模式下才会具体使用和生效。

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 65 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class InferenceContext:
2828
radix_cache: RadixCache = None
2929
shm_req_manager: ShmReqManager = None # 共享内存请求对象管理
3030
requests_mapping: Dict[int, "InferReq"] = None
31-
group_mapping = None # 只有进行多输出模式下才有真的使用
3231
infer_req_ids = None
3332
vocab_size = None
3433

@@ -48,7 +47,6 @@ def register(
4847
self.shm_req_manager = shm_req_manager
4948

5049
self.requests_mapping = {}
51-
self.group_mapping: Dict[int, InferReqGroup] = {}
5250
self.infer_req_ids = []
5351

5452
self.vocab_size = vocab_size
@@ -84,46 +82,42 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
8482

8583
self.infer_req_ids.extend(request_ids)
8684

87-
# 多输出模式下需要将请求添加到各自的组对象 InferReqGroup 中
85+
# diverse mode 下,建立一组请求间的主从关系
8886
if get_env_start_args().diverse_mode:
87+
group_reqs: Dict[int, InferReq] = collections.defaultdict(lambda: [None, list()])
8988
for r_id in request_ids:
9089
req: InferReq = g_infer_context.requests_mapping[r_id]
9190
group_req_id = req.shm_req.group_req_id
92-
if group_req_id not in g_infer_context.group_mapping:
93-
g_infer_context.group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id)
94-
g_infer_context.group_mapping[group_req_id].add_req(r_id)
91+
if req.req_id == group_req_id:
92+
group_reqs[group_req_id][0] = req
93+
else:
94+
group_reqs[group_req_id][1].append(req)
95+
96+
for group_req_id, (master_req, slave_reqs) in group_reqs.items():
97+
master_req: InferReq = master_req
98+
master_req.slave_reqs.extend(slave_reqs)
99+
for slave_req in slave_reqs:
100+
slave_req: InferReq = slave_req
101+
slave_req.related_master_req = master_req
95102

96103
return req_objs
97104

98-
def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finished: bool):
105+
def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
99106
if self.radix_cache is None:
100-
if is_group_finished:
101-
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
102-
else:
103-
free_token_index.append(
104-
self.req_manager.req_to_token_indexs[req.req_idx][req.shm_req.input_len : req.cur_kv_len]
105-
)
107+
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
106108
else:
107109
input_token_ids = req.get_input_token_ids()
108110
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
109111
# .cpu() 是 流内阻塞操作
110112
value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu()
111113

112-
if is_group_finished:
113-
prefix_len, _ = self.radix_cache.insert(key, value)
114-
old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
115-
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len])
116-
if req.shared_kv_node is not None:
117-
assert req.shared_kv_node.node_prefix_total_len <= prefix_len
118-
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
119-
req.shared_kv_node = None
120-
else:
121-
free_token_index.append(
122-
self.req_manager.req_to_token_indexs[req.req_idx][req.shm_req.input_len : req.cur_kv_len]
123-
)
124-
if req.shared_kv_node is not None:
125-
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
126-
req.shared_kv_node = None
114+
prefix_len, _ = self.radix_cache.insert(key, value)
115+
old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
116+
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len])
117+
if req.shared_kv_node is not None:
118+
assert req.shared_kv_node.node_prefix_total_len <= prefix_len
119+
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
120+
req.shared_kv_node = None
127121

128122
def _save_promptcache_kvbuffer(self):
129123
"""
@@ -148,14 +142,10 @@ def _filter(self, finished_request_ids: List[int]):
148142
free_token_index = []
149143
for request_id in finished_request_ids:
150144
req: InferReq = self.requests_mapping.pop(request_id)
151-
group_req_id = convert_sub_id_to_group_id(req.shm_req.request_id)
152-
if group_req_id in self.group_mapping:
153-
is_group_finished = self.group_mapping[group_req_id].remove_req(req.shm_req.request_id)
154-
if is_group_finished:
155-
del self.group_mapping[group_req_id]
156-
self.free_a_req_mem(free_token_index, req, is_group_finished)
157-
else:
158-
self.free_a_req_mem(free_token_index, req, True)
145+
if self.args.diverse_mode:
146+
req.clear_master_slave_state()
147+
self.free_a_req_mem(free_token_index, req)
148+
159149
free_req_index.append(req.req_idx)
160150
# logger.info(f"infer release req id {req.shm_req.request_id}")
161151
req.shm_req.shm_infer_released = True
@@ -192,8 +182,10 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
192182

193183
free_token_index = []
194184
for req in pause_reqs:
195-
# 不支持多输出的情况的暂停, 不能支持 diverse 输出模式。
196-
self.free_a_req_mem(free_token_index, req, is_group_finished=True)
185+
if self.args.diverse_mode:
186+
# 发生暂停的时候,需要清除 diverse 模式下的主从关系
187+
req.clear_master_slave_state()
188+
self.free_a_req_mem(free_token_index, req)
197189
req.cur_kv_len = 0
198190
req.shm_req.shm_cur_kv_len = req.cur_kv_len
199191
assert req.wait_pause is True
@@ -337,6 +329,10 @@ def __init__(
337329
self.need_out_token_id_statistics = True
338330
self.out_token_id_count: Dict[int, int] = None
339331

332+
# diverse mode 下,用于标记请求组之间的依赖关系
333+
self.slave_reqs: List[InferReq] = []
334+
self.related_master_req: InferReq = None
335+
340336
# nixl pd 分离模式使用的变量, 普通模式下这些变量没有具体用途
341337
self.nixl_trans_kv_start_index: int = 0
342338
self.nixl_pd_task_num: int = 0
@@ -407,6 +403,37 @@ def _match_radix_cache(self):
407403
self.shm_req.shm_cur_kv_len = self.cur_kv_len
408404
return
409405

406+
def is_master_req(self):
407+
"""
408+
diverse 模式下,判断当前请求是否为独立主请求,其进行prefill后,将
409+
kv 通过 radix cache 共享给其他 slave 请求, 共享后 slave 请求也
410+
会升级为 master 请求,具有独立推理,暂停的特性。
411+
"""
412+
return self.related_master_req is None
413+
414+
def is_slave_req(self):
415+
return self.related_master_req is not None
416+
417+
def clear_master_slave_state(self):
418+
if self.is_slave_req():
419+
self.remove_master_req()
420+
elif self.is_master_req():
421+
# 数组需要 copy 后遍历。
422+
for slave_req in self.slave_reqs.copy():
423+
slave_req.remove_master_req()
424+
425+
def remove_master_req(self):
426+
"""
427+
一个处于 slave 状态的请求,解除与 master 请求的依赖关系后,自己会升级为
428+
master_req 的状态,具有独立推理,暂停的特性。
429+
"""
430+
master_req = self.related_master_req
431+
if master_req is not None:
432+
master_req.slave_reqs.remove(self)
433+
self.related_master_req = None
434+
else:
435+
logger.warning(f"try to remove master req, but related_master_req is None, req id {self.req_id}")
436+
410437
def get_output_len(self):
411438
return self.cur_output_len
412439

@@ -482,49 +509,6 @@ def _mtp_decode_need_token_num(self) -> int:
482509
return (1 + self.mtp_step) * 2
483510

484511

485-
class InferReqGroup:
486-
def __init__(
487-
self,
488-
group_req_id: int,
489-
) -> None:
490-
self.group_req_id = group_req_id
491-
self.req_ids_group = []
492-
493-
def get_req(self, index):
494-
return g_infer_context.requests_mapping[self.req_ids_group[index]]
495-
496-
def get_all_reqs(self):
497-
return [g_infer_context.requests_mapping[self.req_ids_group[i]] for i in range(len(self.req_ids_group))]
498-
499-
def add_req(self, req_id):
500-
self.req_ids_group.append(req_id)
501-
502-
def remove_req(self, req_id):
503-
assert req_id in self.req_ids_group
504-
self.req_ids_group.remove(req_id)
505-
return len(self.req_ids_group) == 0
506-
507-
def best_of(self):
508-
return len(self.req_ids_group)
509-
510-
def diverse_copy(self, req_manager, is_prefill):
511-
# record previous status
512-
master_req = g_infer_context.requests_mapping[convert_sub_id_to_group_id(self.req_ids_group[0])]
513-
new_kv_len = master_req.get_chuncked_input_token_len()
514-
515-
# update the InferReq status and mem_manager status for cache sharing
516-
for req_id in self.req_ids_group[:]:
517-
if req_id == convert_sub_id_to_group_id(req_id):
518-
continue
519-
req = g_infer_context.requests_mapping[req_id]
520-
req.finish_status.set_status(FinishStatus.NO_FINISH)
521-
assert req.cur_kv_len <= master_req.cur_kv_len
522-
copy_token_index = req_manager.req_to_token_indexs[master_req.req_idx][req.cur_kv_len : new_kv_len]
523-
524-
req_manager.req_to_token_indexs[req.req_idx][req.cur_kv_len : new_kv_len] = copy_token_index
525-
req.cur_kv_len = master_req.cur_kv_len
526-
527-
528512
class InferReqUpdatePack:
529513
"""
530514
用于延迟InferReq的请求更新,主要是为了方便更高效的overlap机制实现。解耦

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,13 @@ def _prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, finished_reqs:
7979
self.model.mem_manager.free(
8080
self.model.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]
8181
)
82-
if req.shared_kv_node is not None:
83-
# 将原有共享节点替换为新共享节点,新共享节点对应的长度为当前的cur_kv_len
84-
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
85-
self.radix_cache.add_node_ref_counter(new_shared_kv_node)
86-
req.shared_kv_node = new_shared_kv_node
87-
assert new_shared_kv_node.node_prefix_total_len == req.cur_kv_len
82+
# 将原有共享节点替换为新共享节点,新共享节点对应的长度为当前的cur_kv_len
83+
84+
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
85+
self.radix_cache.add_node_ref_counter(new_shared_kv_node)
86+
req.shared_kv_node = new_shared_kv_node
87+
88+
assert new_shared_kv_node.node_prefix_total_len == req.cur_kv_len
8889

8990
if req.shm_req.sample_params.move_kv_to_decode_node.exists:
9091
# 注意兼容纯tp 和 tp dp 混合模式的逻辑

0 commit comments

Comments
 (0)