diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 914cb71a1..b31666e18 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -145,6 +145,9 @@ pub async fn new_llm_embedding_client( api_config: Option, ) -> Result> { let client = match api_type { + LlmApiType::Ollama => { + Box::new(ollama::Client::new(address).await?) as Box + } LlmApiType::Gemini => { Box::new(gemini::AiStudioClient::new(address)?) as Box } @@ -156,11 +159,7 @@ pub async fn new_llm_embedding_client( } LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) as Box, - LlmApiType::Ollama - | LlmApiType::OpenRouter - | LlmApiType::LiteLlm - | LlmApiType::Vllm - | LlmApiType::Anthropic => { + LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => { api_bail!("Embedding is not supported for API type {:?}", api_type) } }; diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index 6b22febb4..bc09c9386 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -1,11 +1,35 @@ use crate::prelude::*; -use super::LlmGenerationClient; +use super::{LlmEmbeddingClient, LlmGenerationClient}; use schemars::schema::SchemaObject; use serde_with::{base64::Base64, serde_as}; +fn get_embedding_dimension(model: &str) -> Option { + match model.to_ascii_lowercase().as_str() { + "mxbai-embed-large" + | "bge-m3" + | "bge-large" + | "snowflake-arctic-embed" + | "snowflake-arctic-embed2" => Some(1024), + + "nomic-embed-text" + | "paraphrase-multilingual" + | "snowflake-arctic-embed:110m" + | "snowflake-arctic-embed:137m" + | "granite-embedding:278m" => Some(768), + + "all-minilm" + | "snowflake-arctic-embed:22m" + | "snowflake-arctic-embed:33m" + | "granite-embedding" => Some(384), + + _ => None, + } +} + pub struct Client { generate_url: String, + embed_url: String, reqwest_client: reqwest::Client, } @@ -32,6 +56,17 @@ struct OllamaResponse { pub response: String, } +#[derive(Debug, Serialize)] +struct OllamaEmbeddingRequest<'a> { + pub model: &'a str, + pub input: &'a str, +} + +#[derive(Debug, Deserialize)] +struct OllamaEmbeddingResponse { + pub embedding: Vec, +} + const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434"; impl Client { @@ -42,6 +77,7 @@ impl Client { }; Ok(Self { generate_url: format!("{address}/api/generate"), + embed_url: format!("{address}/api/embed"), reqwest_client: reqwest::Client::new(), }) } @@ -97,3 +133,41 @@ impl LlmGenerationClient for Client { } } } + +#[async_trait] +impl LlmEmbeddingClient for Client { + async fn embed_text<'req>( + &self, + request: super::LlmEmbeddingRequest<'req>, + ) -> Result { + let req = OllamaEmbeddingRequest { + model: request.model, + input: request.text.as_ref(), + }; + let resp = retryable::run( + || { + self.reqwest_client + .post(self.embed_url.as_str()) + .json(&req) + .send() + }, + &retryable::HEAVY_LOADED_OPTIONS, + ) + .await?; + if !resp.status().is_success() { + bail!( + "Ollama API error: {:?}\n{}\n", + resp.status(), + resp.text().await? + ); + } + let embedding_resp: OllamaEmbeddingResponse = resp.json().await.context("Invalid JSON")?; + Ok(super::LlmEmbeddingResponse { + embedding: embedding_resp.embedding, + }) + } + + fn get_default_embedding_dimension(&self, model: &str) -> Option { + get_embedding_dimension(model) + } +}