Skip to content

Commit 9f1936a

Browse files
authored
Ce add repitation early stop cases (#3213)
* add repitation early stop cases * add repitation early stop cases
1 parent 1e9a8e8 commit 9f1936a

File tree

3 files changed

+106
-3
lines changed

3 files changed

+106
-3
lines changed

test/ce/server/core/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,22 @@
1313
from .request_template import TEMPLATES
1414
from .utils import (
1515
build_request_payload,
16+
get_logprobs_list,
17+
get_probs_list,
1618
get_stream_chunks,
1719
get_token_list,
1820
send_request,
1921
)
2022

21-
__all__ = ["build_request_payload", "send_request", "TEMPLATES", "get_stream_chunks", "get_token_list"]
23+
__all__ = [
24+
"build_request_payload",
25+
"send_request",
26+
"TEMPLATES",
27+
"get_stream_chunks",
28+
"get_token_list",
29+
"get_logprobs_list",
30+
"get_probs_list",
31+
]
2232

2333
# 检查环境变量是否存在
2434
URL = os.environ.get("URL")

test/ce/server/core/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
55

66
import json
7+
import math
78

89
import requests
910
from core import TEMPLATES, base_logger
@@ -97,3 +98,41 @@ def get_token_list(response):
9798

9899
base_logger.info(f"Token List:{token_list}")
99100
return token_list
101+
102+
103+
def get_logprobs_list(response):
104+
"""解析 response 中的 token 文本列表"""
105+
logprobs_list = []
106+
107+
try:
108+
content_logprobs = response["choices"][0]["logprobs"]["content"]
109+
except (KeyError, IndexError, TypeError) as e:
110+
base_logger.error(f"解析失败:{e}")
111+
return []
112+
113+
for token_info in content_logprobs:
114+
token = token_info.get("logprob")
115+
if token is not None:
116+
logprobs_list.append(token)
117+
118+
base_logger.info(f"Logprobs List:{logprobs_list}")
119+
return logprobs_list
120+
121+
122+
def get_probs_list(response):
123+
"""解析 response 中的 token 文本列表"""
124+
probs_list = []
125+
126+
try:
127+
content_logprobs = response["choices"][0]["logprobs"]["content"]
128+
except (KeyError, IndexError, TypeError) as e:
129+
base_logger.error(f"解析失败:{e}")
130+
return []
131+
132+
for token_info in content_logprobs:
133+
token = token_info.get("logprob")
134+
if token is not None:
135+
probs_list.append(math.exp(token))
136+
137+
base_logger.info(f"probs List:{probs_list}")
138+
return probs_list

test/ce/server/test_base_chat.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99

1010
import json
1111

12-
from core import TEMPLATE, URL, build_request_payload, get_token_list, send_request
12+
from core import (
13+
TEMPLATE,
14+
URL,
15+
build_request_payload,
16+
get_probs_list,
17+
get_token_list,
18+
send_request,
19+
)
1320

1421

1522
def test_stream_response():
@@ -150,7 +157,7 @@ def test_multi_turn_conversation():
150157

151158

152159
def test_bad_words_filtering():
153-
banned_tokens = [""]
160+
banned_tokens = ["香蕉"]
154161

155162
data = {
156163
"stream": False,
@@ -221,3 +228,50 @@ def test_bad_words_filtering1():
221228
assert word in token_list, f"'{word}' 应出现在生成结果中"
222229

223230
print("test_bad_words_filtering1 正例验证通过")
231+
232+
233+
def test_repetition_early_stop():
234+
"""
235+
用于验证 repetition early stop 功能是否生效:
236+
设置 window_size=6,threshold=0.93,输入内容设计成易重复,观察模型是否提前截断输出。
237+
threshold = 0.93
238+
window_size = 6 这个必须是启动模型的时候加上这个参数 负责不能用!!!!
239+
"""
240+
241+
data = {
242+
"stream": False,
243+
"messages": [
244+
{"role": "user", "content": "输出'我爱吃果冻' 10次"},
245+
],
246+
"max_tokens": 10000,
247+
"temperature": 0.8,
248+
"top_p": 0,
249+
}
250+
251+
payload = build_request_payload(TEMPLATE, data)
252+
response = send_request(URL, payload).json()
253+
content = response["choices"][0]["message"]["content"]
254+
255+
print("🧪 repetition early stop 输出内容:\n", content)
256+
probs_list = get_probs_list(response)
257+
258+
threshold = 0.93
259+
window_size = 6
260+
261+
assert len(probs_list) >= window_size, "列表长度不足 window_size"
262+
263+
# 条件 1:末尾 6 个都 > threshold
264+
tail = probs_list[-window_size:]
265+
assert all(v > threshold for v in tail), "末尾 window_size 个数不全大于阈值"
266+
267+
# 条件 2:前面不能有连续 >=6 个值 > threshold
268+
head = probs_list[:-window_size]
269+
count = 0
270+
for v in head:
271+
if v > threshold:
272+
count += 1
273+
assert count < window_size, f"在末尾之前出现了连续 {count} 个大于阈值的数"
274+
else:
275+
count = 0
276+
277+
print("repetition early stop 功能验证通过")

0 commit comments

Comments
 (0)