diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 1d1ffafb..0e6ce1bc 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -70,36 +70,44 @@ fn remove_additional_properties(value: &mut Value) { impl AiStudioClient { fn get_api_url(&self, model: &str, api_name: &str) -> String { format!( - "https://generativelanguage.googleapis.com/v1beta/models/{}:{}?key={}", + "https://generativelanguage.googleapis.com/v1beta/models/{}:{}", encode(model), - api_name, - encode(&self.api_key) + api_name ) } } fn build_embed_payload( model: &str, - text: &str, + texts: &[&str], task_type: Option<&str>, output_dimension: Option, ) -> serde_json::Value { - let mut payload = serde_json::json!({ - "model": model, - "content": { "parts": [{ "text": text }] }, - }); - if let Some(task_type) = task_type { - payload["taskType"] = serde_json::Value::String(task_type.to_string()); - } - if let Some(output_dimension) = output_dimension { - payload["outputDimensionality"] = serde_json::json!(output_dimension); - if model.starts_with("gemini-embedding-") { - payload["config"] = serde_json::json!({ - "outputDimensionality": output_dimension, + let requests: Vec<_> = texts + .iter() + .map(|text| { + let mut req = serde_json::json!({ + "model": format!("models/{}", model), + "content": { "parts": [{ "text": text }] }, }); - } - } - payload + if let Some(task_type) = task_type { + req["taskType"] = serde_json::Value::String(task_type.to_string()); + } + if let Some(output_dimension) = output_dimension { + req["outputDimensionality"] = serde_json::json!(output_dimension); + if model.starts_with("gemini-embedding-") { + req["config"] = serde_json::json!({ + "outputDimensionality": output_dimension, + }); + } + } + req + }) + .collect(); + + serde_json::json!({ + "requests": requests, + }) } #[async_trait] @@ -182,8 +190,8 @@ struct ContentEmbedding { values: Vec, } #[derive(Deserialize)] -struct EmbedContentResponse { - embedding: ContentEmbedding, +struct BatchEmbedContentResponse { + embeddings: Vec, } #[async_trait] @@ -192,29 +200,30 @@ impl LlmEmbeddingClient for AiStudioClient { &self, request: super::LlmEmbeddingRequest<'req>, ) -> Result { - let url = self.get_api_url(request.model, "embedContent"); + let url = self.get_api_url(request.model, "batchEmbedContents"); + let texts: Vec<&str> = request.texts.iter().map(|t| t.as_ref()).collect(); let payload = build_embed_payload( request.model, - request.text.as_ref(), + &texts, request.task_type.as_deref(), request.output_dimension, ); - let resp = retryable::run( - || async { - self.client - .post(&url) - .json(&payload) - .send() - .await? - .error_for_status() - }, - &retryable::HEAVY_LOADED_OPTIONS, - ) + let resp = http::request(|| { + self.client + .post(&url) + .header("x-goog-api-key", &self.api_key) + .json(&payload) + }) .await .context("Gemini API error")?; - let embedding_resp: EmbedContentResponse = resp.json().await.context("Invalid JSON")?; + let embedding_resp: BatchEmbedContentResponse = + resp.json().await.context("Invalid JSON")?; Ok(super::LlmEmbeddingResponse { - embedding: embedding_resp.embedding.values, + embeddings: embedding_resp + .embeddings + .into_iter() + .map(|e| e.values) + .collect(), }) } @@ -381,15 +390,20 @@ impl LlmEmbeddingClient for VertexAiClient { request: super::LlmEmbeddingRequest<'req>, ) -> Result { // Create the instances for the request - let mut instance = serde_json::json!({ - "content": request.text - }); - // Add task type if specified - if let Some(task_type) = &request.task_type { - instance["task_type"] = serde_json::Value::String(task_type.to_string()); - } - - let instances = vec![instance]; + let instances: Vec<_> = request + .texts + .iter() + .map(|text| { + let mut instance = serde_json::json!({ + "content": text + }); + // Add task type if specified + if let Some(task_type) = &request.task_type { + instance["task_type"] = serde_json::Value::String(task_type.to_string()); + } + instance + }) + .collect(); // Prepare the request parameters let mut parameters = serde_json::json!({}); @@ -408,17 +422,20 @@ impl LlmEmbeddingClient for VertexAiClient { .send() .await?; - // Extract the embedding from the response - let embeddings = response + // Extract the embeddings from the response + let embeddings: Vec> = response .predictions .into_iter() - .next() - .and_then(|mut e| e.get_mut("embeddings").map(|v| v.take())) - .ok_or_else(|| anyhow::anyhow!("No embeddings in response"))?; - let embedding: ContentEmbedding = utils::deser::from_json_value(embeddings)?; - Ok(super::LlmEmbeddingResponse { - embedding: embedding.values, - }) + .map(|mut prediction| { + let embeddings = prediction + .get_mut("embeddings") + .map(|v| v.take()) + .ok_or_else(|| anyhow::anyhow!("No embeddings in prediction"))?; + let embedding: ContentEmbedding = utils::deser::from_json_value(embeddings)?; + Ok(embedding.values) + }) + .collect::>()?; + Ok(super::LlmEmbeddingResponse { embeddings }) } fn get_default_embedding_dimension(&self, model: &str) -> Option { diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 12eda662..00df51a2 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -83,13 +83,13 @@ pub trait LlmGenerationClient: Send + Sync { #[derive(Debug)] pub struct LlmEmbeddingRequest<'a> { pub model: &'a str, - pub text: Cow<'a, str>, + pub texts: Vec>, pub output_dimension: Option, pub task_type: Option>, } pub struct LlmEmbeddingResponse { - pub embedding: Vec, + pub embeddings: Vec>, } #[async_trait] diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index 4b3c6c0f..b02a6ddc 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -59,7 +59,7 @@ struct OllamaResponse { #[derive(Debug, Serialize)] struct OllamaEmbeddingRequest<'a> { pub model: &'a str, - pub input: &'a str, + pub input: Vec<&'a str>, } #[derive(Debug, Deserialize)] @@ -130,9 +130,10 @@ impl LlmEmbeddingClient for Client { &self, request: super::LlmEmbeddingRequest<'req>, ) -> Result { + let texts: Vec<&str> = request.texts.iter().map(|t| t.as_ref()).collect(); let req = OllamaEmbeddingRequest { model: request.model, - input: request.text.as_ref(), + input: texts, }; let resp = http::request(|| self.reqwest_client.post(self.embed_url.as_str()).json(&req)) .await @@ -140,14 +141,9 @@ impl LlmEmbeddingClient for Client { let embedding_resp: OllamaEmbeddingResponse = resp.json().await.context("Invalid JSON")?; - // Extract the first embedding (index 0) - let embedding = embedding_resp - .embeddings - .into_iter() - .next() - .context("Ollama API returned no embeddings")?; - - Ok(super::LlmEmbeddingResponse { embedding }) + Ok(super::LlmEmbeddingResponse { + embeddings: embedding_resp.embeddings, + }) } fn get_default_embedding_dimension(&self, model: &str) -> Option { diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 68ec6421..cf22c54d 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -184,11 +184,12 @@ impl LlmEmbeddingClient for Client { ) -> Result { let response = retryable::run( || async { + let texts: Vec = request.texts.iter().map(|t| t.to_string()).collect(); self.client .embeddings() .create(CreateEmbeddingRequest { model: request.model.to_string(), - input: EmbeddingInput::String(request.text.to_string()), + input: EmbeddingInput::StringArray(texts), dimensions: request.output_dimension, ..Default::default() }) @@ -198,12 +199,7 @@ impl LlmEmbeddingClient for Client { ) .await?; Ok(super::LlmEmbeddingResponse { - embedding: response - .data - .into_iter() - .next() - .ok_or_else(|| anyhow::anyhow!("No embedding returned from OpenAI"))? - .embedding, + embeddings: response.data.into_iter().map(|e| e.embedding).collect(), }) } diff --git a/src/llm/voyage.rs b/src/llm/voyage.rs index 40972d8b..304ec41a 100644 --- a/src/llm/voyage.rs +++ b/src/llm/voyage.rs @@ -66,8 +66,9 @@ impl LlmEmbeddingClient for Client { ) -> Result { let url = "https://api.voyageai.com/v1/embeddings"; + let texts: Vec = request.texts.iter().map(|t| t.to_string()).collect(); let mut payload = serde_json::json!({ - "input": request.text, + "input": texts, "model": request.model, }); @@ -86,12 +87,12 @@ impl LlmEmbeddingClient for Client { let embedding_resp: EmbedResponse = resp.json().await.context("Invalid JSON")?; - if embedding_resp.data.is_empty() { - bail!("No embedding data in response"); - } - Ok(LlmEmbeddingResponse { - embedding: embedding_resp.data[0].embedding.clone(), + embeddings: embedding_resp + .data + .into_iter() + .map(|d| d.embedding) + .collect(), }) } diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index b6e5426e..acaa3e8d 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -377,8 +377,8 @@ pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static { None } - fn into_fn_executor(self) -> Box { - Box::new(BatchedFunctionExecutorWrapper::new(self)) + fn into_fn_executor(self) -> impl SimpleFunctionExecutor { + BatchedFunctionExecutorWrapper::new(self) } } diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index bd870158..2efb3fa5 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -27,7 +27,7 @@ struct Executor { } #[async_trait] -impl SimpleFunctionExecutor for Executor { +impl BatchedFunctionExecutor for Executor { fn behavior_version(&self) -> Option { self.args.client.behavior_version() } @@ -36,11 +36,18 @@ impl SimpleFunctionExecutor for Executor { true } - async fn evaluate(&self, input: Vec) -> Result { - let text = self.args.text.value(&input)?.as_str()?; + async fn evaluate_batch(&self, args: Vec>) -> Result> { + let texts = args + .iter() + .map(|arg| { + Ok(Cow::Borrowed( + self.args.text.value(&arg)?.as_str()?.as_ref(), + )) + }) + .collect::>()?; let req = LlmEmbeddingRequest { model: &self.spec.model, - text: Cow::Borrowed(text), + texts, output_dimension: self.spec.output_dimension, task_type: self .spec @@ -48,25 +55,37 @@ impl SimpleFunctionExecutor for Executor { .as_ref() .map(|s| Cow::Borrowed(s.as_str())), }; - let embedding = self.args.client.embed_text(req).await?; - if embedding.embedding.len() != self.args.expected_output_dimension { - if self.spec.output_dimension.is_some() { - api_bail!( - "Expected output dimension {expected} but got {actual} from the embedding API. \ - Consider setting `output_dimension` to {actual} or leave it unset to use the default.", - expected = self.args.expected_output_dimension, - actual = embedding.embedding.len() - ); - } else { - bail!( - "Expected output dimension {expected} but got {actual} from the embedding API. \ - Consider setting `output_dimension` to {actual} as a workaround.", - expected = self.args.expected_output_dimension, - actual = embedding.embedding.len() - ) - } + let resp = self.args.client.embed_text(req).await?; + if resp.embeddings.len() != args.len() { + api_bail!( + "Expected {expected} embeddings but got {actual} from the embedding API.", + expected = args.len(), + actual = resp.embeddings.len() + ); } - Ok(embedding.embedding.into()) + resp.embeddings + .into_iter() + .map(|embedding| { + if embedding.len() != self.args.expected_output_dimension { + if self.spec.output_dimension.is_some() { + api_bail!( + "Expected output dimension {expected} but got {actual} from the embedding API. \ + Consider setting `output_dimension` to {actual} or leave it unset to use the default.", + expected = self.args.expected_output_dimension, + actual = embedding.len(), + ); + } else { + bail!( + "Expected output dimension {expected} but got {actual} from the embedding API. \ + Consider setting `output_dimension` to {actual} as a workaround.", + expected = self.args.expected_output_dimension, + actual = embedding.len(), + ); + } + }; + Ok(embedding.into()) + }) + .collect::>>() } } @@ -121,7 +140,7 @@ impl SimpleFunctionFactoryBase for Factory { args: Args, _context: Arc, ) -> Result { - Ok(Executor { spec, args }) + Ok(Executor { spec, args }.into_fn_executor()) } } diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index be16b5fb..885797f9 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -260,15 +260,17 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory { Ok((prepare_fut, enable_cache, behavior_version)) })?; prepare_fut.await?; - let executor = if self.batching { - PyBatchedFunctionExecutor { - py_function_executor: executor, - py_exec_ctx, - result_type, - enable_cache, - behavior_version, - } - .into_fn_executor() + let executor: Box = if self.batching { + Box::new( + PyBatchedFunctionExecutor { + py_function_executor: executor, + py_exec_ctx, + result_type, + enable_cache, + behavior_version, + } + .into_fn_executor(), + ) } else { Box::new(Arc::new(PyFunctionExecutor { py_function_executor: executor, @@ -278,7 +280,7 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory { result_type, enable_cache, behavior_version, - })) as Box + })) }; Ok(executor) }