@@ -28,8 +28,8 @@ use tokio::{
2828use tokio_stream:: wrappers:: UnboundedReceiverStream ;
2929
3030pub ( super ) struct Model {
31+ max_tokens : usize ,
3132 terminal : Terminal ,
32- max_steps : usize ,
3333 sessions : Mutex < BTreeMap < SessionId , SessionInfo > > ,
3434 cache_manager : Mutex < CacheManager > ,
3535}
@@ -48,7 +48,7 @@ impl Model {
4848 let ModelConfig {
4949 path,
5050 gpus,
51- max_steps ,
51+ max_tokens ,
5252 think,
5353 } = config;
5454
@@ -67,26 +67,25 @@ impl Model {
6767 ( utok:: MAX , utok:: MAX )
6868 } ;
6969
70- let service_manager = Arc :: new ( Model {
70+ let model = Arc :: new ( Model {
71+ max_tokens : max_tokens. unwrap_or ( 2 << 10 ) ,
7172 terminal : service. terminal ( ) . clone ( ) ,
72- max_steps : max_steps. unwrap_or ( 2 << 10 ) ,
7373 sessions : Mutex :: new ( sessions) ,
7474 cache_manager : Mutex :: new ( CacheManager :: new ( service. terminal ( ) . clone ( ) ) ) ,
7575 } ) ;
7676
77- let service_manager_for_recv = service_manager. clone ( ) ;
78-
77+ let model_ = model. clone ( ) ;
7978 let join_handle = tokio:: task:: spawn_blocking ( move || {
8079 loop {
8180 let Received { sessions, outputs } = service. recv ( Duration :: from_millis ( 10 ) ) ;
8281
82+ let mut sessions_guard = model_. sessions . lock ( ) . unwrap ( ) ;
8383 // 先处理输出
8484 for ( session_id, tokens) in outputs {
8585 if tokens. is_empty ( ) {
8686 continue ;
8787 }
8888
89- let mut sessions_guard = service_manager_for_recv. sessions . lock ( ) . unwrap ( ) ;
9089 let session_info = sessions_guard. get_mut ( & session_id) . unwrap ( ) ;
9190 // 更新 session_info
9291 session_info. tokens . extend ( & tokens) ;
@@ -111,12 +110,8 @@ impl Model {
111110 & [ ]
112111 } ;
113112
114- let think = service_manager_for_recv
115- . terminal
116- . decode ( think, & mut session_info. buf ) ;
117- let text = service_manager_for_recv
118- . terminal
119- . decode ( tokens, & mut session_info. buf ) ;
113+ let think = model_. terminal . decode ( think, & mut session_info. buf ) ;
114+ let text = model_. terminal . decode ( tokens, & mut session_info. buf ) ;
120115 debug ! ( "解码完成:{tokens:?} -> {think:?} | {text:?}" ) ;
121116
122117 let response = create_chat_completion_response (
@@ -131,16 +126,12 @@ impl Model {
131126
132127 if session_info. sender . send ( message) . is_err ( ) {
133128 info ! ( "{session_id:?} 客户端连接已关闭" ) ;
134- service_manager_for_recv . terminal . stop ( session_id) ;
129+ model_ . terminal . stop ( session_id) ;
135130 }
136131 }
137132
138133 // 处理会话结束
139134 if !sessions. is_empty ( ) {
140- let mut sessions_guard = service_manager_for_recv. sessions . lock ( ) . unwrap ( ) ;
141- let mut cache_manager_guard =
142- service_manager_for_recv. cache_manager . lock ( ) . unwrap ( ) ;
143-
144135 for ( session, reason) in sessions {
145136 let SessionInfo {
146137 tokens,
@@ -152,7 +143,11 @@ impl Model {
152143 let reason = match reason {
153144 // 正常完成,插回cache
154145 ReturnReason :: Finish => {
155- cache_manager_guard. insert ( tokens, session. cache ) ;
146+ model_
147+ . cache_manager
148+ . lock ( )
149+ . unwrap ( )
150+ . insert ( tokens, session. cache ) ;
156151 info ! ( "{:?} 正常完成" , session. id) ;
157152 FinishReason :: Stop
158153 }
@@ -177,12 +172,12 @@ impl Model {
177172 }
178173 } ) ;
179174
180- ( service_manager , join_handle)
175+ ( model , join_handle)
181176 }
182177
183178 pub fn complete_chat (
184179 & self ,
185- completions : CreateChatCompletionRequest ,
180+ req : CreateChatCompletionRequest ,
186181 ) -> Response < BoxBody < Bytes , hyper:: Error > > {
187182 let CreateChatCompletionRequest {
188183 model,
@@ -191,10 +186,10 @@ impl Model {
191186 temperature,
192187 top_p,
193188 ..
194- } = completions ;
189+ } = req ;
195190 let ( sender, receiver) = mpsc:: unbounded_channel ( ) ;
196191
197- let max_steps = max_tokens. map_or ( self . max_steps , |n| n as usize ) ;
192+ let max_tokens = max_tokens. map_or ( self . max_tokens , |n| n as _ ) ;
198193 let sample_args =
199194 SampleArgs :: new ( temperature. unwrap_or ( 0. ) , top_p. unwrap_or ( 1. ) , usize:: MAX ) . unwrap ( ) ;
200195
@@ -242,7 +237,7 @@ impl Model {
242237 . cache_manager
243238 . lock ( )
244239 . unwrap ( )
245- . send ( tokens, sample_args, max_steps ) ;
240+ . send ( tokens, sample_args, max_tokens ) ;
246241
247242 let session_info = SessionInfo {
248243 sender,
0 commit comments