Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ pub async fn new_llm_embedding_client(
api_config: Option<LlmApiConfig>,
) -> Result<Box<dyn LlmEmbeddingClient>> {
let client = match api_type {
LlmApiType::Ollama => {
Box::new(ollama::Client::new(address).await?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::Gemini => {
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmEmbeddingClient>
}
Expand All @@ -156,11 +159,7 @@ pub async fn new_llm_embedding_client(
}
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
as Box<dyn LlmEmbeddingClient>,
LlmApiType::Ollama
| LlmApiType::OpenRouter
| LlmApiType::LiteLlm
| LlmApiType::Vllm
| LlmApiType::Anthropic => {
LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => {
api_bail!("Embedding is not supported for API type {:?}", api_type)
}
};
Expand Down
76 changes: 75 additions & 1 deletion src/llm/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
use crate::prelude::*;

use super::LlmGenerationClient;
use super::{LlmEmbeddingClient, LlmGenerationClient};
use schemars::schema::SchemaObject;
use serde_with::{base64::Base64, serde_as};

fn get_embedding_dimension(model: &str) -> Option<u32> {
match model.to_ascii_lowercase().as_str() {
"mxbai-embed-large"
| "bge-m3"
| "bge-large"
| "snowflake-arctic-embed"
| "snowflake-arctic-embed2" => Some(1024),

"nomic-embed-text"
| "paraphrase-multilingual"
| "snowflake-arctic-embed:110m"
| "snowflake-arctic-embed:137m"
| "granite-embedding:278m" => Some(768),

"all-minilm"
| "snowflake-arctic-embed:22m"
| "snowflake-arctic-embed:33m"
| "granite-embedding" => Some(384),

_ => None,
}
}

pub struct Client {
generate_url: String,
embed_url: String,
reqwest_client: reqwest::Client,
}

Expand All @@ -32,6 +56,17 @@ struct OllamaResponse {
pub response: String,
}

#[derive(Debug, Serialize)]
struct OllamaEmbeddingRequest<'a> {
pub model: &'a str,
pub input: &'a str,
}

#[derive(Debug, Deserialize)]
struct OllamaEmbeddingResponse {
pub embedding: Vec<f32>,
}

const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434";

impl Client {
Expand All @@ -42,6 +77,7 @@ impl Client {
};
Ok(Self {
generate_url: format!("{address}/api/generate"),
embed_url: format!("{address}/api/embed"),
reqwest_client: reqwest::Client::new(),
})
}
Expand Down Expand Up @@ -97,3 +133,41 @@ impl LlmGenerationClient for Client {
}
}
}

#[async_trait]
impl LlmEmbeddingClient for Client {
async fn embed_text<'req>(
&self,
request: super::LlmEmbeddingRequest<'req>,
) -> Result<super::LlmEmbeddingResponse> {
let req = OllamaEmbeddingRequest {
model: request.model,
input: request.text.as_ref(),
};
let resp = retryable::run(
|| {
self.reqwest_client
.post(self.embed_url.as_str())
.json(&req)
.send()
},
&retryable::HEAVY_LOADED_OPTIONS,
)
.await?;
if !resp.status().is_success() {
bail!(
"Ollama API error: {:?}\n{}\n",
resp.status(),
resp.text().await?
);
}
let embedding_resp: OllamaEmbeddingResponse = resp.json().await.context("Invalid JSON")?;
Ok(super::LlmEmbeddingResponse {
embedding: embedding_resp.embedding,
})
}

fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
get_embedding_dimension(model)
}
}
Loading