@@ -6,15 +6,23 @@ use crate::llm::{
66} ;
77use base64:: prelude:: * ;
88use google_cloud_aiplatform_v1 as vertexai;
9- use phf:: phf_map;
109use serde_json:: Value ;
1110use urlencoding:: encode;
1211
13- static DEFAULT_EMBEDDING_DIMENSIONS : phf:: Map < & str , u32 > = phf_map ! {
14- "gemini-embedding-exp-03-07" => 3072 ,
15- "text-embedding-004" => 768 ,
16- "embedding-001" => 768 ,
17- } ;
12+ fn get_embedding_dimension ( model : & str ) -> Option < u32 > {
13+ let model = model. to_ascii_lowercase ( ) ;
14+ if model. starts_with ( "gemini-embedding-" ) {
15+ Some ( 3072 )
16+ } else if model. starts_with ( "text-embedding-" ) {
17+ Some ( 768 )
18+ } else if model. starts_with ( "embedding-" ) {
19+ Some ( 768 )
20+ } else if model. starts_with ( "text-multilingual-embedding-" ) {
21+ Some ( 768 )
22+ } else {
23+ None
24+ }
25+ }
1826
1927pub struct AiStudioClient {
2028 api_key : String ,
@@ -192,7 +200,7 @@ impl LlmEmbeddingClient for AiStudioClient {
192200 }
193201
194202 fn get_default_embedding_dimension ( & self , model : & str ) -> Option < u32 > {
195- DEFAULT_EMBEDDING_DIMENSIONS . get ( model) . copied ( )
203+ get_embedding_dimension ( model)
196204 }
197205}
198206
@@ -202,12 +210,30 @@ pub struct VertexAiClient {
202210}
203211
204212impl VertexAiClient {
205- pub async fn new ( config : super :: VertexAiConfig ) -> Result < Self > {
213+ pub async fn new (
214+ address : Option < String > ,
215+ api_config : Option < super :: LlmApiConfig > ,
216+ ) -> Result < Self > {
217+ if address. is_some ( ) {
218+ api_bail ! ( "VertexAi API address is not supported for VertexAi API type" ) ;
219+ }
220+ let Some ( super :: LlmApiConfig :: VertexAi ( config) ) = api_config else {
221+ api_bail ! ( "VertexAi API config is required for VertexAi API type" ) ;
222+ } ;
206223 let client = vertexai:: client:: PredictionService :: builder ( )
207224 . build ( )
208225 . await ?;
209226 Ok ( Self { client, config } )
210227 }
228+
229+ fn get_model_path ( & self , model : & str ) -> String {
230+ format ! (
231+ "projects/{}/locations/{}/publishers/google/models/{}" ,
232+ self . config. project,
233+ self . config. region. as_deref( ) . unwrap_or( "global" ) ,
234+ model
235+ )
236+ }
211237}
212238
213239#[ async_trait]
@@ -254,20 +280,10 @@ impl LlmGenerationClient for VertexAiClient {
254280 ) ;
255281 }
256282
257- // projects/{project_id}/locations/global/publishers/google/models/{MODEL}
258-
259- let model = format ! (
260- "projects/{}/locations/{}/publishers/google/models/{}" ,
261- self . config. project,
262- self . config. region. as_deref( ) . unwrap_or( "global" ) ,
263- request. model
264- ) ;
265-
266- // Build the request
267283 let mut req = self
268284 . client
269285 . generate_content ( )
270- . set_model ( model)
286+ . set_model ( self . get_model_path ( request . model ) )
271287 . set_contents ( contents) ;
272288 if let Some ( sys) = system_instruction {
273289 req = req. set_system_instruction ( sys) ;
@@ -301,3 +317,54 @@ impl LlmGenerationClient for VertexAiClient {
301317 }
302318 }
303319}
320+
321+ #[ async_trait]
322+ impl LlmEmbeddingClient for VertexAiClient {
323+ async fn embed_text < ' req > (
324+ & self ,
325+ request : super :: LlmEmbeddingRequest < ' req > ,
326+ ) -> Result < super :: LlmEmbeddingResponse > {
327+ // Create the instances for the request
328+ let mut instance = serde_json:: json!( {
329+ "content" : request. text
330+ } ) ;
331+ // Add task type if specified
332+ if let Some ( task_type) = & request. task_type {
333+ instance[ "task_type" ] = serde_json:: Value :: String ( task_type. to_string ( ) ) ;
334+ }
335+
336+ let instances = vec ! [ instance] ;
337+
338+ // Prepare the request parameters
339+ let mut parameters = serde_json:: json!( { } ) ;
340+ if let Some ( output_dimension) = request. output_dimension {
341+ parameters[ "outputDimensionality" ] = serde_json:: Value :: Number ( output_dimension. into ( ) ) ;
342+ }
343+
344+ // Build the prediction request using the raw predict builder
345+ let response = self
346+ . client
347+ . predict ( )
348+ . set_endpoint ( self . get_model_path ( request. model ) )
349+ . set_instances ( instances)
350+ . set_parameters ( parameters)
351+ . send ( )
352+ . await ?;
353+
354+ // Extract the embedding from the response
355+ let embeddings = response
356+ . predictions
357+ . into_iter ( )
358+ . next ( )
359+ . and_then ( |mut e| e. get_mut ( "embeddings" ) . map ( |v| v. take ( ) ) )
360+ . ok_or_else ( || anyhow:: anyhow!( "No embeddings in response" ) ) ?;
361+ let embedding: ContentEmbedding = serde_json:: from_value ( embeddings) ?;
362+ Ok ( super :: LlmEmbeddingResponse {
363+ embedding : embedding. values ,
364+ } )
365+ }
366+
367+ fn get_default_embedding_dimension ( & self , model : & str ) -> Option < u32 > {
368+ get_embedding_dimension ( model)
369+ }
370+ }
0 commit comments