Skip to content

Commit 88596c0

Browse files
authored
Add more base chat cases (#3203)
* add test base class * fix codestyle * fix codestyle * add base chat
1 parent fe540f6 commit 88596c0

File tree

3 files changed

+60
-52
lines changed

3 files changed

+60
-52
lines changed

test/ce/server/core/__init__.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99

1010
base_logger = Logger(loggername="FDSentry", save_level="channel", log_path="./fd_logs").get_logger()
1111
base_logger.setLevel("INFO")
12+
1213
from .request_template import TEMPLATES
13-
from .utils import build_request_payload, get_stream_chunks, send_request
14+
from .utils import (
15+
build_request_payload,
16+
get_stream_chunks,
17+
get_token_list,
18+
send_request,
19+
)
1420

15-
__all__ = ["build_request_payload", "send_request", "TEMPLATES", "get_stream_chunks"]
21+
__all__ = ["build_request_payload", "send_request", "TEMPLATES", "get_stream_chunks", "get_token_list"]
1622

1723
# 检查环境变量是否存在
1824
URL = os.environ.get("URL")
@@ -23,14 +29,15 @@
2329
missing_vars.append("URL")
2430
if not TEMPLATE:
2531
missing_vars.append("TEMPLATE")
26-
if missing_vars:
27-
if not URL:
28-
msg = (
29-
f"❌ 缺少环境变量:{', '.join(missing_vars)},请先设置,例如:\n"
30-
f" export URL=http://localhost:8000/v1/chat/completions\n"
31-
f" export TEMPLATE=TOKEN_LOGPROB"
32-
)
33-
base_logger.error(msg)
34-
sys.exit(33) # 终止程序
35-
if not TEMPLATE:
36-
base_logger.warning("未启用请求模板,请在Case中自行设置请求模板")
32+
33+
if not URL:
34+
msg = (
35+
f"❌ 缺少环境变量:{', '.join(missing_vars)},请先设置,例如:\n"
36+
f" export URL=http://localhost:8000/v1/chat/completions\n"
37+
f" export TEMPLATE=TOKEN_LOGPROB"
38+
)
39+
base_logger.error(msg)
40+
sys.exit(33) # 终止程序
41+
42+
if not TEMPLATE:
43+
base_logger.warning("⚠️ 未设置 TEMPLATE,请确保在用例中显式传入请求模板。")

test/ce/server/core/utils.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@ def send_request(url, payload, timeout=600, stream=False):
3737
3838
Returns:
3939
response: 请求的响应结果,如果请求失败则返回None。
40-
41-
Raises:
42-
None
43-
4440
"""
4541
headers = {
4642
"Content-Type": "application/json",
@@ -60,7 +56,7 @@ def send_request(url, payload, timeout=600, stream=False):
6056

6157

6258
def get_stream_chunks(response):
63-
"""解析流式返回,生成chunk List[dict]"""
59+
"""解析流式返回,生成 chunk List[dict]"""
6460
chunks = []
6561

6662
if response.status_code == 200:
@@ -82,3 +78,22 @@ def get_stream_chunks(response):
8278
base_logger.error("返回内容:", response.text)
8379

8480
return chunks
81+
82+
83+
def get_token_list(response):
84+
"""解析 response 中的 token 文本列表"""
85+
token_list = []
86+
87+
try:
88+
content_logprobs = response["choices"][0]["logprobs"]["content"]
89+
except (KeyError, IndexError, TypeError) as e:
90+
base_logger.error(f"解析失败:{e}")
91+
return []
92+
93+
for token_info in content_logprobs:
94+
token = token_info.get("token")
95+
if token is not None:
96+
token_list.append(token)
97+
98+
base_logger.info(f"Token List:{token_list}")
99+
return token_list

test/ce/server/test_base_chat.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import json
1111

12-
from core import TEMPLATE, URL, build_request_payload, send_request
12+
from core import TEMPLATE, URL, build_request_payload, get_token_list, send_request
1313

1414

1515
def test_stream_response():
@@ -76,18 +76,24 @@ def test_logprobs_enabled():
7676
def test_stop_sequence():
7777
data = {
7878
"stream": False,
79-
"stop": ["果冻"],
79+
"stop": [""],
8080
"messages": [
81-
{"role": "user", "content": "你要严格按照我接下来的话输出,输出冒号后面的内容,请输出:这是第一段。果冻这是第二段啦啦啦啦啦。"},
81+
{
82+
"role": "user",
83+
"content": "你要严格按照我接下来的话输出,输出冒号后面的内容,请输出:这是第一段。这是第二段啦啦啦啦啦。",
84+
},
8285
],
8386
"max_tokens": 20,
8487
"top_p": 0,
8588
}
8689
payload = build_request_payload(TEMPLATE, data)
8790
resp = send_request(URL, payload).json()
8891
content = resp["choices"][0]["message"]["content"]
92+
token_list = get_token_list(resp)
8993
print("截断输出:", content)
9094
assert "第二段" not in content
95+
assert "第二段" not in token_list
96+
assert "。" in token_list, "没有找到。符号"
9197

9298

9399
def test_sampling_parameters():
@@ -125,7 +131,7 @@ def test_multi_turn_conversation():
125131

126132

127133
def test_bad_words_filtering():
128-
banned_tokens = ["和", "呀"]
134+
banned_tokens = [""]
129135

130136
data = {
131137
"stream": False,
@@ -140,36 +146,14 @@ def test_bad_words_filtering():
140146

141147
payload = build_request_payload(TEMPLATE, data)
142148
response = send_request(URL, payload).json()
143-
144149
content = response["choices"][0]["message"]["content"]
145150
print("生成内容:", content)
151+
token_list = get_token_list(response)
146152

147153
for word in banned_tokens:
148-
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
154+
assert word not in token_list, f"bad_word '{word}' 不应出现在生成结果中"
149155

150-
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
151-
152-
data = {
153-
"stream": False,
154-
"messages": [
155-
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
156-
{"role": "user", "content": "请输出冒号后面的字,一模一样: 我爱吃果冻,苹果,香蕉,和荔枝呀呀呀"},
157-
],
158-
"top_p": 0,
159-
"max_tokens": 69,
160-
# "bad_words": banned_tokens,
161-
}
162-
163-
payload = build_request_payload(TEMPLATE, data)
164-
response = send_request(URL, payload).json()
165-
166-
content = response["choices"][0]["message"]["content"]
167-
print("生成内容:", content)
168-
169-
for word in banned_tokens:
170-
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
171-
172-
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
156+
print("test_bad_words_filtering 正例验证通过")
173157

174158

175159
def test_bad_words_filtering1():
@@ -195,8 +179,10 @@ def test_bad_words_filtering1():
195179
for word in banned_tokens:
196180
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
197181

198-
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
199-
word = "呀呀"
182+
print("test_bad_words_filtering1 通过:生成结果未包含被禁词")
183+
184+
# 正例验证
185+
word = "呀"
200186
data = {
201187
"stream": False,
202188
"messages": [
@@ -212,7 +198,7 @@ def test_bad_words_filtering1():
212198

213199
content = response["choices"][0]["message"]["content"]
214200
print("生成内容:", content)
201+
token_list = get_token_list(response)
202+
assert word in token_list, f"'{word}' 应出现在生成结果中"
215203

216-
assert word in content, f" '{word}' 应出现在生成结果中"
217-
218-
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
204+
print("test_bad_words_filtering1 正例验证通过")

0 commit comments

Comments
 (0)