Skip to content

Commit eaae4a5

Browse files
authored
Split cases (#3297)
* add repitation early stop cases * add repitation early stop cases * split repetition_early_stop from the base test
1 parent c011cb8 commit eaae4a5

File tree

2 files changed

+55
-55
lines changed

2 files changed

+55
-55
lines changed

test/ce/server/test_base_chat.py

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,7 @@
99

1010
import json
1111

12-
from core import (
13-
TEMPLATE,
14-
URL,
15-
build_request_payload,
16-
get_probs_list,
17-
get_token_list,
18-
send_request,
19-
)
12+
from core import TEMPLATE, URL, build_request_payload, get_token_list, send_request
2013

2114

2215
def test_stream_response():
@@ -278,50 +271,3 @@ def test_bad_words_filtering1():
278271
assert word in token_list, f"'{word}' 应出现在生成结果中"
279272

280273
print("test_bad_words_filtering1 正例验证通过")
281-
282-
283-
def test_repetition_early_stop():
284-
"""
285-
用于验证 repetition early stop 功能是否生效:
286-
设置 window_size=6,threshold=0.93,输入内容设计成易重复,观察模型是否提前截断输出。
287-
threshold = 0.93
288-
window_size = 6 这个必须是启动模型的时候加上这个参数 负责不能用!!!!
289-
"""
290-
291-
data = {
292-
"stream": False,
293-
"messages": [
294-
{"role": "user", "content": "输出'我爱吃果冻' 10次"},
295-
],
296-
"max_tokens": 10000,
297-
"temperature": 0.8,
298-
"top_p": 0,
299-
}
300-
301-
payload = build_request_payload(TEMPLATE, data)
302-
response = send_request(URL, payload).json()
303-
content = response["choices"][0]["message"]["content"]
304-
305-
print("🧪 repetition early stop 输出内容:\n", content)
306-
probs_list = get_probs_list(response)
307-
308-
threshold = 0.93
309-
window_size = 6
310-
311-
assert len(probs_list) >= window_size, "列表长度不足 window_size"
312-
313-
# 条件 1:末尾 6 个都 > threshold
314-
tail = probs_list[-window_size:]
315-
assert all(v > threshold for v in tail), "末尾 window_size 个数不全大于阈值"
316-
317-
# 条件 2:前面不能有连续 >=6 个值 > threshold
318-
head = probs_list[:-window_size]
319-
count = 0
320-
for v in head:
321-
if v > threshold:
322-
count += 1
323-
assert count < window_size, f"在末尾之前出现了连续 {count} 个大于阈值的数"
324-
else:
325-
count = 0
326-
327-
print("repetition early stop 功能验证通过")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# @author DDDivano
4+
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
5+
6+
7+
from core import TEMPLATE, URL, build_request_payload, get_probs_list, send_request
8+
9+
10+
def test_repetition_early_stop():
11+
"""
12+
用于验证 repetition early stop 功能是否生效:
13+
设置 window_size=6,threshold=0.93,输入内容设计成易重复,观察模型是否提前截断输出。
14+
threshold = 0.93
15+
window_size = 6 这个必须是启动模型的时候加上这个参数 负责不能用!!!!
16+
"""
17+
18+
data = {
19+
"stream": False,
20+
"messages": [
21+
{"role": "user", "content": "输出'我爱吃果冻' 10次"},
22+
],
23+
"max_tokens": 10000,
24+
"temperature": 0.8,
25+
"top_p": 0,
26+
}
27+
28+
payload = build_request_payload(TEMPLATE, data)
29+
response = send_request(URL, payload).json()
30+
content = response["choices"][0]["message"]["content"]
31+
32+
print("🧪 repetition early stop 输出内容:\n", content)
33+
probs_list = get_probs_list(response)
34+
35+
threshold = 0.93
36+
window_size = 6
37+
38+
assert len(probs_list) >= window_size, "列表长度不足 window_size"
39+
40+
# 条件 1:末尾 6 个都 > threshold
41+
tail = probs_list[-window_size:]
42+
assert all(v > threshold for v in tail), "末尾 window_size 个数不全大于阈值"
43+
44+
# 条件 2:前面不能有连续 >=6 个值 > threshold
45+
head = probs_list[:-window_size]
46+
count = 0
47+
for v in head:
48+
if v > threshold:
49+
count += 1
50+
assert count < window_size, f"在末尾之前出现了连续 {count} 个大于阈值的数"
51+
else:
52+
count = 0
53+
54+
print("repetition early stop 功能验证通过")

0 commit comments

Comments
 (0)