@@ -70,36 +70,44 @@ fn remove_additional_properties(value: &mut Value) {
7070impl AiStudioClient {
7171 fn get_api_url ( & self , model : & str , api_name : & str ) -> String {
7272 format ! (
73- "https://generativelanguage.googleapis.com/v1beta/models/{}:{}?key={} " ,
73+ "https://generativelanguage.googleapis.com/v1beta/models/{}:{}" ,
7474 encode( model) ,
75- api_name,
76- encode( & self . api_key)
75+ api_name
7776 )
7877 }
7978}
8079
8180fn build_embed_payload (
8281 model : & str ,
83- text : & str ,
82+ texts : & [ & str ] ,
8483 task_type : Option < & str > ,
8584 output_dimension : Option < u32 > ,
8685) -> serde_json:: Value {
87- let mut payload = serde_json:: json!( {
88- "model" : model,
89- "content" : { "parts" : [ { "text" : text } ] } ,
90- } ) ;
91- if let Some ( task_type) = task_type {
92- payload[ "taskType" ] = serde_json:: Value :: String ( task_type. to_string ( ) ) ;
93- }
94- if let Some ( output_dimension) = output_dimension {
95- payload[ "outputDimensionality" ] = serde_json:: json!( output_dimension) ;
96- if model. starts_with ( "gemini-embedding-" ) {
97- payload[ "config" ] = serde_json:: json!( {
98- "outputDimensionality" : output_dimension,
86+ let requests: Vec < _ > = texts
87+ . iter ( )
88+ . map ( |text| {
89+ let mut req = serde_json:: json!( {
90+ "model" : format!( "models/{}" , model) ,
91+ "content" : { "parts" : [ { "text" : text } ] } ,
9992 } ) ;
100- }
101- }
102- payload
93+ if let Some ( task_type) = task_type {
94+ req[ "taskType" ] = serde_json:: Value :: String ( task_type. to_string ( ) ) ;
95+ }
96+ if let Some ( output_dimension) = output_dimension {
97+ req[ "outputDimensionality" ] = serde_json:: json!( output_dimension) ;
98+ if model. starts_with ( "gemini-embedding-" ) {
99+ req[ "config" ] = serde_json:: json!( {
100+ "outputDimensionality" : output_dimension,
101+ } ) ;
102+ }
103+ }
104+ req
105+ } )
106+ . collect ( ) ;
107+
108+ serde_json:: json!( {
109+ "requests" : requests,
110+ } )
103111}
104112
105113#[ async_trait]
@@ -182,8 +190,8 @@ struct ContentEmbedding {
182190 values : Vec < f32 > ,
183191}
184192#[ derive( Deserialize ) ]
185- struct EmbedContentResponse {
186- embedding : ContentEmbedding ,
193+ struct BatchEmbedContentResponse {
194+ embeddings : Vec < ContentEmbedding > ,
187195}
188196
189197#[ async_trait]
@@ -192,29 +200,30 @@ impl LlmEmbeddingClient for AiStudioClient {
192200 & self ,
193201 request : super :: LlmEmbeddingRequest < ' req > ,
194202 ) -> Result < super :: LlmEmbeddingResponse > {
195- let url = self . get_api_url ( request. model , "embedContent" ) ;
203+ let url = self . get_api_url ( request. model , "batchEmbedContents" ) ;
204+ let texts: Vec < & str > = request. texts . iter ( ) . map ( |t| t. as_ref ( ) ) . collect ( ) ;
196205 let payload = build_embed_payload (
197206 request. model ,
198- request . text . as_ref ( ) ,
207+ & texts ,
199208 request. task_type . as_deref ( ) ,
200209 request. output_dimension ,
201210 ) ;
202- let resp = retryable:: run (
203- || async {
204- self . client
205- . post ( & url)
206- . json ( & payload)
207- . send ( )
208- . await ?
209- . error_for_status ( )
210- } ,
211- & retryable:: HEAVY_LOADED_OPTIONS ,
212- )
211+ let resp = http:: request ( || {
212+ self . client
213+ . post ( & url)
214+ . header ( "x-goog-api-key" , & self . api_key )
215+ . json ( & payload)
216+ } )
213217 . await
214218 . context ( "Gemini API error" ) ?;
215- let embedding_resp: EmbedContentResponse = resp. json ( ) . await . context ( "Invalid JSON" ) ?;
219+ let embedding_resp: BatchEmbedContentResponse =
220+ resp. json ( ) . await . context ( "Invalid JSON" ) ?;
216221 Ok ( super :: LlmEmbeddingResponse {
217- embedding : embedding_resp. embedding . values ,
222+ embeddings : embedding_resp
223+ . embeddings
224+ . into_iter ( )
225+ . map ( |e| e. values )
226+ . collect ( ) ,
218227 } )
219228 }
220229
@@ -381,15 +390,20 @@ impl LlmEmbeddingClient for VertexAiClient {
381390 request : super :: LlmEmbeddingRequest < ' req > ,
382391 ) -> Result < super :: LlmEmbeddingResponse > {
383392 // Create the instances for the request
384- let mut instance = serde_json:: json!( {
385- "content" : request. text
386- } ) ;
387- // Add task type if specified
388- if let Some ( task_type) = & request. task_type {
389- instance[ "task_type" ] = serde_json:: Value :: String ( task_type. to_string ( ) ) ;
390- }
391-
392- let instances = vec ! [ instance] ;
393+ let instances: Vec < _ > = request
394+ . texts
395+ . iter ( )
396+ . map ( |text| {
397+ let mut instance = serde_json:: json!( {
398+ "content" : text
399+ } ) ;
400+ // Add task type if specified
401+ if let Some ( task_type) = & request. task_type {
402+ instance[ "task_type" ] = serde_json:: Value :: String ( task_type. to_string ( ) ) ;
403+ }
404+ instance
405+ } )
406+ . collect ( ) ;
393407
394408 // Prepare the request parameters
395409 let mut parameters = serde_json:: json!( { } ) ;
@@ -408,17 +422,20 @@ impl LlmEmbeddingClient for VertexAiClient {
408422 . send ( )
409423 . await ?;
410424
411- // Extract the embedding from the response
412- let embeddings = response
425+ // Extract the embeddings from the response
426+ let embeddings: Vec < Vec < f32 > > = response
413427 . predictions
414428 . into_iter ( )
415- . next ( )
416- . and_then ( |mut e| e. get_mut ( "embeddings" ) . map ( |v| v. take ( ) ) )
417- . ok_or_else ( || anyhow:: anyhow!( "No embeddings in response" ) ) ?;
418- let embedding: ContentEmbedding = utils:: deser:: from_json_value ( embeddings) ?;
419- Ok ( super :: LlmEmbeddingResponse {
420- embedding : embedding. values ,
421- } )
429+ . map ( |mut prediction| {
430+ let embeddings = prediction
431+ . get_mut ( "embeddings" )
432+ . map ( |v| v. take ( ) )
433+ . ok_or_else ( || anyhow:: anyhow!( "No embeddings in prediction" ) ) ?;
434+ let embedding: ContentEmbedding = utils:: deser:: from_json_value ( embeddings) ?;
435+ Ok ( embedding. values )
436+ } )
437+ . collect :: < Result < _ > > ( ) ?;
438+ Ok ( super :: LlmEmbeddingResponse { embeddings } )
422439 }
423440
424441 fn get_default_embedding_dimension ( & self , model : & str ) -> Option < u32 > {
0 commit comments