Skip to content

Commit c331a1c

Browse files
committed
feat(xtask): 实现反馈模型列表的接口
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent b0be217 commit c331a1c

File tree

6 files changed

+91
-49
lines changed

6 files changed

+91
-49
lines changed

xtask/src/service/cache_manager.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl CacheManager {
2323
&mut self,
2424
tokens: Vec<utok>,
2525
sample_args: SampleArgs,
26-
max_steps: usize,
26+
max_tokens: usize,
2727
) -> (SessionId, Vec<utok>) {
2828
static SESSION_ID: AtomicUsize = AtomicUsize::new(0);
2929
let id = SessionId(SESSION_ID.fetch_add(1, SeqCst));
@@ -51,7 +51,7 @@ impl CacheManager {
5151
cache,
5252
},
5353
&tokens[pos..],
54-
max_steps,
54+
max_tokens,
5555
);
5656
(id, tokens)
5757
}

xtask/src/service/client.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::V1_CHAT_COMPLETIONS;
1+
use super::openai::POST_CHAT_COMPLETIONS;
22
use log::{info, trace, warn};
33
use openai_struct::{
44
ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionStreamResponse,
@@ -74,7 +74,10 @@ async fn send_single_request(
7474
}
7575

7676
let req = client
77-
.post(format!("http://localhost:{port}{V1_CHAT_COMPLETIONS}"))
77+
.post(format!(
78+
"http://localhost:{port}{}",
79+
POST_CHAT_COMPLETIONS.1
80+
))
7881
.headers(headers.clone())
7982
.body(req_body)
8083
.timeout(Duration::from_secs(100));

xtask/src/service/mod.rs

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@ mod model;
44
mod openai;
55
mod response;
66

7-
use crate::{parse_gpus, service::model::Model};
7+
use crate::parse_gpus;
88
use error::*;
99
use http_body_util::{BodyExt, combinators::BoxBody};
1010
use hyper::{
11-
Method, Request, Response,
11+
Request, Response,
1212
body::{Bytes, Incoming},
1313
server::conn::http1,
1414
service::Service as HyperService,
1515
};
1616
use hyper_util::rt::TokioIo;
1717
use log::{info, warn};
18-
use openai::V1_CHAT_COMPLETIONS;
18+
use model::Model;
19+
use openai::create_models;
1920
use openai_struct::CreateChatCompletionRequest;
2021
use response::error;
22+
use response::json;
2123
use std::collections::HashMap;
2224
use std::{ffi::c_int, fs::read_to_string, path::Path};
2325
use std::{
@@ -42,7 +44,7 @@ pub struct ServiceArgs {
4244
#[clap(long)]
4345
gpus: Option<String>,
4446
#[clap(long)]
45-
max_steps: Option<usize>,
47+
max_tokens: Option<usize>,
4648
#[clap(long)]
4749
think: bool,
4850
}
@@ -51,7 +53,7 @@ pub struct ServiceArgs {
5153
pub struct ModelConfig {
5254
pub path: String,
5355
pub gpus: Option<Box<[c_int]>>,
54-
pub max_steps: Option<usize>,
56+
pub max_tokens: Option<usize>,
5557
pub think: Option<bool>,
5658
}
5759

@@ -63,7 +65,7 @@ impl ServiceArgs {
6365
no_cuda_graph,
6466
name,
6567
gpus,
66-
max_steps,
68+
max_tokens,
6769
think,
6870
} = self;
6971

@@ -77,7 +79,7 @@ impl ServiceArgs {
7779
ModelConfig {
7880
path: file.clone(),
7981
gpus: Some(parse_gpus(gpus.as_deref())),
80-
max_steps,
82+
max_tokens,
8183
think: Some(think),
8284
},
8385
)]
@@ -139,19 +141,25 @@ impl HyperService<Request<Incoming>> for App {
139141
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
140142

141143
fn call(&self, req: Request<Incoming>) -> Self::Future {
142-
let models = self.0.clone();
143144
match (req.method(), req.uri().path()) {
144-
(&Method::POST, V1_CHAT_COMPLETIONS) => Box::pin(async move {
145-
let whole_body = req.collect().await?.to_bytes();
146-
let req = serde_json::from_slice::<CreateChatCompletionRequest>(&whole_body);
147-
Ok(match req {
148-
Ok(req) => match models.get(&req.model) {
149-
Some(model) => model.complete_chat(req),
150-
None => error(Error::ModelNotFound(req.model)),
151-
},
152-
Err(e) => error(Error::WrongJson(e)),
145+
openai::GET_MODELS => {
146+
let json = json(create_models(self.0.keys().cloned()));
147+
Box::pin(async move { Ok(json) })
148+
}
149+
openai::POST_CHAT_COMPLETIONS => {
150+
let models = self.0.clone();
151+
Box::pin(async move {
152+
let whole_body = req.collect().await?.to_bytes();
153+
let req = serde_json::from_slice::<CreateChatCompletionRequest>(&whole_body);
154+
Ok(match req {
155+
Ok(req) => match models.get(&req.model) {
156+
Some(model) => model.complete_chat(req),
157+
None => error(Error::ModelNotFound(req.model)),
158+
},
159+
Err(e) => error(Error::WrongJson(e)),
160+
})
153161
})
154-
}),
162+
}
155163
// Return 404 Not Found for other routes.
156164
(method, uri) => {
157165
let msg = Error::not_found(method, uri);

xtask/src/service/model.rs

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ use tokio::{
2828
use tokio_stream::wrappers::UnboundedReceiverStream;
2929

3030
pub(super) struct Model {
31+
max_tokens: usize,
3132
terminal: Terminal,
32-
max_steps: usize,
3333
sessions: Mutex<BTreeMap<SessionId, SessionInfo>>,
3434
cache_manager: Mutex<CacheManager>,
3535
}
@@ -48,7 +48,7 @@ impl Model {
4848
let ModelConfig {
4949
path,
5050
gpus,
51-
max_steps,
51+
max_tokens,
5252
think,
5353
} = config;
5454

@@ -67,26 +67,25 @@ impl Model {
6767
(utok::MAX, utok::MAX)
6868
};
6969

70-
let service_manager = Arc::new(Model {
70+
let model = Arc::new(Model {
71+
max_tokens: max_tokens.unwrap_or(2 << 10),
7172
terminal: service.terminal().clone(),
72-
max_steps: max_steps.unwrap_or(2 << 10),
7373
sessions: Mutex::new(sessions),
7474
cache_manager: Mutex::new(CacheManager::new(service.terminal().clone())),
7575
});
7676

77-
let service_manager_for_recv = service_manager.clone();
78-
77+
let model_ = model.clone();
7978
let join_handle = tokio::task::spawn_blocking(move || {
8079
loop {
8180
let Received { sessions, outputs } = service.recv(Duration::from_millis(10));
8281

82+
let mut sessions_guard = model_.sessions.lock().unwrap();
8383
// 先处理输出
8484
for (session_id, tokens) in outputs {
8585
if tokens.is_empty() {
8686
continue;
8787
}
8888

89-
let mut sessions_guard = service_manager_for_recv.sessions.lock().unwrap();
9089
let session_info = sessions_guard.get_mut(&session_id).unwrap();
9190
// 更新 session_info
9291
session_info.tokens.extend(&tokens);
@@ -111,12 +110,8 @@ impl Model {
111110
&[]
112111
};
113112

114-
let think = service_manager_for_recv
115-
.terminal
116-
.decode(think, &mut session_info.buf);
117-
let text = service_manager_for_recv
118-
.terminal
119-
.decode(tokens, &mut session_info.buf);
113+
let think = model_.terminal.decode(think, &mut session_info.buf);
114+
let text = model_.terminal.decode(tokens, &mut session_info.buf);
120115
debug!("解码完成:{tokens:?} -> {think:?} | {text:?}");
121116

122117
let response = create_chat_completion_response(
@@ -131,16 +126,12 @@ impl Model {
131126

132127
if session_info.sender.send(message).is_err() {
133128
info!("{session_id:?} 客户端连接已关闭");
134-
service_manager_for_recv.terminal.stop(session_id);
129+
model_.terminal.stop(session_id);
135130
}
136131
}
137132

138133
// 处理会话结束
139134
if !sessions.is_empty() {
140-
let mut sessions_guard = service_manager_for_recv.sessions.lock().unwrap();
141-
let mut cache_manager_guard =
142-
service_manager_for_recv.cache_manager.lock().unwrap();
143-
144135
for (session, reason) in sessions {
145136
let SessionInfo {
146137
tokens,
@@ -152,7 +143,11 @@ impl Model {
152143
let reason = match reason {
153144
// 正常完成,插回cache
154145
ReturnReason::Finish => {
155-
cache_manager_guard.insert(tokens, session.cache);
146+
model_
147+
.cache_manager
148+
.lock()
149+
.unwrap()
150+
.insert(tokens, session.cache);
156151
info!("{:?} 正常完成", session.id);
157152
FinishReason::Stop
158153
}
@@ -177,12 +172,12 @@ impl Model {
177172
}
178173
});
179174

180-
(service_manager, join_handle)
175+
(model, join_handle)
181176
}
182177

183178
pub fn complete_chat(
184179
&self,
185-
completions: CreateChatCompletionRequest,
180+
req: CreateChatCompletionRequest,
186181
) -> Response<BoxBody<Bytes, hyper::Error>> {
187182
let CreateChatCompletionRequest {
188183
model,
@@ -191,10 +186,10 @@ impl Model {
191186
temperature,
192187
top_p,
193188
..
194-
} = completions;
189+
} = req;
195190
let (sender, receiver) = mpsc::unbounded_channel();
196191

197-
let max_steps = max_tokens.map_or(self.max_steps, |n| n as usize);
192+
let max_tokens = max_tokens.map_or(self.max_tokens, |n| n as _);
198193
let sample_args =
199194
SampleArgs::new(temperature.unwrap_or(0.), top_p.unwrap_or(1.), usize::MAX).unwrap();
200195

@@ -242,7 +237,7 @@ impl Model {
242237
.cache_manager
243238
.lock()
244239
.unwrap()
245-
.send(tokens, sample_args, max_steps);
240+
.send(tokens, sample_args, max_tokens);
246241

247242
let session_info = SessionInfo {
248243
sender,

xtask/src/service/openai.rs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,35 @@
1-
use llama_cu::SessionId;
1+
use hyper::Method;
2+
use llama_cu::SessionId;
23
use openai_struct::{
34
ChatCompletionStreamResponseDelta, CreateChatCompletionStreamResponse,
4-
CreateChatCompletionStreamResponseChoices, FinishReason,
5+
CreateChatCompletionStreamResponseChoices, FinishReason, Model,
56
};
7+
use serde::Serialize;
68

79
const CHAT_COMPLETION_OBJECT: &str = "chat.completion.chunk";
8-
pub(crate) const V1_CHAT_COMPLETIONS: &str = "/v1/chat/completions";
10+
pub(crate) const GET_MODELS: (&Method, &str) = (&Method::GET, "models");
11+
pub(crate) const POST_CHAT_COMPLETIONS: (&Method, &str) = (&Method::POST, "/chat/completions");
12+
13+
pub(crate) fn create_models(models: impl IntoIterator<Item = String>) -> impl Serialize {
14+
#[derive(Serialize)]
15+
struct Response {
16+
object: &'static str,
17+
data: Vec<Model>,
18+
}
19+
20+
Response {
21+
object: "list",
22+
data: models
23+
.into_iter()
24+
.map(|id| Model {
25+
id,
26+
object: "model".into(),
27+
owned_by: "QYLab".into(),
28+
created: 0,
29+
})
30+
.collect(),
31+
}
32+
}
933

1034
pub(crate) fn create_chat_completion_response(
1135
id: SessionId,

xtask/src/service/response.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use hyper::{
1010
CACHE_CONTROL, CONNECTION, CONTENT_TYPE,
1111
},
1212
};
13+
use serde::Serialize;
1314
use tokio_stream::{Stream, StreamExt};
1415

1516
pub fn text_stream(
@@ -27,6 +28,17 @@ pub fn text_stream(
2728
.unwrap()
2829
}
2930

31+
pub fn json(json: impl Serialize) -> Response<BoxBody<Bytes, hyper::Error>> {
32+
Response::builder()
33+
.status(StatusCode::OK)
34+
.header(CONTENT_TYPE, "application/json")
35+
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
36+
.header(ACCESS_CONTROL_ALLOW_METHODS, "GET,POST")
37+
.header(ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type")
38+
.body(full(serde_json::to_string(&json).unwrap()))
39+
.unwrap()
40+
}
41+
3042
pub fn error(e: Error) -> Response<BoxBody<Bytes, hyper::Error>> {
3143
Response::builder()
3244
.status(e.status())

0 commit comments

Comments
 (0)