Skip to content

Commit 0296692

Browse files
authored
fix(llm, llamacpp): Loosen model ID validation (#163)
The validation logic for model IDs was too strict to support model names often used in the Llama.cpp ecosystem. For example, a model ID like `llamacpp/bartowski_Qwen2.5-7B-Instruct-GGUF_Qwen2.5-7B-Instruct-Q4_K_M.gguf` was previously rejected, but can now be used with JP. Additionally, the model listing implementation was inadvertently left out of the implementation of the Llamacpp provider. This commit fixes that oversight. It should be noted that it is not possible (unless the llama.cpp server is run with the `--alias` flag[1]) to get the proper model name from the API (instead, it points to the model file on disk), so in reality this implementation isn't particularly useful, but at least it doesn't panic anymore. [1]: https://github.com/ggml-org/llama.cpp/blob/8846aace4934ad29651ea61b8c7e3f6b0556e3d2/tools/server/README.md#get-v1models-openai-compatible-model-info-api Signed-off-by: Jean Mertz <git@jeanmertz.com>
1 parent ff4ea97 commit 0296692

File tree

5 files changed

+115
-63
lines changed

5 files changed

+115
-63
lines changed

crates/jp_llm/src/provider/llamacpp.rs

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use jp_conversation::{
88
AssistantMessage, MessagePair, UserMessage,
99
};
1010
use jp_mcp::tool::{self, ToolChoice};
11-
use jp_model::ModelId;
11+
use jp_model::{ModelId, ProviderId};
1212
use jp_query::query::ChatQuery;
1313
use openai::{
1414
chat::{
@@ -22,25 +22,24 @@ use openai::{
2222
use serde::Serialize;
2323
use tracing::{debug, trace};
2424

25-
use super::{CompletionChunk, Delta, EventStream, ModelDetails, StreamEvent};
25+
use super::{
26+
openai::{ModelListResponse, ModelResponse},
27+
CompletionChunk, Delta, EventStream, ModelDetails, StreamEvent,
28+
};
2629
use crate::{
2730
error::{Error, Result},
2831
provider::{handle_delta, AccumulationState, Provider, ReasoningExtractor},
2932
};
3033

3134
#[derive(Debug, Clone)]
3235
pub struct Llamacpp {
36+
reqwest_client: reqwest::Client,
3337
credentials: Credentials,
38+
base_url: String,
3439
}
3540

3641
impl Llamacpp {
37-
fn new(base_url: String) -> Self {
38-
let credentials = Credentials::new("", base_url);
39-
40-
Self { credentials }
41-
}
42-
43-
/// Build request for Openai API.
42+
/// Build request for Llama.cpp API.
4443
fn build_request(
4544
&self,
4645
model_id: &ModelId,
@@ -63,7 +62,7 @@ impl Llamacpp {
6362
slug,
6463
messages_size = messages.len(),
6564
tools_size = tools.len(),
66-
"Built Openai request."
65+
"Built Llamacpp request."
6766
);
6867

6968
Ok(ChatCompletionDelta::builder(slug, messages)
@@ -76,7 +75,18 @@ impl Llamacpp {
7675
#[async_trait]
7776
impl Provider for Llamacpp {
7877
async fn models(&self) -> Result<Vec<ModelDetails>> {
79-
todo!()
78+
Ok(self
79+
.reqwest_client
80+
.get(format!("{}/v1/models", self.base_url))
81+
.send()
82+
.await?
83+
.error_for_status()?
84+
.json::<ModelListResponse>()
85+
.await?
86+
.data
87+
.iter()
88+
.map(map_model)
89+
.collect())
8090
}
8191

8292
async fn chat_completion_stream(
@@ -182,13 +192,34 @@ fn map_content(
182192
events
183193
}
184194

195+
fn map_model(model: &ModelResponse) -> ModelDetails {
196+
ModelDetails {
197+
provider: ProviderId::Llamacpp,
198+
slug: model
199+
.id
200+
.rsplit_once('/')
201+
.map_or(model.id.as_str(), |(_, v)| v)
202+
.to_string(),
203+
context_window: None,
204+
max_output_tokens: None,
205+
reasoning: None,
206+
knowledge_cutoff: None,
207+
}
208+
}
209+
185210
impl TryFrom<&assistant::provider::llamacpp::Llamacpp> for Llamacpp {
186211
type Error = Error;
187212

188213
fn try_from(config: &assistant::provider::llamacpp::Llamacpp) -> Result<Self> {
214+
let reqwest_client = reqwest::Client::builder().build()?;
189215
let base_url = config.base_url.clone();
216+
let credentials = Credentials::new("", &base_url);
190217

191-
Ok(Llamacpp::new(base_url))
218+
Ok(Llamacpp {
219+
reqwest_client,
220+
credentials,
221+
base_url,
222+
})
192223
}
193224
}
194225

@@ -461,41 +492,29 @@ mod tests {
461492
Vcr::new("http://127.0.0.1:8080", fixtures)
462493
}
463494

464-
// #[test(tokio::test)]
465-
// async fn test_llamacpp_models() -> std::result::Result<(), Box<dyn std::error::Error>> {
466-
// let mut config =
467-
// assistant::Assistant::from_partial(assistant::AssistantPartial::default_values())
468-
// .unwrap()
469-
// .provider
470-
// .openai;
471-
//
472-
// let vcr = vcr();
473-
// vcr.cassette(
474-
// function_name!(),
475-
// |rule| {
476-
// rule.filter(|when| {
477-
// when.any_request();
478-
// });
479-
// },
480-
// |recording, url| async move {
481-
// config.base_url = url;
482-
// if !recording {
483-
// // dummy api key value when replaying a cassette
484-
// config.api_key_env = "USER".to_owned();
485-
// }
486-
//
487-
// Openai::try_from(&config)
488-
// .unwrap()
489-
// .models()
490-
// .await
491-
// .map(|mut v| {
492-
// v.truncate(10);
493-
// v
494-
// })
495-
// },
496-
// )
497-
// .await
498-
// }
495+
#[test(tokio::test)]
496+
async fn test_llamacpp_models() -> std::result::Result<(), Box<dyn std::error::Error>> {
497+
let mut config =
498+
assistant::Assistant::from_partial(assistant::AssistantPartial::default_values())
499+
.unwrap()
500+
.provider
501+
.llamacpp;
502+
503+
let vcr = vcr();
504+
vcr.cassette(
505+
function_name!(),
506+
|rule| {
507+
rule.filter(|when| {
508+
when.any_request();
509+
});
510+
},
511+
|_, url| async move {
512+
config.base_url = url;
513+
Llamacpp::try_from(&config).unwrap().models().await
514+
},
515+
)
516+
.await
517+
}
499518

500519
#[test(tokio::test)]
501520
async fn test_llamacpp_chat_completion() -> std::result::Result<(), Box<dyn std::error::Error>>

crates/jp_llm/src/provider/openai.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,15 @@ impl Provider for Openai {
146146

147147
#[derive(Debug, Deserialize)]
148148
#[expect(dead_code)]
149-
struct ModelListResponse {
149+
pub(crate) struct ModelListResponse {
150150
object: String,
151-
data: Vec<ModelResponse>,
151+
pub data: Vec<ModelResponse>,
152152
}
153153

154154
#[derive(Debug, Deserialize)]
155155
#[expect(dead_code)]
156-
struct ModelResponse {
157-
id: String,
156+
pub(crate) struct ModelResponse {
157+
pub id: String,
158158
object: String,
159159
#[serde(with = "time::serde::timestamp")]
160160
created: OffsetDateTime,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
---
2+
source: crates/jp_test/src/mock.rs
3+
expression: expr
4+
---
5+
Ok(
6+
[
7+
ModelDetails {
8+
provider: Llamacpp,
9+
slug: "bartowski_Qwen2.5-7B-Instruct-GGUF_Qwen2.5-7B-Instruct-Q4_K_M.gguf",
10+
context_window: None,
11+
max_output_tokens: None,
12+
reasoning: None,
13+
knowledge_cutoff: None,
14+
},
15+
],
16+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
when:
2+
path: /v1/models
3+
method: GET
4+
then:
5+
status: 200
6+
header:
7+
- name: content-type
8+
value: application/json; charset=utf-8
9+
body: "{\"models\":[{\"name\":\"/llama.cpp/bartowski_Qwen2.5-7B-Instruct-GGUF_Qwen2.5-7B-Instruct-Q4_K_M.gguf\",\"model\":\"/llama.cpp/bartowski_Qwen2.5-7B-Instruct-GGUF_Qwen2.5-7B-Instruct-Q4_K_M.gguf\",\"modified_at\":\"\",\"size\":\"\",\"digest\":\"\",\"type\":\"model\",\"description\":\"\",\"tags\":[\"\"],\"capabilities\":[\"completion\"],\"parameters\":\"\",\"details\":{\"parent_model\":\"\",\"format\":\"gguf\",\"family\":\"\",\"families\":[\"\"],\"parameter_size\":\"\",\"quantization_level\":\"\"}}],\"object\":\"list\",\"data\":[{\"id\":\"/llama.cpp/bartowski_Qwen2.5-7B-Instruct-GGUF_Qwen2.5-7B-Instruct-Q4_K_M.gguf\",\"object\":\"model\",\"created\":1751010042,\"owned_by\":\"llamacpp\",\"meta\":{\"vocab_type\":2,\"n_vocab\":152064,\"n_ctx_train\":32768,\"n_embd\":3584,\"n_params\":7615616512,\"size\":4677120000}}]}"

crates/jp_model/src/lib.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,26 @@ impl FromStr for ModelId {
115115
type Err = Error;
116116

117117
fn from_str(s: &str) -> Result<Self> {
118-
let (provider, name) = s.split_once('/').unwrap_or(("", s));
119-
120-
if name.chars().any(|c| {
121-
!(c.is_numeric()
122-
|| (c.is_ascii_alphabetic() && c.is_ascii_lowercase())
123-
|| c == '-'
124-
|| c == '_'
125-
|| c == '.'
126-
|| c == ':'
127-
|| c == '/')
128-
}) {
118+
let (provider, name) =
119+
s.split_once('/')
120+
.map(|(p, n)| (p.trim(), n.trim()))
121+
.ok_or(Error::InvalidIdFormat(
122+
"ID must match <provider>/<model>".to_owned(),
123+
))?;
124+
125+
if name.is_empty()
126+
|| name.chars().any(|c| {
127+
!(c.is_numeric()
128+
|| c.is_ascii_alphabetic()
129+
|| c == '-'
130+
|| c == '_'
131+
|| c == '.'
132+
|| c == ':'
133+
|| c == '/')
134+
})
135+
{
129136
return Err(Error::InvalidIdFormat(
130-
"Model ID must be [a-z0-9_-.:/]".to_string(),
137+
"Model ID must be [a-zA-Z0-9_-.:/]+".to_string(),
131138
));
132139
}
133140

@@ -197,6 +204,7 @@ impl FromStr for ProviderId {
197204
"openai" => Ok(Self::Openai),
198205
"openrouter" => Ok(Self::Openrouter),
199206
"ollama" => Ok(Self::Ollama),
207+
_ if s.is_empty() => Err(Error::InvalidProviderId("<empty>".to_owned())),
200208
_ => Err(Error::InvalidProviderId(s.to_owned())),
201209
}
202210
}

0 commit comments

Comments
 (0)