Skip to content

Commit 129bc89

Browse files
committed
refactor(ollama): replace static map with dynamic dimension retrieval
1 parent 488d254 commit 129bc89

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

src/llm/ollama.rs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,32 @@
11
use crate::prelude::*;
22

33
use super::{LlmEmbeddingClient, LlmGenerationClient};
4-
use phf::phf_map;
54
use schemars::schema::SchemaObject;
65
use serde_with::{base64::Base64, serde_as};
76

7+
fn get_embedding_dimension(model: &str) -> Option<u32> {
8+
match model.to_ascii_lowercase().as_str() {
9+
"mxbai-embed-large"
10+
| "bge-m3"
11+
| "bge-large"
12+
| "snowflake-arctic-embed"
13+
| "snowflake-arctic-embed2" => Some(1024),
14+
15+
"nomic-embed-text"
16+
| "paraphrase-multilingual"
17+
| "snowflake-arctic-embed:110m"
18+
| "snowflake-arctic-embed:137m"
19+
| "granite-embedding:278m" => Some(768),
20+
21+
"all-minilm"
22+
| "snowflake-arctic-embed:22m"
23+
| "snowflake-arctic-embed:33m"
24+
| "granite-embedding" => Some(384),
25+
26+
_ => None,
27+
}
28+
}
29+
830
pub struct Client {
931
generate_url: String,
1032
embed_url: String,
@@ -45,23 +67,6 @@ struct OllamaEmbeddingResponse {
4567
pub embedding: Vec<f32>,
4668
}
4769

48-
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
49-
"nomic-embed-text" => 768,
50-
"mxbai-embed-large" => 1024,
51-
"bge-m3" => 1024,
52-
"bge-large" => 1024,
53-
"all-minilm" => 384,
54-
"snowflake-arctic-embed:22m" => 384,
55-
"snowflake-arctic-embed:33m" => 384,
56-
"snowflake-arctic-embed:110m" => 768,
57-
"snowflake-arctic-embed:137m" => 768,
58-
"snowflake-arctic-embed" => 1024,
59-
"snowflake-arctic-embed2" => 1024,
60-
"paraphrase-multilingual" => 768,
61-
"granite-embedding" => 384,
62-
"granite-embedding:278m" => 768,
63-
};
64-
6570
const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434";
6671

6772
impl Client {
@@ -163,6 +168,6 @@ impl LlmEmbeddingClient for Client {
163168
}
164169

165170
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
166-
DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied()
171+
get_embedding_dimension(model)
167172
}
168173
}

0 commit comments

Comments
 (0)