Skip to content

Commit 572ce8d

Browse files
committed
fix
1 parent 04e5490 commit 572ce8d

File tree

4 files changed

+66
-1
lines changed

4 files changed

+66
-1
lines changed

lightllm/common/basemodel/batch_objs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass, field
33
from typing import Optional
44
from typing import List
5+
from lightllm.utils.envs_utils import get_env_start_args
56

67

78
@dataclass
@@ -21,6 +22,15 @@ class ModelInput:
2122
b_req_idx: torch.Tensor = None
2223
b_mtp_index: torch.Tensor = None
2324
b_seq_len: torch.Tensor = None
25+
# 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, 用于记录共享的radix cache中的长度
26+
b_shared_seq_len: torch.Tensor = None
27+
# 只会在 diverse_mode 下的 decode 阶段真正被使用的参数, 用于记录请求间的共享关系。
28+
# 举列说明:
29+
# b_shared_seq_len : [10, 10, 10, 11, 11, 11, 11]
30+
# b_mark_shared_group: [0, 0, 3, 0, 0, 0, 4]
31+
# b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于
32+
# 同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。
33+
b_mark_shared_group: torch.Tensor = None
2434
mem_indexes: torch.Tensor = None
2535
is_prefill: bool = False
2636
b_ready_cache_len: torch.Tensor = None
@@ -52,6 +62,16 @@ def to_cuda(self):
5262
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
5363
if self.b_prefill_start_loc is not None:
5464
self.b_prefill_start_loc = self.b_prefill_start_loc.cuda(non_blocking=True)
65+
if not self.is_prefill and get_env_start_args().diverse_mode:
66+
batch_size = len(self.b_req_idx)
67+
if self.b_mark_shared_group is None:
68+
self.b_mark_shared_group = torch.ones(size=(batch_size,), dtype=torch.int32, device="cuda")
69+
else:
70+
self.b_mark_shared_group = self.b_mark_shared_group.cuda(non_blocking=True)
71+
if self.b_shared_seq_len is None:
72+
self.b_shared_seq_len = torch.zeros(size=(batch_size,), dtype=torch.int32, device="cuda")
73+
else:
74+
self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True)
5575

5676

5777
@dataclass

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ def remove_master_req(self):
434434
else:
435435
logger.warning(f"try to remove master req, but related_master_req is None, req id {self.req_id}")
436436

437+
def get_radix_cache_shared_len(self):
438+
return 0 if self.shared_kv_node is None else self.shared_kv_node.node_prefix_total_len
439+
437440
def get_output_len(self):
438441
return self.cur_output_len
439442

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
55
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
66
from lightllm.common.basemodel.batch_objs import ModelInput
7+
from lightllm.utils.envs_utils import get_env_start_args, get_diverse_max_batch_shared_group_size
78

89

910
def prepare_prefill_inputs(
@@ -99,12 +100,16 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
99100
b_mtp_index = []
100101
b_seq_len = []
101102
b_q_seq_len = []
103+
b_shared_seq_len = []
104+
max_batch_shared_group_size = get_diverse_max_batch_shared_group_size()
102105
for req in req_objs:
106+
_radix_shared_len = req.get_radix_cache_shared_len()
103107
run_reqs.append(req)
104108
b_req_idx.append(req.req_idx)
105109
seq_len = req.get_cur_total_len()
106110
assert req.cur_kv_len == seq_len - 1, f"{req.cur_kv_len} {seq_len}"
107111
b_seq_len.append(seq_len)
112+
b_shared_seq_len.append(_radix_shared_len)
108113
total_token_num += seq_len
109114
max_len_in_batch = max(max_len_in_batch, seq_len)
110115
b_mtp_index.append(0)
@@ -114,6 +119,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
114119
b_req_idx.append(req.req_idx)
115120
seq_len += 1
116121
b_seq_len.append(seq_len)
122+
b_shared_seq_len.append(_radix_shared_len)
117123
total_token_num += seq_len
118124
max_len_in_batch = max(max_len_in_batch, seq_len)
119125
b_mtp_index.append(step + 1)
@@ -124,7 +130,36 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
124130

125131
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
126132
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
133+
b_shared_seq_len = torch.tensor(b_shared_seq_len, dtype=torch.int32, device="cpu")
127134
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
135+
if get_env_start_args().diverse_mode:
136+
b_mark_shared_group = []
137+
shared_nodes = [req.shared_kv_node for req in run_reqs]
138+
_current_group = []
139+
for node in shared_nodes:
140+
if not _current_group:
141+
_current_group.append(node)
142+
elif node == _current_group[-1]:
143+
_current_group.append(node)
144+
else:
145+
b_mark_shared_group.extend([0 for _ in range(len(_current_group))])
146+
b_mark_shared_group[-1] = len(_current_group)
147+
_current_group.clear()
148+
_current_group.append(node)
149+
150+
if len(_current_group) == max_batch_shared_group_size:
151+
b_mark_shared_group.extend([0 for _ in range(len(_current_group))])
152+
b_mark_shared_group[-1] = len(_current_group)
153+
_current_group.clear()
154+
if _current_group:
155+
b_mark_shared_group.extend([0 for _ in range(len(_current_group))])
156+
b_mark_shared_group[-1] = len(_current_group)
157+
_current_group.clear()
158+
159+
assert len(b_mark_shared_group) == len(run_reqs)
160+
b_mark_shared_group = torch.tensor(b_mark_shared_group, dtype=torch.int32, device="cpu")
161+
else:
162+
b_mark_shared_group = None
128163

129164
# dynamic prompt cache 准备 token
130165
g_infer_state_lock.acquire()
@@ -144,6 +179,8 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
144179
b_req_idx=b_req_idx,
145180
b_mtp_index=b_mtp_index,
146181
b_seq_len=b_seq_len,
182+
b_shared_seq_len=b_shared_seq_len,
183+
b_mark_shared_group=b_mark_shared_group,
147184
is_prefill=False,
148185
)
149186
return model_input, run_reqs

lightllm/utils/envs_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,9 @@ def enable_radix_tree_timer_merge() -> bool:
193193

194194
@lru_cache(maxsize=None)
195195
def get_radix_tree_merge_update_delta() -> int:
196-
return int(os.getenv("LIGHTLMM_RADIX_TREE_MERGE_DELTA", 6000))
196+
return int(os.getenv("LIGHTLLM_RADIX_TREE_MERGE_DELTA", 6000))
197+
198+
199+
@lru_cache(maxsize=None)
200+
def get_diverse_max_batch_shared_group_size() -> int:
201+
return int(os.getenv("LIGHTLLM_MAX_BATCH_SHARED_GROUP_SIZE", 4))

0 commit comments

Comments
 (0)