Skip to content

Commit c121b45

Browse files
committed
Add LiteLLM
1 parent 1cec966 commit c121b45

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

python/cocoindex/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class LlmApiType(Enum):
99
OLLAMA = "Ollama"
1010
GEMINI = "Gemini"
1111
ANTHROPIC = "Anthropic"
12+
LITELLM = "LiteLlm"
1213

1314

1415
@dataclass

src/llm/litellm.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use super::LlmGenerationClient;
2+
use anyhow::Result;
3+
use async_openai::{
4+
config::OpenAIConfig,
5+
types::{
6+
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
7+
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
8+
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, ResponseFormat,
9+
ResponseFormatJsonSchema,
10+
},
11+
Client as OpenAIClient,
12+
};
13+
use async_trait::async_trait;
14+
15+
pub struct Client {
16+
client: async_openai::Client<OpenAIConfig>,
17+
model: String,
18+
}
19+
20+
impl Client {
21+
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
22+
// For LiteLLM, use provided address or default to http://0.0.0.0:4000
23+
let address = spec.address.clone().unwrap_or_else(|| "http://0.0.0.0:4000".to_string());
24+
let api_key = std::env::var("LITELLM_API_KEY").unwrap_or_else(|_| "anything".to_string());
25+
let config = OpenAIConfig::new().with_api_base(address).with_api_key(api_key);
26+
Ok(Self {
27+
client: OpenAIClient::with_config(config),
28+
model: spec.model,
29+
})
30+
}
31+
}
32+
33+
#[async_trait]
34+
impl LlmGenerationClient for Client {
35+
async fn generate<'req>(
36+
&self,
37+
request: super::LlmGenerateRequest<'req>,
38+
) -> Result<super::LlmGenerateResponse> {
39+
let mut messages = Vec::new();
40+
41+
// Add system prompt if provided
42+
if let Some(system) = request.system_prompt {
43+
messages.push(ChatCompletionRequestMessage::System(
44+
ChatCompletionRequestSystemMessage {
45+
content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()),
46+
..Default::default()
47+
},
48+
));
49+
}
50+
51+
// Add user message
52+
messages.push(ChatCompletionRequestMessage::User(
53+
ChatCompletionRequestUserMessage {
54+
content: ChatCompletionRequestUserMessageContent::Text(
55+
request.user_prompt.into_owned(),
56+
),
57+
..Default::default()
58+
},
59+
));
60+
61+
// Create the chat completion request
62+
let request = CreateChatCompletionRequest {
63+
model: self.model.clone(),
64+
messages,
65+
response_format: match request.output_format {
66+
Some(super::OutputFormat::JsonSchema { name, schema }) => {
67+
Some(ResponseFormat::JsonSchema {
68+
json_schema: ResponseFormatJsonSchema {
69+
name: name.into_owned(),
70+
description: None,
71+
schema: Some(serde_json::to_value(&schema)?),
72+
strict: Some(true),
73+
},
74+
})
75+
}
76+
None => None,
77+
},
78+
..Default::default()
79+
};
80+
81+
// Send request and get response
82+
let response = self.client.chat().create(request).await?;
83+
84+
// Extract the response text from the first choice
85+
let text = response
86+
.choices
87+
.into_iter()
88+
.next()
89+
.and_then(|choice| choice.message.content)
90+
.ok_or_else(|| anyhow::anyhow!("No response from LiteLLM proxy"))?;
91+
92+
Ok(super::LlmGenerateResponse { text })
93+
}
94+
95+
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
96+
super::ToJsonSchemaOptions {
97+
fields_always_required: true,
98+
supports_format: false,
99+
extract_descriptions: false,
100+
top_level_must_be_object: true,
101+
}
102+
}
103+
}

src/llm/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub enum LlmApiType {
1313
OpenAi,
1414
Gemini,
1515
Anthropic,
16+
LiteLlm,
1617
}
1718

1819
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -56,6 +57,7 @@ mod anthropic;
5657
mod gemini;
5758
mod ollama;
5859
mod openai;
60+
mod litellm;
5961

6062
pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
6163
let client = match spec.api_type {
@@ -71,6 +73,9 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGener
7173
LlmApiType::Anthropic => {
7274
Box::new(anthropic::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
7375
}
76+
LlmApiType::LiteLlm => {
77+
Box::new(litellm::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
78+
}
7479
};
7580
Ok(client)
7681
}

0 commit comments

Comments
 (0)