From 488d254585ae17d78b9cd13c89c598f04cf5ffd0 Mon Sep 17 00:00:00 2001 From: lemorage Date: Thu, 17 Jul 2025 11:01:54 +0200 Subject: [PATCH 1/3] feat(embed-text): add `EmbedText` for Ollama --- src/llm/ollama.rs | 71 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index 6b22febb4..984053984 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -1,11 +1,13 @@ use crate::prelude::*; -use super::LlmGenerationClient; +use super::{LlmEmbeddingClient, LlmGenerationClient}; +use phf::phf_map; use schemars::schema::SchemaObject; use serde_with::{base64::Base64, serde_as}; pub struct Client { generate_url: String, + embed_url: String, reqwest_client: reqwest::Client, } @@ -32,6 +34,34 @@ struct OllamaResponse { pub response: String, } +#[derive(Debug, Serialize)] +struct OllamaEmbeddingRequest<'a> { + pub model: &'a str, + pub input: &'a str, +} + +#[derive(Debug, Deserialize)] +struct OllamaEmbeddingResponse { + pub embedding: Vec, +} + +static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { + "nomic-embed-text" => 768, + "mxbai-embed-large" => 1024, + "bge-m3" => 1024, + "bge-large" => 1024, + "all-minilm" => 384, + "snowflake-arctic-embed:22m" => 384, + "snowflake-arctic-embed:33m" => 384, + "snowflake-arctic-embed:110m" => 768, + "snowflake-arctic-embed:137m" => 768, + "snowflake-arctic-embed" => 1024, + "snowflake-arctic-embed2" => 1024, + "paraphrase-multilingual" => 768, + "granite-embedding" => 384, + "granite-embedding:278m" => 768, +}; + const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434"; impl Client { @@ -42,6 +72,7 @@ impl Client { }; Ok(Self { generate_url: format!("{address}/api/generate"), + embed_url: format!("{address}/api/embed"), reqwest_client: reqwest::Client::new(), }) } @@ -97,3 +128,41 @@ impl LlmGenerationClient for Client { } } } + +#[async_trait] +impl LlmEmbeddingClient for Client { + async fn embed_text<'req>( + &self, + request: super::LlmEmbeddingRequest<'req>, + ) -> Result { + let req = OllamaEmbeddingRequest { + model: request.model, + input: request.text.as_ref(), + }; + let resp = retryable::run( + || { + self.reqwest_client + .post(self.embed_url.as_str()) + .json(&req) + .send() + }, + &retryable::HEAVY_LOADED_OPTIONS, + ) + .await?; + if !resp.status().is_success() { + bail!( + "Ollama API error: {:?}\n{}\n", + resp.status(), + resp.text().await? + ); + } + let embedding_resp: OllamaEmbeddingResponse = resp.json().await.context("Invalid JSON")?; + Ok(super::LlmEmbeddingResponse { + embedding: embedding_resp.embedding, + }) + } + + fn get_default_embedding_dimension(&self, model: &str) -> Option { + DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied() + } +} From 129bc894d97271d5e3b80294ab4b17a75f0437fc Mon Sep 17 00:00:00 2001 From: lemorage Date: Thu, 17 Jul 2025 11:22:46 +0200 Subject: [PATCH 2/3] refactor(ollama): replace static map with dynamic dimension retrieval --- src/llm/ollama.rs | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index 984053984..bc09c9386 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -1,10 +1,32 @@ use crate::prelude::*; use super::{LlmEmbeddingClient, LlmGenerationClient}; -use phf::phf_map; use schemars::schema::SchemaObject; use serde_with::{base64::Base64, serde_as}; +fn get_embedding_dimension(model: &str) -> Option { + match model.to_ascii_lowercase().as_str() { + "mxbai-embed-large" + | "bge-m3" + | "bge-large" + | "snowflake-arctic-embed" + | "snowflake-arctic-embed2" => Some(1024), + + "nomic-embed-text" + | "paraphrase-multilingual" + | "snowflake-arctic-embed:110m" + | "snowflake-arctic-embed:137m" + | "granite-embedding:278m" => Some(768), + + "all-minilm" + | "snowflake-arctic-embed:22m" + | "snowflake-arctic-embed:33m" + | "granite-embedding" => Some(384), + + _ => None, + } +} + pub struct Client { generate_url: String, embed_url: String, @@ -45,23 +67,6 @@ struct OllamaEmbeddingResponse { pub embedding: Vec, } -static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { - "nomic-embed-text" => 768, - "mxbai-embed-large" => 1024, - "bge-m3" => 1024, - "bge-large" => 1024, - "all-minilm" => 384, - "snowflake-arctic-embed:22m" => 384, - "snowflake-arctic-embed:33m" => 384, - "snowflake-arctic-embed:110m" => 768, - "snowflake-arctic-embed:137m" => 768, - "snowflake-arctic-embed" => 1024, - "snowflake-arctic-embed2" => 1024, - "paraphrase-multilingual" => 768, - "granite-embedding" => 384, - "granite-embedding:278m" => 768, -}; - const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434"; impl Client { @@ -163,6 +168,6 @@ impl LlmEmbeddingClient for Client { } fn get_default_embedding_dimension(&self, model: &str) -> Option { - DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied() + get_embedding_dimension(model) } } From 227abee71f836f04780879e5c02272a10b7ebb17 Mon Sep 17 00:00:00 2001 From: lemorage Date: Tue, 17 Jun 2025 14:33:07 +0200 Subject: [PATCH 3/3] feat(llm): add embedding support for Ollama --- src/llm/mod.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 914cb71a1..b31666e18 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -145,6 +145,9 @@ pub async fn new_llm_embedding_client( api_config: Option, ) -> Result> { let client = match api_type { + LlmApiType::Ollama => { + Box::new(ollama::Client::new(address).await?) as Box + } LlmApiType::Gemini => { Box::new(gemini::AiStudioClient::new(address)?) as Box } @@ -156,11 +159,7 @@ pub async fn new_llm_embedding_client( } LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) as Box, - LlmApiType::Ollama - | LlmApiType::OpenRouter - | LlmApiType::LiteLlm - | LlmApiType::Vllm - | LlmApiType::Anthropic => { + LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => { api_bail!("Embedding is not supported for API type {:?}", api_type) } };