Skip to content

Commit 9423c57

Browse files
authored
[stop_seq] fix out-bound value for stop sequence (#3216)
* fix out-bound value for stop sequence * catch error if there are out-of-bounds value * check in offline mode * add ut tests
1 parent 5885285 commit 9423c57

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

fastdeploy/engine/engine.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,26 @@ def add_requests(self, task, sampling_params=None, **kwargs):
530530
llm_logger.error(error_msg)
531531
raise EngineError(error_msg, error_code=400)
532532

533+
if request.get("stop_seqs_len") is not None:
534+
stop_seqs_len = request.get("stop_seqs_len")
535+
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
536+
if len(stop_seqs_len) > max_stop_seqs_num:
537+
error_msg = (
538+
f"Length of stop ({stop_seqs_len}) exceeds the limit max_stop_seqs_num({max_stop_seqs_num})."
539+
"Please reduce the number of stop or set a lager max_stop_seqs_num by `FD_MAX_STOP_SEQS_NUM`"
540+
)
541+
llm_logger.error(error_msg)
542+
raise EngineError(error_msg, error_code=400)
543+
stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
544+
for single_stop_seq_len in stop_seqs_len:
545+
if single_stop_seq_len > stop_seqs_max_len:
546+
error_msg = (
547+
f"Length of stop_seqs({single_stop_seq_len}) exceeds the limit stop_seqs_max_len({stop_seqs_max_len})."
548+
"Please reduce the length of stop sequences or set a larger stop_seqs_max_len by `FD_STOP_SEQS_MAX_LEN`"
549+
)
550+
llm_logger.error(error_msg)
551+
raise EngineError(error_msg, error_code=400)
552+
533553
if self.guided_decoding_checker is not None:
534554
request, err_msg = self.guided_decoding_checker.schema_format(request)
535555
if err_msg is not None:

fastdeploy/entrypoints/engine_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import numpy as np
2121

22+
from fastdeploy import envs
2223
from fastdeploy.engine.config import ModelConfig
2324
from fastdeploy.input.preprocess import InputPreprocessor
2425
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
@@ -154,6 +155,26 @@ def add_requests(self, task):
154155
api_server_logger.error(error_msg)
155156
raise EngineError(error_msg, error_code=400)
156157

158+
if "stop_seqs_len" in task:
159+
stop_seqs_len = task["stop_seqs_len"]
160+
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
161+
if len(stop_seqs_len) > max_stop_seqs_num:
162+
error_msg = (
163+
f"Length of stop ({stop_seqs_len}) exceeds the limit max_stop_seqs_num({max_stop_seqs_num})."
164+
"Please reduce the number of stop or set a lager max_stop_seqs_num by `FD_MAX_STOP_SEQS_NUM`"
165+
)
166+
api_server_logger.error(error_msg)
167+
raise EngineError(error_msg, error_code=400)
168+
stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
169+
for single_stop_seq_len in stop_seqs_len:
170+
if single_stop_seq_len > stop_seqs_max_len:
171+
error_msg = (
172+
f"Length of stop_seqs({single_stop_seq_len}) exceeds the limit stop_seqs_max_len({stop_seqs_max_len})."
173+
"Please reduce the length of stop sequences or set a larger stop_seqs_max_len by `FD_STOP_SEQS_MAX_LEN`"
174+
)
175+
api_server_logger.error(error_msg)
176+
raise EngineError(error_msg, error_code=400)
177+
157178
task["preprocess_end_time"] = time.time()
158179
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
159180
api_server_logger.info(

test/ce/server/test_evil_cases.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,30 @@ def test_mixed_valid_invalid_fields():
9595
resp = send_request(URL, payload).json()
9696
assert "error" not in resp, "非法字段不应导致请求失败"
9797

98+
99+
def test_stop_seq_exceed_num():
100+
"""stop 字段包含超过 FD_MAX_STOP_SEQS_NUM 个元素,服务应报错"""
101+
data = {
102+
"stream": False,
103+
"messages": [{"role": "user", "content": "非洲的首都是?"}],
104+
"top_p": 0,
105+
"stop": ["11", "22", "33", "44", "55", "66", "77"],
106+
}
107+
payload = build_request_payload(TEMPLATE, data)
108+
resp = send_request(URL, payload).json()
109+
assert resp.get("object") == "error", "stop 超出个数应触发异常"
110+
assert "exceeds the limit max_stop_seqs_num" in resp.get("message", ""), "未返回预期的报错信息"
111+
112+
113+
def test_stop_seq_exceed_length():
114+
"""stop 中包含长度超过 FD_STOP_SEQS_MAX_LEN 的元素,服务应报错"""
115+
data = {
116+
"stream": False,
117+
"messages": [{"role": "user", "content": "非洲的首都是?"}],
118+
"top_p": 0,
119+
"stop": ["11", "今天天气比明天好多了,请问你会出门还是和我一起玩"],
120+
}
121+
payload = build_request_payload(TEMPLATE, data)
122+
resp = send_request(URL, payload).json()
123+
assert resp.get("object") == "error", "stop 超出长度应触发异常"
124+
assert "exceeds the limit stop_seqs_max_len" in resp.get("message", ""), "未返回预期的报错信息"

0 commit comments

Comments
 (0)