Skip to content

Commit 3dd8492

Browse files
sunlei1024DDDivano
andauthored
[Bugfix] Fix uninitialized decoded_token and add corresponding unit test (#3201)
* Update test_base_chat.py (#3183) * [Bugfix] Fix uninitialized decoded_token and add corresponding unit test. --------- Co-authored-by: Divano <[email protected]>
1 parent bd77a3a commit 3dd8492

File tree

3 files changed

+305
-1
lines changed

3 files changed

+305
-1
lines changed

fastdeploy/entrypoints/llm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ def _add_request(
285285
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
286286
return req_ids
287287

288+
def _decode_token(self, token_id: int) -> str:
289+
"""Decodes a single token ID into its string representation."""
290+
return self.llm_engine.data_processor.process_logprob_response([token_id], clean_up_tokenization_spaces=False)
291+
288292
def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
289293
"""
290294
Constructs a list of dictionaries mapping token IDs to Logprob objects,
@@ -318,8 +322,9 @@ def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: i
318322
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
319323
result = []
320324
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
325+
321326
logprob_dict = {
322-
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=None)
327+
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=self._decode_token(token_id))
323328
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
324329
}
325330
result.append(logprob_dict)

test/ce/server/test_base_chat.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
#!/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# @author DDDivano
4+
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
5+
6+
"""
7+
some basic check for fd web api
8+
"""
9+
10+
import json
11+
12+
from core import TEMPLATE, URL, build_request_payload, send_request
13+
14+
15+
def test_stream_response():
16+
data = {
17+
"stream": True,
18+
"messages": [
19+
{"role": "system", "content": "你是一个知识渊博的 AI 助手"},
20+
{"role": "user", "content": "讲讲爱因斯坦的相对论"},
21+
],
22+
"max_tokens": 10,
23+
}
24+
payload = build_request_payload(TEMPLATE, data)
25+
resp = send_request(URL, payload, stream=True)
26+
27+
output = ""
28+
for line in resp.iter_lines(decode_unicode=True):
29+
if line.strip() == "" or not line.startswith("data: "):
30+
continue
31+
line = line[len("data: ") :]
32+
if line.strip() == "[DONE]":
33+
break
34+
chunk = json.loads(line)
35+
delta = chunk.get("choices", [{}])[0].get("delta", {})
36+
output += delta.get("content", "")
37+
38+
print("Stream输出:", output)
39+
assert "相对论" in output or len(output) > 0
40+
41+
42+
def test_system_prompt_effect():
43+
data = {
44+
"stream": False,
45+
"messages": [
46+
{"role": "system", "content": "请用一句话回答"},
47+
{"role": "user", "content": "什么是人工智能?"},
48+
],
49+
"max_tokens": 30,
50+
}
51+
payload = build_request_payload(TEMPLATE, data)
52+
resp = send_request(URL, payload).json()
53+
content = resp["choices"][0]["message"]["content"]
54+
print("内容输出:", content)
55+
assert len(content) < 50
56+
57+
58+
def test_logprobs_enabled():
59+
data = {
60+
"stream": False,
61+
"logprobs": True,
62+
"top_logprobs": 5,
63+
"messages": [{"role": "user", "content": "非洲的首都是?"}],
64+
"max_tokens": 3,
65+
}
66+
payload = build_request_payload(TEMPLATE, data)
67+
resp = send_request(URL, payload).json()
68+
logprob_data = resp["choices"][0].get("logprobs")
69+
print("LogProbs:", logprob_data)
70+
assert logprob_data is not None
71+
content_logprobs = logprob_data.get("content", [])
72+
assert isinstance(content_logprobs, list)
73+
assert all("token" in item for item in content_logprobs)
74+
75+
76+
def test_stop_sequence():
77+
data = {
78+
"stream": False,
79+
"stop": ["果冻"],
80+
"messages": [
81+
{
82+
"role": "user",
83+
"content": "你要严格按照我接下来的话输出,输出冒号后面的内容,请输出:这是第一段。果冻这是第二段啦啦啦啦啦。",
84+
},
85+
],
86+
"max_tokens": 20,
87+
"top_p": 0,
88+
}
89+
payload = build_request_payload(TEMPLATE, data)
90+
resp = send_request(URL, payload).json()
91+
content = resp["choices"][0]["message"]["content"]
92+
print("截断输出:", content)
93+
assert "第二段" not in content
94+
95+
96+
def test_sampling_parameters():
97+
data = {
98+
"stream": False,
99+
"temperature": 0,
100+
"top_p": 0,
101+
"messages": [
102+
{"role": "user", "content": "1+1=?,直接回答答案"},
103+
],
104+
"max_tokens": 50,
105+
}
106+
payload = build_request_payload(TEMPLATE, data)
107+
resp = send_request(URL, payload).json()
108+
answer = resp["choices"][0]["message"]["content"]
109+
print("Sampling输出:", answer)
110+
assert any(ans in answer for ans in ["2", "二"])
111+
112+
113+
def test_multi_turn_conversation():
114+
data = {
115+
"stream": False,
116+
"messages": [
117+
{"role": "user", "content": "牛顿是谁?"},
118+
{"role": "assistant", "content": "牛顿是一位物理学家。"},
119+
{"role": "user", "content": "他提出了什么理论?"},
120+
],
121+
"max_tokens": 30,
122+
}
123+
payload = build_request_payload(TEMPLATE, data)
124+
resp = send_request(URL, payload).json()
125+
content = resp["choices"][0]["message"]["content"]
126+
print("多轮记忆:", content)
127+
assert "三大运动定律" in content or "万有引力" in content
128+
129+
130+
def test_bad_words_filtering():
131+
banned_tokens = ["和", "呀"]
132+
133+
data = {
134+
"stream": False,
135+
"messages": [
136+
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
137+
{"role": "user", "content": "请输出冒号后面的字: 我爱吃果冻,和苹果,香蕉,和荔枝"},
138+
],
139+
"top_p": 0,
140+
"max_tokens": 69,
141+
"bad_words": banned_tokens,
142+
}
143+
144+
payload = build_request_payload(TEMPLATE, data)
145+
response = send_request(URL, payload).json()
146+
147+
content = response["choices"][0]["message"]["content"]
148+
print("生成内容:", content)
149+
150+
for word in banned_tokens:
151+
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
152+
153+
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
154+
155+
data = {
156+
"stream": False,
157+
"messages": [
158+
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
159+
{"role": "user", "content": "请输出冒号后面的字,一模一样: 我爱吃果冻,苹果,香蕉,和荔枝呀呀呀"},
160+
],
161+
"top_p": 0,
162+
"max_tokens": 69,
163+
# "bad_words": banned_tokens,
164+
}
165+
166+
payload = build_request_payload(TEMPLATE, data)
167+
response = send_request(URL, payload).json()
168+
169+
content = response["choices"][0]["message"]["content"]
170+
print("生成内容:", content)
171+
172+
for word in banned_tokens:
173+
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
174+
175+
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
176+
177+
178+
def test_bad_words_filtering1():
179+
banned_tokens = ["和", "呀"]
180+
181+
data = {
182+
"stream": False,
183+
"messages": [
184+
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
185+
{"role": "user", "content": "请输出冒号后面的字: 我爱吃果冻,和苹果,香蕉,和荔枝"},
186+
],
187+
"top_p": 0,
188+
"max_tokens": 69,
189+
"bad_words": banned_tokens,
190+
}
191+
192+
payload = build_request_payload(TEMPLATE, data)
193+
response = send_request(URL, payload).json()
194+
195+
content = response["choices"][0]["message"]["content"]
196+
print("生成内容:", content)
197+
198+
for word in banned_tokens:
199+
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
200+
201+
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
202+
word = "呀呀"
203+
data = {
204+
"stream": False,
205+
"messages": [
206+
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
207+
{"role": "user", "content": "请输出冒号后面的字,一模一样: 我爱吃果冻,苹果,香蕉,和荔枝呀呀呀"},
208+
],
209+
"top_p": 0,
210+
"max_tokens": 69,
211+
}
212+
213+
payload = build_request_payload(TEMPLATE, data)
214+
response = send_request(URL, payload).json()
215+
216+
content = response["choices"][0]["message"]["content"]
217+
print("生成内容:", content)
218+
219+
assert word in content, f" '{word}' 应出现在生成结果中"
220+
221+
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from fastdeploy.entrypoints.llm import LLM
5+
from fastdeploy.worker.output import Logprob, LogprobsLists
6+
7+
8+
def get_patch_path(cls, method="__init__"):
9+
return f"{cls.__module__}.{cls.__qualname__}.{method}"
10+
11+
12+
class TestBuildSampleLogprobs(unittest.TestCase):
13+
14+
def setUp(self):
15+
"""
16+
Set up the test environment by creating an instance of the LLM class using Mock.
17+
"""
18+
patch_llm = get_patch_path(LLM)
19+
with patch(patch_llm, return_value=None):
20+
self.llm = LLM()
21+
# mock d data_processor
22+
self.llm.llm_engine = MagicMock()
23+
self.llm.llm_engine.data_processor.process_logprob_response.side_effect = (
24+
lambda ids, **kwargs: f"token_{ids[0]}"
25+
)
26+
27+
def test_build_sample_logprobs_basic(self):
28+
"""
29+
Test case for building sample logprobs when `topk_logprobs` is valid.
30+
"""
31+
logprob_token_ids = [[100, 101, 102]]
32+
logprobs = [[-0.1, -0.5, -1.0]]
33+
sampled_token_ranks = [0]
34+
35+
logprobs_lists = LogprobsLists(
36+
logprob_token_ids=logprob_token_ids, logprobs=logprobs, sampled_token_ranks=sampled_token_ranks
37+
)
38+
39+
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
40+
41+
expected = [
42+
{
43+
101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"),
44+
102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"),
45+
}
46+
]
47+
48+
self.assertEqual(result, expected)
49+
50+
def test_build_sample_logprobs_empty_input(self):
51+
"""
52+
Test case where `logprob_token_ids` is empty.
53+
"""
54+
logprobs_lists = MagicMock(spec=LogprobsLists)
55+
logprobs_lists.logprob_token_ids = []
56+
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
57+
self.assertIsNone(result)
58+
59+
def test_build_sample_logprobs_invalid_topk(self):
60+
"""
61+
Test case where `topk` value exceeds length of first element in `logprob_token_ids`.
62+
"""
63+
logprobs_lists = MagicMock(spec=LogprobsLists)
64+
logprobs_lists.logprob_token_ids = [[100]]
65+
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
66+
self.assertIsNone(result)
67+
68+
def test_decode_token(self):
69+
"""
70+
Test case for decoding a single token ID.
71+
"""
72+
token_id = 123
73+
decoded = self.llm._decode_token(token_id)
74+
self.assertEqual(decoded, "token_123")
75+
76+
77+
if __name__ == "__main__":
78+
unittest.main()

0 commit comments

Comments
 (0)