|
1 | | -use super::LlmGenerationClient; |
2 | | -use anyhow::Result; |
3 | | -use async_openai::{ |
4 | | - config::OpenAIConfig, |
5 | | - types::{ |
6 | | - ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, |
7 | | - ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage, |
8 | | - ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, ResponseFormat, |
9 | | - ResponseFormatJsonSchema, |
10 | | - }, |
11 | | - Client as OpenAIClient, |
12 | | -}; |
13 | | -use async_trait::async_trait; |
| 1 | +use async_openai::config::OpenAIConfig; |
| 2 | +use async_openai::Client as OpenAIClient; |
14 | 3 |
|
15 | | -pub struct Client { |
16 | | - client: async_openai::Client<OpenAIConfig>, |
17 | | - model: String, |
18 | | -} |
| 4 | +pub use super::openai::Client; |
19 | 5 |
|
20 | 6 | impl Client { |
21 | | - pub async fn new(spec: super::LlmSpec) -> Result<Self> { |
22 | | - // For LiteLLM, use provided address or default to http://0.0.0.0:4000 |
23 | | - let address = spec.address.clone().unwrap_or_else(|| "http://0.0.0.0:4000".to_string()); |
24 | | - let api_key = std::env::var("LITELLM_API_KEY").unwrap_or_else(|_| "anything".to_string()); |
25 | | - let config = OpenAIConfig::new().with_api_base(address).with_api_key(api_key); |
| 7 | + pub async fn new(spec: super::LlmSpec) -> anyhow::Result<Self> { |
| 8 | + let address = spec.address.clone().unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); |
| 9 | + let api_key = std::env::var("LITELLM_API_KEY").ok(); |
| 10 | + let mut config = OpenAIConfig::new().with_api_base(address); |
| 11 | + if let Some(api_key) = api_key { |
| 12 | + config = config.with_api_key(api_key); |
| 13 | + } |
26 | 14 | Ok(Self { |
27 | 15 | client: OpenAIClient::with_config(config), |
28 | 16 | model: spec.model, |
29 | 17 | }) |
30 | 18 | } |
31 | 19 | } |
32 | | - |
33 | | -#[async_trait] |
34 | | -impl LlmGenerationClient for Client { |
35 | | - async fn generate<'req>( |
36 | | - &self, |
37 | | - request: super::LlmGenerateRequest<'req>, |
38 | | - ) -> Result<super::LlmGenerateResponse> { |
39 | | - let mut messages = Vec::new(); |
40 | | - |
41 | | - // Add system prompt if provided |
42 | | - if let Some(system) = request.system_prompt { |
43 | | - messages.push(ChatCompletionRequestMessage::System( |
44 | | - ChatCompletionRequestSystemMessage { |
45 | | - content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()), |
46 | | - ..Default::default() |
47 | | - }, |
48 | | - )); |
49 | | - } |
50 | | - |
51 | | - // Add user message |
52 | | - messages.push(ChatCompletionRequestMessage::User( |
53 | | - ChatCompletionRequestUserMessage { |
54 | | - content: ChatCompletionRequestUserMessageContent::Text( |
55 | | - request.user_prompt.into_owned(), |
56 | | - ), |
57 | | - ..Default::default() |
58 | | - }, |
59 | | - )); |
60 | | - |
61 | | - // Create the chat completion request |
62 | | - let request = CreateChatCompletionRequest { |
63 | | - model: self.model.clone(), |
64 | | - messages, |
65 | | - response_format: match request.output_format { |
66 | | - Some(super::OutputFormat::JsonSchema { name, schema }) => { |
67 | | - Some(ResponseFormat::JsonSchema { |
68 | | - json_schema: ResponseFormatJsonSchema { |
69 | | - name: name.into_owned(), |
70 | | - description: None, |
71 | | - schema: Some(serde_json::to_value(&schema)?), |
72 | | - strict: Some(true), |
73 | | - }, |
74 | | - }) |
75 | | - } |
76 | | - None => None, |
77 | | - }, |
78 | | - ..Default::default() |
79 | | - }; |
80 | | - |
81 | | - // Send request and get response |
82 | | - let response = self.client.chat().create(request).await?; |
83 | | - |
84 | | - // Extract the response text from the first choice |
85 | | - let text = response |
86 | | - .choices |
87 | | - .into_iter() |
88 | | - .next() |
89 | | - .and_then(|choice| choice.message.content) |
90 | | - .ok_or_else(|| anyhow::anyhow!("No response from LiteLLM proxy"))?; |
91 | | - |
92 | | - Ok(super::LlmGenerateResponse { text }) |
93 | | - } |
94 | | - |
95 | | - fn json_schema_options(&self) -> super::ToJsonSchemaOptions { |
96 | | - super::ToJsonSchemaOptions { |
97 | | - fields_always_required: true, |
98 | | - supports_format: false, |
99 | | - extract_descriptions: false, |
100 | | - top_level_must_be_object: true, |
101 | | - } |
102 | | - } |
103 | | -} |
0 commit comments