Skip to content

Commit 357da40

Browse files
committed
Merge commit '9f5b3ca0e09d49338e9384672b841e856a5d6e84' into llama.cu
Signed-off-by: YdrMaster <ydrml@hotmail.com>
2 parents 4d6ea06 + 9f5b3ca commit 357da40

File tree

11 files changed

+486
-301
lines changed

11 files changed

+486
-301
lines changed

Cargo.lock

Lines changed: 34 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,26 @@ cargo service --help
8989
```plaintext
9090
web service
9191
92-
Usage: xtask service [OPTIONS] --port <PORT> <MODEL>
92+
Usage: xtask service [OPTIONS] --port <PORT> <FILE>
9393
9494
Arguments:
95-
<MODEL>
95+
<FILE>
9696
9797
Options:
98+
-p, --port <PORT>
99+
--no-cuda-graph
100+
--name <NAME>
98101
--gpus <GPUS>
99102
--max-steps <MAX_STEPS>
100-
-p, --port <PORT>
103+
--think
101104
-h, --help
102105
```
106+
107+
通过 TOML 配置文件可以配置多模型服务。示例格式:
108+
109+
```toml
110+
[model-name]
111+
path = "model-path"
112+
think = true
113+
max-steps = 2048
114+
```

xtask/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ ratatui = "0.29"
1717

1818
serde.workspace = true
1919
serde_json = "1.0"
20+
toml = "0.8"
2021
tokio = { version = "1.45", features = ["rt-multi-thread", "net"] }
2122
hyper = { version = "1.6", features = ["http1", "server"] }
2223
hyper-util = { version = "0.1", features = ["http1", "tokio", "server"] }

xtask/src/main.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,27 @@ struct BaseArgs {
5555

5656
impl BaseArgs {
5757
fn gpus(&self) -> Box<[c_int]> {
58-
self.gpus
59-
.as_ref()
60-
.map(|devices| {
61-
static NUM_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\d+").unwrap());
62-
NUM_REGEX
63-
.find_iter(devices)
64-
.map(|c| c.as_str().parse().unwrap())
65-
.collect()
66-
})
67-
.unwrap_or_else(|| [0].into())
58+
parse_gpus(self.gpus.as_deref())
6859
}
6960

7061
fn max_steps(&self) -> usize {
7162
self.max_steps.unwrap_or(1000)
7263
}
7364
}
7465

66+
fn parse_gpus(config: Option<&str>) -> Box<[c_int]> {
67+
config
68+
.as_ref()
69+
.map(|devices| {
70+
static NUM_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\d+").unwrap());
71+
NUM_REGEX
72+
.find_iter(devices)
73+
.map(|c| c.as_str().parse().unwrap())
74+
.collect()
75+
})
76+
.unwrap_or_else(|| [0].into())
77+
}
78+
7579
mod macros {
7680
macro_rules! print_now {
7781
($($arg:tt)*) => {{

xtask/src/service/cache_manager.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl CacheManager {
2323
&mut self,
2424
tokens: Vec<utok>,
2525
sample_args: SampleArgs,
26-
max_steps: usize,
26+
max_tokens: usize,
2727
) -> (SessionId, Vec<utok>) {
2828
static SESSION_ID: AtomicUsize = AtomicUsize::new(0);
2929
let id = SessionId(SESSION_ID.fetch_add(1, SeqCst));
@@ -51,7 +51,7 @@ impl CacheManager {
5151
cache,
5252
},
5353
&tokens[pos..],
54-
max_steps,
54+
max_tokens,
5555
);
5656
(id, tokens)
5757
}

xtask/src/service/client.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
use super::*;
1+
use super::openai::POST_CHAT_COMPLETIONS;
22
use log::{info, trace, warn};
3-
use openai_struct::CreateChatCompletionStreamResponse;
3+
use openai_struct::{
4+
ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionStreamResponse,
5+
};
46
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
57
use std::{env::VarError, time::Instant};
68
use tokio::time::Duration;
@@ -72,7 +74,10 @@ async fn send_single_request(
7274
}
7375

7476
let req = client
75-
.post(format!("http://localhost:{port}{V1_CHAT_COMPLETIONS}"))
77+
.post(format!(
78+
"http://localhost:{port}{}",
79+
POST_CHAT_COMPLETIONS.1
80+
))
7681
.headers(headers.clone())
7782
.body(req_body)
7883
.timeout(Duration::from_secs(100));

xtask/src/service/error.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
use hyper::{Method, StatusCode};
1+
use hyper::{Method, StatusCode};
22
use serde::Serialize;
3+
use std::fmt;
34

45
#[derive(Debug)]
56
pub(crate) enum Error {
67
WrongJson(serde_json::Error),
78
NotFound(NotFoundError),
89
MsgNotSupported(MsgNotSupportedError),
10+
ModelNotFound(String),
911
}
1012

1113
#[derive(Serialize, Debug)]
@@ -39,6 +41,7 @@ impl Error {
3941
Self::WrongJson(..) => StatusCode::BAD_REQUEST,
4042
Self::NotFound(..) => StatusCode::NOT_FOUND,
4143
Self::MsgNotSupported(..) => StatusCode::BAD_REQUEST,
44+
Self::ModelNotFound(..) => StatusCode::NOT_FOUND,
4245
}
4346
}
4447

@@ -48,6 +51,20 @@ impl Error {
4851
Self::WrongJson(e) => e.to_string(),
4952
Self::NotFound(e) => serde_json::to_string(&e).unwrap(),
5053
Self::MsgNotSupported(e) => serde_json::to_string(&e).unwrap(),
54+
Self::ModelNotFound(model) => format!("Model not found: {}", model),
5155
}
5256
}
5357
}
58+
59+
impl fmt::Display for Error {
60+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61+
match self {
62+
Error::WrongJson(e) => write!(f, "Invalid JSON: {}", e),
63+
Error::NotFound(e) => write!(f, "Not Found: {} {}", e.method, e.uri),
64+
Error::MsgNotSupported(e) => write!(f, "Message type not supported: {:?}", e.message),
65+
Error::ModelNotFound(model) => write!(f, "Model not found: {}", model),
66+
}
67+
}
68+
}
69+
70+
impl std::error::Error for Error {}

0 commit comments

Comments
 (0)