Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/cocoindex/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class LlmApiType(Enum):
OLLAMA = "Ollama"
GEMINI = "Gemini"
ANTHROPIC = "Anthropic"
LITELLM = "LiteLlm"


@dataclass
Expand Down
103 changes: 103 additions & 0 deletions src/llm/litellm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use super::LlmGenerationClient;
use anyhow::Result;
use async_openai::{
config::OpenAIConfig,
types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, ResponseFormat,
ResponseFormatJsonSchema,
},
Client as OpenAIClient,
};
use async_trait::async_trait;

pub struct Client {
client: async_openai::Client<OpenAIConfig>,
model: String,
}

impl Client {
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
// For LiteLLM, use provided address or default to http://0.0.0.0:4000
let address = spec.address.clone().unwrap_or_else(|| "http://0.0.0.0:4000".to_string());
let api_key = std::env::var("LITELLM_API_KEY").unwrap_or_else(|_| "anything".to_string());
let config = OpenAIConfig::new().with_api_base(address).with_api_key(api_key);
Ok(Self {
client: OpenAIClient::with_config(config),
model: spec.model,
})
}
}

#[async_trait]
impl LlmGenerationClient for Client {
async fn generate<'req>(
&self,
request: super::LlmGenerateRequest<'req>,
) -> Result<super::LlmGenerateResponse> {
let mut messages = Vec::new();

// Add system prompt if provided
if let Some(system) = request.system_prompt {
messages.push(ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()),
..Default::default()
},
));
}

// Add user message
messages.push(ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(
request.user_prompt.into_owned(),
),
..Default::default()
},
));

// Create the chat completion request
let request = CreateChatCompletionRequest {
model: self.model.clone(),
messages,
response_format: match request.output_format {
Some(super::OutputFormat::JsonSchema { name, schema }) => {
Some(ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
name: name.into_owned(),
description: None,
schema: Some(serde_json::to_value(&schema)?),
strict: Some(true),
},
})
}
None => None,
},
..Default::default()
};

// Send request and get response
let response = self.client.chat().create(request).await?;

// Extract the response text from the first choice
let text = response
.choices
.into_iter()
.next()
.and_then(|choice| choice.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from LiteLLM proxy"))?;

Ok(super::LlmGenerateResponse { text })
}

fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
super::ToJsonSchemaOptions {
fields_always_required: true,
supports_format: false,
extract_descriptions: false,
top_level_must_be_object: true,
}
}
}
5 changes: 5 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum LlmApiType {
OpenAi,
Gemini,
Anthropic,
LiteLlm,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -56,6 +57,7 @@ mod anthropic;
mod gemini;
mod ollama;
mod openai;
mod litellm;

pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
let client = match spec.api_type {
Expand All @@ -71,6 +73,9 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGener
LlmApiType::Anthropic => {
Box::new(anthropic::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::LiteLlm => {
Box::new(litellm::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
};
Ok(client)
}
Loading