@@ -59,6 +59,8 @@ pub struct ServiceArgs {
5959 #[ clap( long) ]
6060 max_tokens : Option < usize > ,
6161 #[ clap( long) ]
62+ max_sessions : Option < usize > ,
63+ #[ clap( long) ]
6264 temperature : Option < f32 > ,
6365 #[ clap( long) ]
6466 top_p : Option < f32 > ,
@@ -74,6 +76,8 @@ pub struct ModelConfig {
7476 pub gpus : Option < Box < [ c_int ] > > ,
7577 #[ serde( rename = "max-tokens" ) ]
7678 pub max_tokens : Option < usize > ,
79+ #[ serde( rename = "max-sessions" ) ]
80+ pub max_sessions : Option < usize > ,
7781 pub temperature : Option < f32 > ,
7882 #[ serde( rename = "top-p" ) ]
7983 pub top_p : Option < f32 > ,
@@ -92,6 +96,7 @@ impl ServiceArgs {
9296 name,
9397 gpus,
9498 max_tokens,
99+ max_sessions,
95100 temperature,
96101 top_p,
97102 repetition_penalty,
@@ -109,6 +114,7 @@ impl ServiceArgs {
109114 path : file. clone ( ) ,
110115 gpus : Some ( parse_gpus ( gpus. as_deref ( ) ) ) ,
111116 max_tokens,
117+ max_sessions,
112118 temperature,
113119 top_p,
114120 repetition_penalty,
@@ -146,7 +152,7 @@ async fn start_infer_service(
146152 handles : Vec < ( Arc < Model > , Service ) > ,
147153 port : u16 ,
148154) -> std:: io:: Result < ( ) > {
149- let app = App ( Arc :: new ( models) ) ;
155+ let app = App :: new ( models) ;
150156
151157 let _handles = handles
152158 . into_iter ( )
@@ -174,22 +180,50 @@ async fn start_infer_service(
174180}
175181
176182#[ derive( Clone ) ]
177- struct App ( Arc < HashMap < String , Arc < Model > > > ) ;
183+ struct App ( Arc < HashMap < String , Arc < Model > > > , Arc < AtomicUsize > ) ;
184+
185+ impl App {
186+ fn new ( models : HashMap < String , Arc < Model > > ) -> Self {
187+ App ( Arc :: new ( models) , Arc :: new ( AtomicUsize :: new ( 0 ) ) )
188+ }
189+
190+ fn try_acquire_connection ( & self ) -> bool {
191+ const MAX_CONCURRENT_CONNECTIONS : usize = 32 ; // Set a reasonable limit
192+ let current = self . 1 . fetch_add ( 1 , SeqCst ) ;
193+ if current >= MAX_CONCURRENT_CONNECTIONS {
194+ self . 1 . fetch_sub ( 1 , SeqCst ) ;
195+ false
196+ } else {
197+ true
198+ }
199+ }
200+
201+ fn release_connection ( & self ) {
202+ self . 1 . fetch_sub ( 1 , SeqCst ) ;
203+ }
204+ }
178205
179206impl HyperService < Request < Incoming > > for App {
180207 type Response = Response < BoxBody < Bytes , hyper:: Error > > ;
181208 type Error = hyper:: Error ;
182209 type Future = Pin < Box < dyn Future < Output = Result < Self :: Response , Self :: Error > > + Send > > ;
183210
184211 fn call ( & self , req : Request < Incoming > ) -> Self :: Future {
185- match ( req. method ( ) , req. uri ( ) . path ( ) ) {
186- openai:: GET_MODELS => {
187- let json = json ( create_models ( self . 0 . keys ( ) . cloned ( ) ) ) ;
188- Box :: pin ( async move { Ok ( json) } )
189- }
190- openai:: POST_COMPLETIONS => {
191- let models = self . 0 . clone ( ) ;
192- Box :: pin ( async move {
212+ // Try to acquire a connection slot
213+ if !self . try_acquire_connection ( ) {
214+ let response = error ( Error :: TooManyConnections ) ;
215+ return Box :: pin ( async move { Ok ( response) } ) ;
216+ }
217+
218+ let app_clone = self . clone ( ) ;
219+ Box :: pin ( async move {
220+ let result = match ( req. method ( ) , req. uri ( ) . path ( ) ) {
221+ openai:: GET_MODELS => {
222+ let json = json ( create_models ( app_clone. 0 . keys ( ) . cloned ( ) ) ) ;
223+ Ok ( json)
224+ }
225+ openai:: POST_COMPLETIONS => {
226+ let models = app_clone. 0 . clone ( ) ;
193227 let whole_body = req. collect ( ) . await ?. to_bytes ( ) ;
194228 let req: CreateCompletionRequest = match serde_json:: from_slice ( & whole_body) {
195229 Ok ( req) => req,
@@ -261,11 +295,9 @@ impl HyperService<Request<Incoming>> for App {
261295
262296 let response = completion_response ( id, created, model_name, content_, reason_) ;
263297 Ok ( json ( response) )
264- } )
265- }
266- openai:: POST_CHAT_COMPLETIONS => {
267- let models = self . 0 . clone ( ) ;
268- Box :: pin ( async move {
298+ }
299+ openai:: POST_CHAT_COMPLETIONS => {
300+ let models = app_clone. 0 . clone ( ) ;
269301 let whole_body = req. collect ( ) . await ?. to_bytes ( ) ;
270302
271303 let req: CreateChatCompletionRequest = match serde_json:: from_slice ( & whole_body)
@@ -352,14 +384,17 @@ impl HyperService<Request<Incoming>> for App {
352384 reason_,
353385 ) ;
354386 Ok ( json ( response) )
355- } )
356- }
357- // Return 404 Not Found for other routes.
358- ( method, uri) => {
359- let msg = Error :: not_found ( method, uri) ;
360- Box :: pin ( async move { Ok ( error ( msg) ) } )
361- }
362- }
387+ }
388+ ( method, uri) => {
389+ let msg = Error :: not_found ( method, uri) ;
390+ Ok ( error ( msg) )
391+ }
392+ } ;
393+
394+ // Always release the connection when done
395+ app_clone. release_connection ( ) ;
396+ result
397+ } )
363398 }
364399}
365400
0 commit comments