|
1 | 1 | use crate::prelude::*; |
2 | 2 |
|
3 | 3 | use super::{LlmEmbeddingClient, LlmGenerationClient}; |
4 | | -use phf::phf_map; |
5 | 4 | use schemars::schema::SchemaObject; |
6 | 5 | use serde_with::{base64::Base64, serde_as}; |
7 | 6 |
|
| 7 | +fn get_embedding_dimension(model: &str) -> Option<u32> { |
| 8 | + match model.to_ascii_lowercase().as_str() { |
| 9 | + "mxbai-embed-large" |
| 10 | + | "bge-m3" |
| 11 | + | "bge-large" |
| 12 | + | "snowflake-arctic-embed" |
| 13 | + | "snowflake-arctic-embed2" => Some(1024), |
| 14 | + |
| 15 | + "nomic-embed-text" |
| 16 | + | "paraphrase-multilingual" |
| 17 | + | "snowflake-arctic-embed:110m" |
| 18 | + | "snowflake-arctic-embed:137m" |
| 19 | + | "granite-embedding:278m" => Some(768), |
| 20 | + |
| 21 | + "all-minilm" |
| 22 | + | "snowflake-arctic-embed:22m" |
| 23 | + | "snowflake-arctic-embed:33m" |
| 24 | + | "granite-embedding" => Some(384), |
| 25 | + |
| 26 | + _ => None, |
| 27 | + } |
| 28 | +} |
| 29 | + |
8 | 30 | pub struct Client { |
9 | 31 | generate_url: String, |
10 | 32 | embed_url: String, |
@@ -45,23 +67,6 @@ struct OllamaEmbeddingResponse { |
45 | 67 | pub embedding: Vec<f32>, |
46 | 68 | } |
47 | 69 |
|
48 | | -static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { |
49 | | - "nomic-embed-text" => 768, |
50 | | - "mxbai-embed-large" => 1024, |
51 | | - "bge-m3" => 1024, |
52 | | - "bge-large" => 1024, |
53 | | - "all-minilm" => 384, |
54 | | - "snowflake-arctic-embed:22m" => 384, |
55 | | - "snowflake-arctic-embed:33m" => 384, |
56 | | - "snowflake-arctic-embed:110m" => 768, |
57 | | - "snowflake-arctic-embed:137m" => 768, |
58 | | - "snowflake-arctic-embed" => 1024, |
59 | | - "snowflake-arctic-embed2" => 1024, |
60 | | - "paraphrase-multilingual" => 768, |
61 | | - "granite-embedding" => 384, |
62 | | - "granite-embedding:278m" => 768, |
63 | | -}; |
64 | | - |
65 | 70 | const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434"; |
66 | 71 |
|
67 | 72 | impl Client { |
@@ -163,6 +168,6 @@ impl LlmEmbeddingClient for Client { |
163 | 168 | } |
164 | 169 |
|
165 | 170 | fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> { |
166 | | - DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied() |
| 171 | + get_embedding_dimension(model) |
167 | 172 | } |
168 | 173 | } |
0 commit comments