Skip to content

Commit 783262e

Browse files
committed
feat(xtask): 增加max session控制处理并发
Signed-off-by: Ceng23333 <[email protected]>
1 parent 9fe0ec1 commit 783262e

File tree

3 files changed

+81
-23
lines changed

3 files changed

+81
-23
lines changed

xtask/src/service/error.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub(crate) enum Error {
88
NotFound(NotFoundError),
99
MsgNotSupported(MsgNotSupportedError),
1010
ModelNotFound(String),
11+
TooManyConnections,
1112
}
1213

1314
#[derive(Serialize, Debug)]
@@ -42,6 +43,7 @@ impl Error {
4243
Self::NotFound(..) => StatusCode::NOT_FOUND,
4344
Self::MsgNotSupported(..) => StatusCode::BAD_REQUEST,
4445
Self::ModelNotFound(..) => StatusCode::NOT_FOUND,
46+
Self::TooManyConnections => StatusCode::TOO_MANY_REQUESTS,
4547
}
4648
}
4749

@@ -52,6 +54,9 @@ impl Error {
5254
Self::NotFound(e) => serde_json::to_string(&e).unwrap(),
5355
Self::MsgNotSupported(e) => serde_json::to_string(&e).unwrap(),
5456
Self::ModelNotFound(model) => format!("Model not found: {model}"),
57+
Self::TooManyConnections => {
58+
"Too many concurrent connections. Please try again later.".to_string()
59+
}
5560
}
5661
}
5762
}
@@ -63,6 +68,10 @@ impl fmt::Display for Error {
6368
Error::NotFound(e) => write!(f, "Not Found: {} {}", e.method, e.uri),
6469
Error::MsgNotSupported(e) => write!(f, "Message type not supported: {:?}", e.message),
6570
Error::ModelNotFound(model) => write!(f, "Model not found: {model}"),
71+
Error::TooManyConnections => write!(
72+
f,
73+
"Too many concurrent connections. Please try again later."
74+
),
6675
}
6776
}
6877
}

xtask/src/service/mod.rs

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ pub struct ServiceArgs {
5959
#[clap(long)]
6060
max_tokens: Option<usize>,
6161
#[clap(long)]
62+
max_sessions: Option<usize>,
63+
#[clap(long)]
6264
temperature: Option<f32>,
6365
#[clap(long)]
6466
top_p: Option<f32>,
@@ -74,6 +76,8 @@ pub struct ModelConfig {
7476
pub gpus: Option<Box<[c_int]>>,
7577
#[serde(rename = "max-tokens")]
7678
pub max_tokens: Option<usize>,
79+
#[serde(rename = "max-sessions")]
80+
pub max_sessions: Option<usize>,
7781
pub temperature: Option<f32>,
7882
#[serde(rename = "top-p")]
7983
pub top_p: Option<f32>,
@@ -92,6 +96,7 @@ impl ServiceArgs {
9296
name,
9397
gpus,
9498
max_tokens,
99+
max_sessions,
95100
temperature,
96101
top_p,
97102
repetition_penalty,
@@ -109,6 +114,7 @@ impl ServiceArgs {
109114
path: file.clone(),
110115
gpus: Some(parse_gpus(gpus.as_deref())),
111116
max_tokens,
117+
max_sessions,
112118
temperature,
113119
top_p,
114120
repetition_penalty,
@@ -146,7 +152,7 @@ async fn start_infer_service(
146152
handles: Vec<(Arc<Model>, Service)>,
147153
port: u16,
148154
) -> std::io::Result<()> {
149-
let app = App(Arc::new(models));
155+
let app = App::new(models);
150156

151157
let _handles = handles
152158
.into_iter()
@@ -174,22 +180,50 @@ async fn start_infer_service(
174180
}
175181

176182
#[derive(Clone)]
177-
struct App(Arc<HashMap<String, Arc<Model>>>);
183+
struct App(Arc<HashMap<String, Arc<Model>>>, Arc<AtomicUsize>);
184+
185+
impl App {
186+
fn new(models: HashMap<String, Arc<Model>>) -> Self {
187+
App(Arc::new(models), Arc::new(AtomicUsize::new(0)))
188+
}
189+
190+
fn try_acquire_connection(&self) -> bool {
191+
const MAX_CONCURRENT_CONNECTIONS: usize = 32; // Set a reasonable limit
192+
let current = self.1.fetch_add(1, SeqCst);
193+
if current >= MAX_CONCURRENT_CONNECTIONS {
194+
self.1.fetch_sub(1, SeqCst);
195+
false
196+
} else {
197+
true
198+
}
199+
}
200+
201+
fn release_connection(&self) {
202+
self.1.fetch_sub(1, SeqCst);
203+
}
204+
}
178205

179206
impl HyperService<Request<Incoming>> for App {
180207
type Response = Response<BoxBody<Bytes, hyper::Error>>;
181208
type Error = hyper::Error;
182209
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
183210

184211
fn call(&self, req: Request<Incoming>) -> Self::Future {
185-
match (req.method(), req.uri().path()) {
186-
openai::GET_MODELS => {
187-
let json = json(create_models(self.0.keys().cloned()));
188-
Box::pin(async move { Ok(json) })
189-
}
190-
openai::POST_COMPLETIONS => {
191-
let models = self.0.clone();
192-
Box::pin(async move {
212+
// Try to acquire a connection slot
213+
if !self.try_acquire_connection() {
214+
let response = error(Error::TooManyConnections);
215+
return Box::pin(async move { Ok(response) });
216+
}
217+
218+
let app_clone = self.clone();
219+
Box::pin(async move {
220+
let result = match (req.method(), req.uri().path()) {
221+
openai::GET_MODELS => {
222+
let json = json(create_models(app_clone.0.keys().cloned()));
223+
Ok(json)
224+
}
225+
openai::POST_COMPLETIONS => {
226+
let models = app_clone.0.clone();
193227
let whole_body = req.collect().await?.to_bytes();
194228
let req: CreateCompletionRequest = match serde_json::from_slice(&whole_body) {
195229
Ok(req) => req,
@@ -261,11 +295,9 @@ impl HyperService<Request<Incoming>> for App {
261295

262296
let response = completion_response(id, created, model_name, content_, reason_);
263297
Ok(json(response))
264-
})
265-
}
266-
openai::POST_CHAT_COMPLETIONS => {
267-
let models = self.0.clone();
268-
Box::pin(async move {
298+
}
299+
openai::POST_CHAT_COMPLETIONS => {
300+
let models = app_clone.0.clone();
269301
let whole_body = req.collect().await?.to_bytes();
270302

271303
let req: CreateChatCompletionRequest = match serde_json::from_slice(&whole_body)
@@ -352,14 +384,17 @@ impl HyperService<Request<Incoming>> for App {
352384
reason_,
353385
);
354386
Ok(json(response))
355-
})
356-
}
357-
// Return 404 Not Found for other routes.
358-
(method, uri) => {
359-
let msg = Error::not_found(method, uri);
360-
Box::pin(async move { Ok(error(msg)) })
361-
}
362-
}
387+
}
388+
(method, uri) => {
389+
let msg = Error::not_found(method, uri);
390+
Ok(error(msg))
391+
}
392+
};
393+
394+
// Always release the connection when done
395+
app_clone.release_connection();
396+
result
397+
})
363398
}
364399
}
365400

xtask/src/service/model.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
2020

2121
pub(super) struct Model {
2222
max_tokens: usize,
23+
max_sessions: usize,
2324
sampling: SampleArgs,
2425
think: [utok; 2],
2526
enable_thinking: bool,
@@ -68,6 +69,7 @@ impl Model {
6869
path,
6970
gpus,
7071
max_tokens,
72+
max_sessions,
7173
temperature,
7274
top_p,
7375
repetition_penalty,
@@ -99,6 +101,7 @@ impl Model {
99101

100102
let model = Model {
101103
max_tokens: max_tokens.unwrap_or(2 << 10),
104+
max_sessions: max_sessions.unwrap_or(16), // Default to 16 if not specified
102105
sampling: SampleArgs::new(
103106
temperature.unwrap_or(0.),
104107
top_p.unwrap_or(1.),
@@ -121,6 +124,15 @@ impl Model {
121124
(model, service)
122125
}
123126

127+
fn check_session_limit(&self) -> Result<(), Error> {
128+
let sessions = self.sessions.lock().unwrap();
129+
if sessions.len() >= self.max_sessions {
130+
Err(Error::TooManyConnections)
131+
} else {
132+
Ok(())
133+
}
134+
}
135+
124136
pub fn serve(&self, service: &mut Service) {
125137
let [think, _think] = self.think;
126138
loop {
@@ -263,6 +275,7 @@ impl Model {
263275
&self,
264276
req: CreateChatCompletionRequest,
265277
) -> Result<UnboundedReceiver<Output>, Error> {
278+
self.check_session_limit()?;
266279
let CreateChatCompletionRequest {
267280
messages,
268281
max_tokens,
@@ -347,6 +360,7 @@ impl Model {
347360
&self,
348361
req: CreateCompletionRequest,
349362
) -> Result<UnboundedReceiver<Output>, Error> {
363+
self.check_session_limit()?;
350364
let CreateCompletionRequest {
351365
prompt,
352366
max_tokens,

0 commit comments

Comments
 (0)