11use crate :: prelude:: * ;
22
3- use super :: LlmGenerationClient ;
3+ use super :: { LlmEmbeddingClient , LlmGenerationClient } ;
44use schemars:: schema:: SchemaObject ;
55use serde_with:: { base64:: Base64 , serde_as} ;
66
7+ fn get_embedding_dimension ( model : & str ) -> Option < u32 > {
8+ match model. to_ascii_lowercase ( ) . as_str ( ) {
9+ "mxbai-embed-large"
10+ | "bge-m3"
11+ | "bge-large"
12+ | "snowflake-arctic-embed"
13+ | "snowflake-arctic-embed2" => Some ( 1024 ) ,
14+
15+ "nomic-embed-text"
16+ | "paraphrase-multilingual"
17+ | "snowflake-arctic-embed:110m"
18+ | "snowflake-arctic-embed:137m"
19+ | "granite-embedding:278m" => Some ( 768 ) ,
20+
21+ "all-minilm"
22+ | "snowflake-arctic-embed:22m"
23+ | "snowflake-arctic-embed:33m"
24+ | "granite-embedding" => Some ( 384 ) ,
25+
26+ _ => None ,
27+ }
28+ }
29+
730pub struct Client {
831 generate_url : String ,
32+ embed_url : String ,
933 reqwest_client : reqwest:: Client ,
1034}
1135
@@ -32,6 +56,17 @@ struct OllamaResponse {
3256 pub response : String ,
3357}
3458
59+ #[ derive( Debug , Serialize ) ]
60+ struct OllamaEmbeddingRequest < ' a > {
61+ pub model : & ' a str ,
62+ pub input : & ' a str ,
63+ }
64+
65+ #[ derive( Debug , Deserialize ) ]
66+ struct OllamaEmbeddingResponse {
67+ pub embedding : Vec < f32 > ,
68+ }
69+
3570const OLLAMA_DEFAULT_ADDRESS : & str = "http://localhost:11434" ;
3671
3772impl Client {
@@ -42,6 +77,7 @@ impl Client {
4277 } ;
4378 Ok ( Self {
4479 generate_url : format ! ( "{address}/api/generate" ) ,
80+ embed_url : format ! ( "{address}/api/embed" ) ,
4581 reqwest_client : reqwest:: Client :: new ( ) ,
4682 } )
4783 }
@@ -97,3 +133,41 @@ impl LlmGenerationClient for Client {
97133 }
98134 }
99135}
136+
137+ #[ async_trait]
138+ impl LlmEmbeddingClient for Client {
139+ async fn embed_text < ' req > (
140+ & self ,
141+ request : super :: LlmEmbeddingRequest < ' req > ,
142+ ) -> Result < super :: LlmEmbeddingResponse > {
143+ let req = OllamaEmbeddingRequest {
144+ model : request. model ,
145+ input : request. text . as_ref ( ) ,
146+ } ;
147+ let resp = retryable:: run (
148+ || {
149+ self . reqwest_client
150+ . post ( self . embed_url . as_str ( ) )
151+ . json ( & req)
152+ . send ( )
153+ } ,
154+ & retryable:: HEAVY_LOADED_OPTIONS ,
155+ )
156+ . await ?;
157+ if !resp. status ( ) . is_success ( ) {
158+ bail ! (
159+ "Ollama API error: {:?}\n {}\n " ,
160+ resp. status( ) ,
161+ resp. text( ) . await ?
162+ ) ;
163+ }
164+ let embedding_resp: OllamaEmbeddingResponse = resp. json ( ) . await . context ( "Invalid JSON" ) ?;
165+ Ok ( super :: LlmEmbeddingResponse {
166+ embedding : embedding_resp. embedding ,
167+ } )
168+ }
169+
170+ fn get_default_embedding_dimension ( & self , model : & str ) -> Option < u32 > {
171+ get_embedding_dimension ( model)
172+ }
173+ }
0 commit comments