Skip to content

Commit 62ccdab

Browse files
author
wangzaijun
committed
fix
1 parent 6770350 commit 62ccdab

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

lightllm/server/httpserver/manager.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from lightllm.utils.statics_utils import MovingAverage
3434
from lightllm.utils.config_utils import get_vocab_size
3535
from lightllm.utils.envs_utils import get_unique_server_name
36+
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
3637
from rpyc.utils.classic import obtain
3738

3839
logger = init_logger(__name__)
@@ -280,10 +281,17 @@ async def generate(
280281

281282
# 记录请求到达的相关信息
282283
await self._log_req_header(request_headers, group_request_id)
283-
# 监控
284-
284+
# encode
285285
prompt_ids = await self._encode(prompt, multimodal_params, sampling_params)
286286

287+
prompt_tokens = len(prompt_ids)
288+
# 监控
289+
if group_request_id > 0:
290+
self.metric_client.counter_inc("lightllm_request_count")
291+
self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens)
292+
self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens)
293+
prompt_ids = await self._check_and_repair_length(prompt_ids, sampling_params)
294+
287295
if nixl_pd_upload_websocket is not None and not is_health_req and self.pd_mode.is_NP():
288296
# 在 nixl pd 模式下的 p 节点, 为了更好的兼容多模态的推理流程,np 节点需要先上报其 encode 好的 prompt ids 信息,然后
289297
# 再等待 pd_master 传输下来的对应的进行 decode 节点的decode信息,然后再执行后续的流程
@@ -302,13 +310,10 @@ async def generate(
302310
decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info
303311
sampling_params.nixl_params.set(pickle.dumps(decode_node_info))
304312

305-
prompt_tokens = len(prompt_ids)
306-
# 监控
307-
if group_request_id > 0:
308-
self.metric_client.counter_inc("lightllm_request_count")
309-
self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens)
310-
self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens)
311-
prompt_ids = await self._check_and_repair_length(prompt_ids, sampling_params)
313+
if decode_node_info.ready_kv_len == len(prompt_ids) - 1:
314+
# 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill
315+
# 直接 raise NixlPrefillNodeStopGenToken
316+
raise NixlPrefillNodeStopGenToken(group_request_id=group_request_id)
312317

313318
# 申请资源并存储
314319
alloced_req_indexes = []

lightllm/server/httpserver/pd_loop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..pd_io_struct import PD_Master_Obj
1818
from lightllm.server.core.objs import StartArgs
1919
from lightllm.server.core.objs import SamplingParams
20+
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
2021

2122
logger = init_logger(__name__)
2223

@@ -207,7 +208,8 @@ async def _pd_process_generate(
207208
is_health_check_req = sub_req_id < 0
208209
if not is_health_check_req:
209210
await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status))
210-
211+
except NixlPrefillNodeStopGenToken as e:
212+
logger.info(f"nixl prefill node stop gen token for group_request_id {e.group_request_id}")
211213
except BaseException as e:
212214
logger.error(str(e))
213215

lightllm/utils/error_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@ def __str__(self):
2323
return f"{self.message} (Status code: {self.status_code})"
2424

2525

26+
class NixlPrefillNodeStopGenToken(Exception):
27+
def __init__(self, group_request_id, message="Nixl prefill node stop gen token"):
28+
"""
29+
Initialize the NixlPrefillNodeStopGenToken
30+
31+
Args:
32+
message (str): Error message to display
33+
"""
34+
super().__init__(message)
35+
self.message = message
36+
self.group_request_id = group_request_id
37+
38+
def __str__(self):
39+
"""String representation of the error"""
40+
return f"group_request_id: {self.group_request_id}, {self.message}"
41+
42+
2643
def log_exception(func):
2744
def wrapper(*args, **kwargs):
2845
try:

0 commit comments

Comments
 (0)