Skip to content

Commit 4847468

Browse files
authored
bug fix for max len prefill check error (#650)
1 parent 49b94a1 commit 4847468

File tree

4 files changed

+29
-4
lines changed

4 files changed

+29
-4
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,14 @@ def _check_mem_size(self):
166166
return
167167

168168
def _init_req_manager(self):
169-
self.req_manager = ReqManager(self.max_req_num, self.max_seq_length, self.mem_manager)
169+
create_max_seq_len = 0
170+
171+
if self.batch_max_tokens is not None:
172+
create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens)
173+
if self.max_seq_length is not None:
174+
create_max_seq_len = max(create_max_seq_len, self.max_seq_length)
175+
176+
self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, self.mem_manager)
170177
return
171178

172179
def _init_infer_layer(self):

lightllm/server/httpserver/manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL]
6868
self.id_gen = ReqIDGenerator()
6969
self.first_time_costs = MovingAverage()
70+
self.per_token_costs = MovingAverage()
7071
# 有的模型的vocab size 读取tokenizer和config.json中不一致
7172
self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size)
7273

@@ -340,6 +341,7 @@ async def _wait_to_token_package(
340341
pass
341342
total_cost_time_ms = (time.time() - start_time) * 1000
342343
mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter
344+
self.per_token_costs.add(mean_per_token_cost_time_ms)
343345
x_request_id = request.headers.get("X-Request-Id", "")
344346
x_session_id = request.headers.get("X-Session-Id", "")
345347
prompt_cache_len = metadata.pop("prompt_cache_len", 0)
@@ -441,6 +443,7 @@ async def timer_to_pd_master(self):
441443
await asyncio.sleep(3)
442444
if log_count % 5 == 0:
443445
logger.info(f"mean first cost: {self.first_time_costs.average()} ms")
446+
logger.info(f"mean per token cost: {self.per_token_costs.average()} ms")
444447

445448
except Exception as e:
446449
logger.error("connetion to pd_master has error")

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(
4141
self.id_to_event: Dict[int, asyncio.Event] = {}
4242
self.session = None
4343
self.first_time_costs = MovingAverage()
44-
self.create_session_costs = MovingAverage()
44+
self.prefill_create_session_costs = MovingAverage()
45+
self.decode_create_session_costs = MovingAverage()
46+
self.per_token_costs = MovingAverage()
4547
return
4648

4749
async def register_pd(self, pd_info_json):
@@ -181,7 +183,7 @@ async def fetch_stream(
181183
req = await self._to_req_info(prompt, sampling_params, multimodal_params)
182184
create_start_time = time.time()
183185
async with self.session.post(p_node.to_llm_url(), json=req) as response:
184-
self.create_session_costs.add((time.time() - create_start_time) * 1000)
186+
self.prefill_create_session_costs.add((time.time() - create_start_time) * 1000)
185187
if response.status == 200:
186188
async for line in response.content:
187189
line = line.decode("utf-8").strip()
@@ -217,7 +219,9 @@ async def fetch_stream(
217219
sampling_params.suggested_dp_index = event.upkv_status.dp_index
218220

219221
req = await self._to_req_info(prompt_ids, sampling_params, multimodal_params)
222+
create_start_time = time.time()
220223
async with self.session.post(d_node.to_llm_url(), json=req) as response:
224+
self.decode_create_session_costs.add((time.time() - create_start_time) * 1000)
221225
if response.status == 200:
222226
async for line in response.content:
223227
line = line.decode("utf-8").strip()
@@ -269,6 +273,7 @@ async def _wait_to_token_package(
269273

270274
total_cost_time_ms = (time.time() - start_time) * 1000
271275
mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter
276+
self.per_token_costs.add(mean_per_token_cost_time_ms)
272277
x_request_id = request.headers.get("X-Request-Id", "")
273278
x_session_id = request.headers.get("X-Session-Id", "")
274279
prompt_cache_len = metadata.pop("prompt_cache_len", 0)
@@ -312,5 +317,7 @@ async def handle_loop(self):
312317
# 可以做一个定时任务
313318
await asyncio.sleep(20)
314319
logger.info(f"mean first cost: {self.first_time_costs.average()} ms")
315-
logger.info(f"create_session_costs: {self.create_session_costs.average()} ms")
320+
logger.info(f"prefill mean create_session_costs: {self.prefill_create_session_costs.average()} ms")
321+
logger.info(f"decode mean create_session_costs: {self.decode_create_session_costs.average()} ms")
322+
logger.info(f"mean per token cost: {self.per_token_costs.average()} ms")
316323
return

lightllm/server/router/manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ def _update_init_status_to_batch(self, batch: Batch, req_to_req_status):
426426

427427
def _update_out_status_to_batch(self, batch: Batch, req_to_out_status):
428428
new_batch_decode_need_tokens = [0 for _ in range(self.dp_size)] # 只有在 splitfuse 模式下有意义
429+
430+
start_time = 0
431+
# extral_info 字段如果推理后端输入时间标记, 则用来评估序列化所占用的时间, 主要用于调试时使用
429432
for req_id, (
430433
req_status,
431434
cur_kv_len,
@@ -434,6 +437,8 @@ def _update_out_status_to_batch(self, batch: Batch, req_to_out_status):
434437
finish_status_value,
435438
extral_info,
436439
) in req_to_out_status.items():
440+
if extral_info is not None:
441+
start_time = max(start_time, extral_info)
437442
req: Req = batch.id_to_reqs[req_id]
438443
req.req_status = req_status
439444
req.cur_kv_len = cur_kv_len
@@ -446,6 +451,9 @@ def _update_out_status_to_batch(self, batch: Batch, req_to_out_status):
446451
new_batch_decode_need_tokens[req_dp_index] += req.get_decode_need_tokens()
447452

448453
batch.batch_decode_need_tokens = new_batch_decode_need_tokens
454+
rpyc_cost_time = (time.time() - start_time) * 1000
455+
if 8 <= rpyc_cost_time <= 1000:
456+
logger.warning(f"rpyc use too much time {rpyc_cost_time} ms, batch_size {len(req_to_out_status)}")
449457
return
450458

451459
def _can_decode(self, batch: Batch):

0 commit comments

Comments
 (0)