Skip to content

Commit c8cb893

Browse files
authored
refactor(llm): rename to LlmClient and pass model in request (#642)
1 parent 5370f75 commit c8cb893

File tree

8 files changed

+57
-72
lines changed

8 files changed

+57
-72
lines changed

src/llm/anthropic.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use crate::llm::{
2-
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat,
3-
ToJsonSchemaOptions,
2+
LlmClient, LlmGenerateRequest, LlmGenerateResponse, OutputFormat, ToJsonSchemaOptions,
43
};
54
use anyhow::{Context, Result, bail};
65
use async_trait::async_trait;
@@ -11,27 +10,28 @@ use crate::api_bail;
1110
use urlencoding::encode;
1211

1312
pub struct Client {
14-
model: String,
1513
api_key: String,
1614
client: reqwest::Client,
1715
}
1816

1917
impl Client {
20-
pub async fn new(spec: LlmSpec) -> Result<Self> {
18+
pub async fn new(address: Option<String>) -> Result<Self> {
19+
if address.is_some() {
20+
api_bail!("Anthropic doesn't support custom API address");
21+
}
2122
let api_key = match std::env::var("ANTHROPIC_API_KEY") {
2223
Ok(val) => val,
2324
Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"),
2425
};
2526
Ok(Self {
26-
model: spec.model,
2727
api_key,
2828
client: reqwest::Client::new(),
2929
})
3030
}
3131
}
3232

3333
#[async_trait]
34-
impl LlmGenerationClient for Client {
34+
impl LlmClient for Client {
3535
async fn generate<'req>(
3636
&self,
3737
request: LlmGenerateRequest<'req>,
@@ -42,7 +42,7 @@ impl LlmGenerationClient for Client {
4242
})];
4343

4444
let mut payload = serde_json::json!({
45-
"model": self.model,
45+
"model": request.model,
4646
"messages": messages,
4747
"max_tokens": 4096
4848
});

src/llm/gemini.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
use crate::api_bail;
22
use crate::llm::{
3-
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat,
4-
ToJsonSchemaOptions,
3+
LlmClient, LlmGenerateRequest, LlmGenerateResponse, OutputFormat, ToJsonSchemaOptions,
54
};
65
use anyhow::{Context, Result, bail};
76
use async_trait::async_trait;
87
use serde_json::Value;
98
use urlencoding::encode;
109

1110
pub struct Client {
12-
model: String,
1311
api_key: String,
1412
client: reqwest::Client,
1513
}
1614

1715
impl Client {
18-
pub async fn new(spec: LlmSpec) -> Result<Self> {
16+
pub async fn new(address: Option<String>) -> Result<Self> {
17+
if address.is_some() {
18+
api_bail!("Gemini doesn't support custom API address");
19+
}
1920
let api_key = match std::env::var("GEMINI_API_KEY") {
2021
Ok(val) => val,
2122
Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"),
2223
};
2324
Ok(Self {
24-
model: spec.model,
2525
api_key,
2626
client: reqwest::Client::new(),
2727
})
@@ -47,7 +47,7 @@ fn remove_additional_properties(value: &mut Value) {
4747
}
4848

4949
#[async_trait]
50-
impl LlmGenerationClient for Client {
50+
impl LlmClient for Client {
5151
async fn generate<'req>(
5252
&self,
5353
request: LlmGenerateRequest<'req>,
@@ -79,7 +79,7 @@ impl LlmGenerationClient for Client {
7979
let api_key = &self.api_key;
8080
let url = format!(
8181
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
82-
encode(&self.model),
82+
encode(request.model),
8383
encode(api_key)
8484
);
8585

src/llm/litellm.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,13 @@ use async_openai::config::OpenAIConfig;
44
pub use super::openai::Client;
55

66
impl Client {
7-
pub async fn new_litellm(spec: super::LlmSpec) -> anyhow::Result<Self> {
8-
let address = spec
9-
.address
10-
.clone()
11-
.unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
7+
pub async fn new_litellm(address: Option<String>) -> anyhow::Result<Self> {
8+
let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
129
let api_key = std::env::var("LITELLM_API_KEY").ok();
1310
let mut config = OpenAIConfig::new().with_api_base(address);
1411
if let Some(api_key) = api_key {
1512
config = config.with_api_key(api_key);
1613
}
17-
Ok(Client::from_parts(
18-
OpenAIClient::with_config(config),
19-
spec.model,
20-
))
14+
Ok(Client::from_parts(OpenAIClient::with_config(config)))
2115
}
2216
}

src/llm/mod.rs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ pub enum LlmApiType {
1919

2020
#[derive(Debug, Clone, Serialize, Deserialize)]
2121
pub struct LlmSpec {
22-
api_type: LlmApiType,
23-
address: Option<String>,
24-
model: String,
22+
pub api_type: LlmApiType,
23+
pub address: Option<String>,
24+
pub model: String,
2525
}
2626

2727
#[derive(Debug)]
@@ -34,6 +34,7 @@ pub enum OutputFormat<'a> {
3434

3535
#[derive(Debug)]
3636
pub struct LlmGenerateRequest<'a> {
37+
pub model: &'a str,
3738
pub system_prompt: Option<Cow<'a, str>>,
3839
pub user_prompt: Cow<'a, str>,
3940
pub output_format: Option<OutputFormat<'a>>,
@@ -45,7 +46,7 @@ pub struct LlmGenerateResponse {
4546
}
4647

4748
#[async_trait]
48-
pub trait LlmGenerationClient: Send + Sync {
49+
pub trait LlmClient: Send + Sync {
4950
async fn generate<'req>(
5051
&self,
5152
request: LlmGenerateRequest<'req>,
@@ -61,25 +62,23 @@ mod ollama;
6162
mod openai;
6263
mod openrouter;
6364

64-
pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
65-
let client = match spec.api_type {
66-
LlmApiType::Ollama => {
67-
Box::new(ollama::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
68-
}
69-
LlmApiType::OpenAi => {
70-
Box::new(openai::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
71-
}
72-
LlmApiType::Gemini => {
73-
Box::new(gemini::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
74-
}
65+
pub async fn new_llm_generation_client(
66+
api_type: LlmApiType,
67+
address: Option<String>,
68+
) -> Result<Box<dyn LlmClient>> {
69+
let client = match api_type {
70+
LlmApiType::Ollama => Box::new(ollama::Client::new(address).await?) as Box<dyn LlmClient>,
71+
LlmApiType::OpenAi => Box::new(openai::Client::new(address).await?) as Box<dyn LlmClient>,
72+
LlmApiType::Gemini => Box::new(gemini::Client::new(address).await?) as Box<dyn LlmClient>,
7573
LlmApiType::Anthropic => {
76-
Box::new(anthropic::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
74+
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmClient>
7775
}
7876
LlmApiType::LiteLlm => {
79-
Box::new(litellm::Client::new_litellm(spec).await?) as Box<dyn LlmGenerationClient>
77+
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmClient>
78+
}
79+
LlmApiType::OpenRouter => {
80+
Box::new(openrouter::Client::new_openrouter(address).await?) as Box<dyn LlmClient>
8081
}
81-
LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(spec).await?)
82-
as Box<dyn LlmGenerationClient>,
8382
};
8483
Ok(client)
8584
}

src/llm/ollama.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
use super::LlmGenerationClient;
1+
use super::LlmClient;
22
use anyhow::Result;
33
use async_trait::async_trait;
44
use schemars::schema::SchemaObject;
55
use serde::{Deserialize, Serialize};
66

77
pub struct Client {
88
generate_url: String,
9-
model: String,
109
reqwest_client: reqwest::Client,
1110
}
1211

@@ -33,27 +32,26 @@ struct OllamaResponse {
3332
const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434";
3433

3534
impl Client {
36-
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
37-
let address = match &spec.address {
35+
pub async fn new(address: Option<String>) -> Result<Self> {
36+
let address = match &address {
3837
Some(addr) => addr.trim_end_matches('/'),
3938
None => OLLAMA_DEFAULT_ADDRESS,
4039
};
4140
Ok(Self {
4241
generate_url: format!("{}/api/generate", address),
43-
model: spec.model,
4442
reqwest_client: reqwest::Client::new(),
4543
})
4644
}
4745
}
4846

4947
#[async_trait]
50-
impl LlmGenerationClient for Client {
48+
impl LlmClient for Client {
5149
async fn generate<'req>(
5250
&self,
5351
request: super::LlmGenerateRequest<'req>,
5452
) -> Result<super::LlmGenerateResponse> {
5553
let req = OllamaRequest {
56-
model: &self.model,
54+
model: request.model,
5755
prompt: request.user_prompt.as_ref(),
5856
format: request.output_format.as_ref().map(
5957
|super::OutputFormat::JsonSchema { schema, .. }| {

src/llm/openai.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::api_bail;
22

3-
use super::LlmGenerationClient;
3+
use super::LlmClient;
44
use anyhow::Result;
55
use async_openai::{
66
Client as OpenAIClient,
@@ -16,16 +16,15 @@ use async_trait::async_trait;
1616

1717
pub struct Client {
1818
client: async_openai::Client<OpenAIConfig>,
19-
model: String,
2019
}
2120

2221
impl Client {
23-
pub(crate) fn from_parts(client: async_openai::Client<OpenAIConfig>, model: String) -> Self {
24-
Self { client, model }
22+
pub(crate) fn from_parts(client: async_openai::Client<OpenAIConfig>) -> Self {
23+
Self { client }
2524
}
2625

27-
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
28-
if let Some(address) = spec.address {
26+
pub async fn new(address: Option<String>) -> Result<Self> {
27+
if let Some(address) = address {
2928
api_bail!("OpenAI doesn't support custom API address: {address}");
3029
}
3130
// Verify API key is set
@@ -35,13 +34,12 @@ impl Client {
3534
Ok(Self {
3635
// OpenAI client will use OPENAI_API_KEY env variable by default
3736
client: OpenAIClient::new(),
38-
model: spec.model,
3937
})
4038
}
4139
}
4240

4341
#[async_trait]
44-
impl LlmGenerationClient for Client {
42+
impl LlmClient for Client {
4543
async fn generate<'req>(
4644
&self,
4745
request: super::LlmGenerateRequest<'req>,
@@ -70,7 +68,7 @@ impl LlmGenerationClient for Client {
7068

7169
// Create the chat completion request
7270
let request = CreateChatCompletionRequest {
73-
model: self.model.clone(),
71+
model: request.model.to_string(),
7472
messages,
7573
response_format: match request.output_format {
7674
Some(super::OutputFormat::JsonSchema { name, schema }) => {

src/llm/openrouter.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,13 @@ use async_openai::config::OpenAIConfig;
44
pub use super::openai::Client;
55

66
impl Client {
7-
pub async fn new_openrouter(spec: super::LlmSpec) -> anyhow::Result<Self> {
8-
let address = spec
9-
.address
10-
.clone()
11-
.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
7+
pub async fn new_openrouter(address: Option<String>) -> anyhow::Result<Self> {
8+
let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
129
let api_key = std::env::var("OPENROUTER_API_KEY").ok();
1310
let mut config = OpenAIConfig::new().with_api_base(address);
1411
if let Some(api_key) = api_key {
1512
config = config.with_api_key(api_key);
1613
}
17-
Ok(Client::from_parts(
18-
OpenAIClient::with_config(config),
19-
spec.model,
20-
))
14+
Ok(Client::from_parts(OpenAIClient::with_config(config)))
2115
}
2216
}

src/ops/functions/extract_by_llm.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use crate::prelude::*;
22

3-
use crate::llm::{
4-
LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat, new_llm_generation_client,
5-
};
3+
use crate::llm::{LlmClient, LlmGenerateRequest, LlmSpec, OutputFormat, new_llm_generation_client};
64
use crate::ops::sdk::*;
75
use base::json_schema::build_json_schema;
86
use schemars::schema::SchemaObject;
@@ -21,7 +19,8 @@ pub struct Args {
2119

2220
struct Executor {
2321
args: Args,
24-
client: Box<dyn LlmGenerationClient>,
22+
client: Box<dyn LlmClient>,
23+
model: String,
2524
output_json_schema: SchemaObject,
2625
system_prompt: String,
2726
value_extractor: base::json_schema::ValueExtractor,
@@ -50,11 +49,13 @@ Output only the JSON without any additional messages or explanations."
5049

5150
impl Executor {
5251
async fn new(spec: Spec, args: Args) -> Result<Self> {
53-
let client = new_llm_generation_client(spec.llm_spec).await?;
52+
let client =
53+
new_llm_generation_client(spec.llm_spec.api_type, spec.llm_spec.address).await?;
5454
let schema_output = build_json_schema(spec.output_type, client.json_schema_options())?;
5555
Ok(Self {
5656
args,
5757
client,
58+
model: spec.llm_spec.model,
5859
output_json_schema: schema_output.schema,
5960
system_prompt: get_system_prompt(&spec.instruction, schema_output.extra_instructions),
6061
value_extractor: schema_output.value_extractor,
@@ -75,6 +76,7 @@ impl SimpleFunctionExecutor for Executor {
7576
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
7677
let text = self.args.text.value(&input)?.as_str()?;
7778
let req = LlmGenerateRequest {
79+
model: &self.model,
7880
system_prompt: Some(Cow::Borrowed(&self.system_prompt)),
7981
user_prompt: Cow::Borrowed(text),
8082
output_format: Some(OutputFormat::JsonSchema {

0 commit comments

Comments
 (0)