Skip to content

Commit b809b38

Browse files
committed
Use existing impl
1 parent c121b45 commit b809b38

File tree

1 file changed

+10
-94
lines changed

1 file changed

+10
-94
lines changed

src/llm/litellm.rs

Lines changed: 10 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,19 @@
1-
use super::LlmGenerationClient;
2-
use anyhow::Result;
3-
use async_openai::{
4-
config::OpenAIConfig,
5-
types::{
6-
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
7-
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
8-
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, ResponseFormat,
9-
ResponseFormatJsonSchema,
10-
},
11-
Client as OpenAIClient,
12-
};
13-
use async_trait::async_trait;
1+
use async_openai::config::OpenAIConfig;
2+
use async_openai::Client as OpenAIClient;
143

15-
pub struct Client {
16-
client: async_openai::Client<OpenAIConfig>,
17-
model: String,
18-
}
4+
pub use super::openai::Client;
195

206
impl Client {
21-
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
22-
// For LiteLLM, use provided address or default to http://0.0.0.0:4000
23-
let address = spec.address.clone().unwrap_or_else(|| "http://0.0.0.0:4000".to_string());
24-
let api_key = std::env::var("LITELLM_API_KEY").unwrap_or_else(|_| "anything".to_string());
25-
let config = OpenAIConfig::new().with_api_base(address).with_api_key(api_key);
7+
pub async fn new(spec: super::LlmSpec) -> anyhow::Result<Self> {
8+
let address = spec.address.clone().unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
9+
let api_key = std::env::var("LITELLM_API_KEY").ok();
10+
let mut config = OpenAIConfig::new().with_api_base(address);
11+
if let Some(api_key) = api_key {
12+
config = config.with_api_key(api_key);
13+
}
2614
Ok(Self {
2715
client: OpenAIClient::with_config(config),
2816
model: spec.model,
2917
})
3018
}
3119
}
32-
33-
#[async_trait]
34-
impl LlmGenerationClient for Client {
35-
async fn generate<'req>(
36-
&self,
37-
request: super::LlmGenerateRequest<'req>,
38-
) -> Result<super::LlmGenerateResponse> {
39-
let mut messages = Vec::new();
40-
41-
// Add system prompt if provided
42-
if let Some(system) = request.system_prompt {
43-
messages.push(ChatCompletionRequestMessage::System(
44-
ChatCompletionRequestSystemMessage {
45-
content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()),
46-
..Default::default()
47-
},
48-
));
49-
}
50-
51-
// Add user message
52-
messages.push(ChatCompletionRequestMessage::User(
53-
ChatCompletionRequestUserMessage {
54-
content: ChatCompletionRequestUserMessageContent::Text(
55-
request.user_prompt.into_owned(),
56-
),
57-
..Default::default()
58-
},
59-
));
60-
61-
// Create the chat completion request
62-
let request = CreateChatCompletionRequest {
63-
model: self.model.clone(),
64-
messages,
65-
response_format: match request.output_format {
66-
Some(super::OutputFormat::JsonSchema { name, schema }) => {
67-
Some(ResponseFormat::JsonSchema {
68-
json_schema: ResponseFormatJsonSchema {
69-
name: name.into_owned(),
70-
description: None,
71-
schema: Some(serde_json::to_value(&schema)?),
72-
strict: Some(true),
73-
},
74-
})
75-
}
76-
None => None,
77-
},
78-
..Default::default()
79-
};
80-
81-
// Send request and get response
82-
let response = self.client.chat().create(request).await?;
83-
84-
// Extract the response text from the first choice
85-
let text = response
86-
.choices
87-
.into_iter()
88-
.next()
89-
.and_then(|choice| choice.message.content)
90-
.ok_or_else(|| anyhow::anyhow!("No response from LiteLLM proxy"))?;
91-
92-
Ok(super::LlmGenerateResponse { text })
93-
}
94-
95-
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
96-
super::ToJsonSchemaOptions {
97-
fields_always_required: true,
98-
supports_format: false,
99-
extract_descriptions: false,
100-
top_level_must_be_object: true,
101-
}
102-
}
103-
}

0 commit comments

Comments
 (0)