Skip to content

Commit 805bc4f

Browse files
committed
添加qwen3 tokenizer
1 parent 0c67188 commit 805bc4f

File tree

7 files changed

+909367
-0
lines changed

7 files changed

+909367
-0
lines changed

scripts/qwen3_tokenizer.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
from transformers import AutoTokenizer, PreTrainedTokenizerFast
2+
from http.server import HTTPServer, BaseHTTPRequestHandler
3+
import json
4+
import argparse
5+
import uuid
6+
7+
# 全局字典:存储 uid 到 Tokenizer_Http 实例的映射
8+
tokenizers = {}
9+
10+
class Tokenizer_Http():
11+
def __init__(self):
12+
model_id = "qwen3_tokenizer"
13+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
14+
15+
model_id = "qwen2.5_tokenizer"
16+
self.tokenizer_25 = AutoTokenizer.from_pretrained(model_id)
17+
self.messages = [
18+
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
19+
]
20+
self.token_ids = []
21+
22+
self.token_ids_cache = []
23+
24+
def remove_think(self, text:str):
25+
thinking_start = '<think>'
26+
thinking_end = '</think>'
27+
# cut thinking_start and thinking_end
28+
text = text[text.find(thinking_start)+len(thinking_start):text.find(thinking_end)]
29+
# trim
30+
text = text.strip("\n")
31+
return text
32+
33+
def encode(self, prompt:str, last_reply:str=None):
34+
if last_reply is not None:
35+
last_reply = self.remove_think(last_reply)
36+
self.messages.append({"role": "assistant", "content": last_reply})
37+
text = self.tokenizer_25.apply_chat_template(
38+
self.messages,
39+
tokenize=False,
40+
add_generation_prompt=True
41+
)
42+
print("fff生成的文本:\n============\n", text, "============\n")
43+
self.token_ids = self.tokenizer.encode(text)[:-3]
44+
print("diff:", self.decode(self.token_ids))
45+
# if not prompt.endswith("/no_think"):
46+
# prompt+="/no_think"
47+
print("prompt:", prompt)
48+
self.messages.append({"role": "user", "content": prompt})
49+
50+
text = self.tokenizer_25.apply_chat_template(
51+
self.messages,
52+
tokenize=False,
53+
add_generation_prompt=True
54+
)
55+
print("生成的文本:\n============\n", text, "============\n")
56+
token_ids = self.tokenizer.encode(text)
57+
# 找出新增部分
58+
diff = token_ids[len(self.token_ids):]
59+
self.token_ids = token_ids
60+
print("diff:", self.decode(diff))
61+
return token_ids, diff
62+
63+
def decode(self, token_ids):
64+
self.token_ids_cache += token_ids
65+
text = self.tokenizer.decode(self.token_ids_cache)
66+
if "\ufffd" in text:
67+
print("text 中包含非法字符")
68+
return ""
69+
else:
70+
self.token_ids_cache.clear()
71+
return text
72+
73+
74+
@property
75+
def bos_id(self):
76+
return self.tokenizer.bos_token_id
77+
78+
@property
79+
def eos_id(self):
80+
return self.tokenizer.eos_token_id
81+
82+
@property
83+
def bos_token(self):
84+
return self.tokenizer.bos_token
85+
86+
@property
87+
def eos_token(self):
88+
return self.tokenizer.eos_token
89+
90+
def reset(self, system_prompt="You are Qwen, created by Alibaba Cloud. You are a helpful assistant."):
91+
self.messages = [
92+
{"role": "system", "content": system_prompt},
93+
]
94+
text = self.tokenizer_25.apply_chat_template(
95+
self.messages,
96+
tokenize=False,
97+
add_generation_prompt=True
98+
)
99+
token_ids = self.tokenizer.encode(text)[:-3]
100+
self.token_ids = token_ids
101+
print(self.decode(token_ids))
102+
return token_ids
103+
104+
105+
class Request(BaseHTTPRequestHandler):
106+
timeout = 5
107+
server_version = 'Apache'
108+
109+
def do_GET(self):
110+
print("GET 请求路径:", self.path)
111+
self.send_response(200)
112+
self.send_header("Content-Type", "application/json")
113+
self.end_headers()
114+
115+
# 新增接口:获取 uid
116+
if '/get_uid' in self.path:
117+
new_uid = str(uuid.uuid4())
118+
print("新 uid:", new_uid)
119+
# 为该 uid 创建一个新的 Tokenizer_Http 实例
120+
tokenizers[new_uid] = Tokenizer_Http()
121+
msg = json.dumps({'uid': new_uid})
122+
elif '/bos_id' in self.path:
123+
# 获取 uid 参数(例如 ?uid=xxx)
124+
uid = self.get_query_param("uid")
125+
instance: Tokenizer_Http = tokenizers.get(uid)
126+
if instance is None:
127+
msg = json.dumps({'error': 'Invalid uid'})
128+
else:
129+
bos_id = instance.bos_id
130+
msg = json.dumps({'bos_id': bos_id if bos_id is not None else -1})
131+
elif '/eos_id' in self.path:
132+
uid = self.get_query_param("uid")
133+
instance: Tokenizer_Http = tokenizers.get(uid)
134+
if instance is None:
135+
msg = json.dumps({'error': 'Invalid uid'})
136+
else:
137+
eos_id = instance.eos_id
138+
msg = json.dumps({'eos_id': eos_id if eos_id is not None else -1})
139+
else:
140+
msg = json.dumps({'error': 'Invalid GET endpoint'})
141+
142+
print("响应消息:", msg)
143+
self.wfile.write(msg.encode())
144+
145+
def do_POST(self):
146+
content_length = int(self.headers.get('content-length', 0))
147+
data = self.rfile.read(content_length).decode()
148+
print("POST 请求路径:", self.path)
149+
print("接收到的数据:", data)
150+
req = json.loads(data)
151+
152+
self.send_response(200)
153+
self.send_header("Content-Type", "application/json")
154+
self.end_headers()
155+
156+
if '/encode' in self.path:
157+
# 请求数据中必须包含 uid, text, 和可选的 last_reply
158+
uid = req.get('uid')
159+
prompt = req.get('text')
160+
last_reply = req.get('last_reply')
161+
instance: Tokenizer_Http = tokenizers.get(uid)
162+
if instance is None:
163+
msg = json.dumps({'error': 'Invalid uid'})
164+
else:
165+
token_ids, diff = instance.encode(prompt, last_reply)
166+
msg = json.dumps({'token_ids': token_ids, 'diff': diff})
167+
elif '/decode' in self.path:
168+
uid = req.get('uid')
169+
token_ids = req.get('token_ids')
170+
instance: Tokenizer_Http = tokenizers.get(uid)
171+
if instance is None:
172+
msg = json.dumps({'error': 'Invalid uid'})
173+
else:
174+
text = instance.decode(token_ids)
175+
msg = json.dumps({'text': text})
176+
elif '/reset' in self.path:
177+
uid = req.get("uid")
178+
system_prompt = req.get("system_prompt")
179+
instance: Tokenizer_Http = tokenizers.get(uid)
180+
if instance is None:
181+
msg = json.dumps({'error': 'Invalid uid'})
182+
else:
183+
if system_prompt is not None:
184+
print("system_prompt:", system_prompt)
185+
token_ids = instance.reset(system_prompt)
186+
msg = json.dumps({'token_ids': token_ids})
187+
else:
188+
token_ids = instance.reset()
189+
msg = json.dumps({'token_ids': token_ids})
190+
else:
191+
msg = json.dumps({'error': 'Invalid POST endpoint'})
192+
193+
print("响应消息:", msg)
194+
self.wfile.write(msg.encode())
195+
196+
def get_query_param(self, key):
197+
"""
198+
辅助函数:从 GET 请求的 URL 中获取查询参数的值
199+
例如:/bos_id?uid=xxx
200+
"""
201+
from urllib.parse import urlparse, parse_qs
202+
query = urlparse(self.path).query
203+
params = parse_qs(query)
204+
values = params.get(key)
205+
return values[0] if values else None
206+
207+
if __name__ == "__main__":
208+
parser = argparse.ArgumentParser()
209+
parser.add_argument('--host', type=str, default='0.0.0.0')
210+
parser.add_argument('--port', type=int, default=12345)
211+
args = parser.parse_args()
212+
213+
host = (args.host, args.port)
214+
print('Server running at http://%s:%s' % host)
215+
server = HTTPServer(host, Request)
216+
server.serve_forever()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"architectures": [
3+
"Qwen3ForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 151643,
8+
"eos_token_id": 151645,
9+
"head_dim": 128,
10+
"hidden_act": "silu",
11+
"hidden_size": 2560,
12+
"initializer_range": 0.02,
13+
"intermediate_size": 9728,
14+
"max_position_embeddings": 40960,
15+
"max_window_layers": 36,
16+
"model_type": "qwen3",
17+
"num_attention_heads": 32,
18+
"num_hidden_layers": 36,
19+
"num_key_value_heads": 8,
20+
"rms_norm_eps": 1e-06,
21+
"rope_scaling": null,
22+
"rope_theta": 1000000,
23+
"sliding_window": null,
24+
"tie_word_embeddings": true,
25+
"torch_dtype": "bfloat16",
26+
"transformers_version": "4.51.0",
27+
"use_cache": true,
28+
"use_sliding_window": false,
29+
"vocab_size": 151936
30+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"bos_token_id": 151643,
3+
"do_sample": true,
4+
"eos_token_id": [
5+
151645,
6+
151643
7+
],
8+
"pad_token_id": 151643,
9+
"temperature": 0.6,
10+
"top_k": 20,
11+
"top_p": 0.95,
12+
"transformers_version": "4.51.0"
13+
}

0 commit comments

Comments
 (0)