1+ from transformers import AutoTokenizer , PreTrainedTokenizerFast
2+ from http .server import HTTPServer , BaseHTTPRequestHandler
3+ import json
4+ import argparse
5+ import uuid
6+
7+ # 全局字典:存储 uid 到 Tokenizer_Http 实例的映射
8+ tokenizers = {}
9+
10+ class Tokenizer_Http ():
11+ def __init__ (self ):
12+ model_id = "qwen3_tokenizer"
13+ self .tokenizer = AutoTokenizer .from_pretrained (model_id )
14+
15+ model_id = "qwen2.5_tokenizer"
16+ self .tokenizer_25 = AutoTokenizer .from_pretrained (model_id )
17+ self .messages = [
18+ {"role" : "system" , "content" : "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." },
19+ ]
20+ self .token_ids = []
21+
22+ self .token_ids_cache = []
23+
24+ def remove_think (self , text :str ):
25+ thinking_start = '<think>'
26+ thinking_end = '</think>'
27+ # cut thinking_start and thinking_end
28+ text = text [text .find (thinking_start )+ len (thinking_start ):text .find (thinking_end )]
29+ # trim
30+ text = text .strip ("\n " )
31+ return text
32+
33+ def encode (self , prompt :str , last_reply :str = None ):
34+ if last_reply is not None :
35+ last_reply = self .remove_think (last_reply )
36+ self .messages .append ({"role" : "assistant" , "content" : last_reply })
37+ text = self .tokenizer_25 .apply_chat_template (
38+ self .messages ,
39+ tokenize = False ,
40+ add_generation_prompt = True
41+ )
42+ print ("fff生成的文本:\n ============\n " , text , "============\n " )
43+ self .token_ids = self .tokenizer .encode (text )[:- 3 ]
44+ print ("diff:" , self .decode (self .token_ids ))
45+ # if not prompt.endswith("/no_think"):
46+ # prompt+="/no_think"
47+ print ("prompt:" , prompt )
48+ self .messages .append ({"role" : "user" , "content" : prompt })
49+
50+ text = self .tokenizer_25 .apply_chat_template (
51+ self .messages ,
52+ tokenize = False ,
53+ add_generation_prompt = True
54+ )
55+ print ("生成的文本:\n ============\n " , text , "============\n " )
56+ token_ids = self .tokenizer .encode (text )
57+ # 找出新增部分
58+ diff = token_ids [len (self .token_ids ):]
59+ self .token_ids = token_ids
60+ print ("diff:" , self .decode (diff ))
61+ return token_ids , diff
62+
63+ def decode (self , token_ids ):
64+ self .token_ids_cache += token_ids
65+ text = self .tokenizer .decode (self .token_ids_cache )
66+ if "\ufffd " in text :
67+ print ("text 中包含非法字符" )
68+ return ""
69+ else :
70+ self .token_ids_cache .clear ()
71+ return text
72+
73+
74+ @property
75+ def bos_id (self ):
76+ return self .tokenizer .bos_token_id
77+
78+ @property
79+ def eos_id (self ):
80+ return self .tokenizer .eos_token_id
81+
82+ @property
83+ def bos_token (self ):
84+ return self .tokenizer .bos_token
85+
86+ @property
87+ def eos_token (self ):
88+ return self .tokenizer .eos_token
89+
90+ def reset (self , system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." ):
91+ self .messages = [
92+ {"role" : "system" , "content" : system_prompt },
93+ ]
94+ text = self .tokenizer_25 .apply_chat_template (
95+ self .messages ,
96+ tokenize = False ,
97+ add_generation_prompt = True
98+ )
99+ token_ids = self .tokenizer .encode (text )[:- 3 ]
100+ self .token_ids = token_ids
101+ print (self .decode (token_ids ))
102+ return token_ids
103+
104+
105+ class Request (BaseHTTPRequestHandler ):
106+ timeout = 5
107+ server_version = 'Apache'
108+
109+ def do_GET (self ):
110+ print ("GET 请求路径:" , self .path )
111+ self .send_response (200 )
112+ self .send_header ("Content-Type" , "application/json" )
113+ self .end_headers ()
114+
115+ # 新增接口:获取 uid
116+ if '/get_uid' in self .path :
117+ new_uid = str (uuid .uuid4 ())
118+ print ("新 uid:" , new_uid )
119+ # 为该 uid 创建一个新的 Tokenizer_Http 实例
120+ tokenizers [new_uid ] = Tokenizer_Http ()
121+ msg = json .dumps ({'uid' : new_uid })
122+ elif '/bos_id' in self .path :
123+ # 获取 uid 参数(例如 ?uid=xxx)
124+ uid = self .get_query_param ("uid" )
125+ instance : Tokenizer_Http = tokenizers .get (uid )
126+ if instance is None :
127+ msg = json .dumps ({'error' : 'Invalid uid' })
128+ else :
129+ bos_id = instance .bos_id
130+ msg = json .dumps ({'bos_id' : bos_id if bos_id is not None else - 1 })
131+ elif '/eos_id' in self .path :
132+ uid = self .get_query_param ("uid" )
133+ instance : Tokenizer_Http = tokenizers .get (uid )
134+ if instance is None :
135+ msg = json .dumps ({'error' : 'Invalid uid' })
136+ else :
137+ eos_id = instance .eos_id
138+ msg = json .dumps ({'eos_id' : eos_id if eos_id is not None else - 1 })
139+ else :
140+ msg = json .dumps ({'error' : 'Invalid GET endpoint' })
141+
142+ print ("响应消息:" , msg )
143+ self .wfile .write (msg .encode ())
144+
145+ def do_POST (self ):
146+ content_length = int (self .headers .get ('content-length' , 0 ))
147+ data = self .rfile .read (content_length ).decode ()
148+ print ("POST 请求路径:" , self .path )
149+ print ("接收到的数据:" , data )
150+ req = json .loads (data )
151+
152+ self .send_response (200 )
153+ self .send_header ("Content-Type" , "application/json" )
154+ self .end_headers ()
155+
156+ if '/encode' in self .path :
157+ # 请求数据中必须包含 uid, text, 和可选的 last_reply
158+ uid = req .get ('uid' )
159+ prompt = req .get ('text' )
160+ last_reply = req .get ('last_reply' )
161+ instance : Tokenizer_Http = tokenizers .get (uid )
162+ if instance is None :
163+ msg = json .dumps ({'error' : 'Invalid uid' })
164+ else :
165+ token_ids , diff = instance .encode (prompt , last_reply )
166+ msg = json .dumps ({'token_ids' : token_ids , 'diff' : diff })
167+ elif '/decode' in self .path :
168+ uid = req .get ('uid' )
169+ token_ids = req .get ('token_ids' )
170+ instance : Tokenizer_Http = tokenizers .get (uid )
171+ if instance is None :
172+ msg = json .dumps ({'error' : 'Invalid uid' })
173+ else :
174+ text = instance .decode (token_ids )
175+ msg = json .dumps ({'text' : text })
176+ elif '/reset' in self .path :
177+ uid = req .get ("uid" )
178+ system_prompt = req .get ("system_prompt" )
179+ instance : Tokenizer_Http = tokenizers .get (uid )
180+ if instance is None :
181+ msg = json .dumps ({'error' : 'Invalid uid' })
182+ else :
183+ if system_prompt is not None :
184+ print ("system_prompt:" , system_prompt )
185+ token_ids = instance .reset (system_prompt )
186+ msg = json .dumps ({'token_ids' : token_ids })
187+ else :
188+ token_ids = instance .reset ()
189+ msg = json .dumps ({'token_ids' : token_ids })
190+ else :
191+ msg = json .dumps ({'error' : 'Invalid POST endpoint' })
192+
193+ print ("响应消息:" , msg )
194+ self .wfile .write (msg .encode ())
195+
196+ def get_query_param (self , key ):
197+ """
198+ 辅助函数:从 GET 请求的 URL 中获取查询参数的值
199+ 例如:/bos_id?uid=xxx
200+ """
201+ from urllib .parse import urlparse , parse_qs
202+ query = urlparse (self .path ).query
203+ params = parse_qs (query )
204+ values = params .get (key )
205+ return values [0 ] if values else None
206+
207+ if __name__ == "__main__" :
208+ parser = argparse .ArgumentParser ()
209+ parser .add_argument ('--host' , type = str , default = '0.0.0.0' )
210+ parser .add_argument ('--port' , type = int , default = 12345 )
211+ args = parser .parse_args ()
212+
213+ host = (args .host , args .port )
214+ print ('Server running at http://%s:%s' % host )
215+ server = HTTPServer (host , Request )
216+ server .serve_forever ()
0 commit comments