From 3223c8407c5fe1d7ed06f4c3df21914a904b0935 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Wed, 24 Sep 2025 18:43:58 -0700 Subject: [PATCH 1/2] fix(gemini): support `output_dimension` for Gemini --- src/llm/gemini.rs | 26 +++++++++++++++++--------- src/llm/mod.rs | 4 ++++ src/ops/functions/embed_text.rs | 2 +- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index d8cec6071..d246ad866 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -181,18 +181,22 @@ impl LlmEmbeddingClient for AiStudioClient { if let Some(task_type) = request.task_type { payload["taskType"] = serde_json::Value::String(task_type.into()); } + if let Some(output_dimension) = request.output_dimension { + payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into()); + } let resp = retryable::run( - || self.client.post(&url).json(&payload).send(), + || async { + self.client + .post(&url) + .json(&payload) + .send() + .await? + .error_for_status() + }, &retryable::HEAVY_LOADED_OPTIONS, ) - .await?; - if !resp.status().is_success() { - bail!( - "Gemini API error: {:?}\n{}\n", - resp.status(), - resp.text().await? - ); - } + .await + .context("Gemini API error")?; let embedding_resp: EmbedContentResponse = resp.json().await.context("Invalid JSON")?; Ok(super::LlmEmbeddingResponse { embedding: embedding_resp.embedding.values, @@ -202,6 +206,10 @@ impl LlmEmbeddingClient for AiStudioClient { fn get_default_embedding_dimension(&self, model: &str) -> Option { get_embedding_dimension(model) } + + fn behavior_version(&self) -> Option { + Some(2) + } } pub struct VertexAiClient { diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 184a9ab97..2145f2d12 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -99,6 +99,10 @@ pub trait LlmEmbeddingClient: Send + Sync { ) -> Result; fn get_default_embedding_dimension(&self, model: &str) -> Option; + + fn behavior_version(&self) -> Option { + Some(1) + } } mod anthropic; diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index 23da89f1b..95eee1791 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -28,7 +28,7 @@ struct Executor { #[async_trait] impl SimpleFunctionExecutor for Executor { fn behavior_version(&self) -> Option { - Some(1) + self.args.client.behavior_version() } fn enable_cache(&self) -> bool { From 2a73c6166b041a93bd0b5f5c2cab9c8a9f554fae Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Wed, 24 Sep 2025 19:04:32 -0700 Subject: [PATCH 2/2] fix: validate output vector dimension --- src/ops/functions/embed_text.rs | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index 95eee1791..bd8701581 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -18,6 +18,7 @@ struct Spec { struct Args { client: Box, text: ResolvedOpArg, + expected_output_dimension: usize, } struct Executor { @@ -48,6 +49,23 @@ impl SimpleFunctionExecutor for Executor { .map(|s| Cow::Borrowed(s.as_str())), }; let embedding = self.args.client.embed_text(req).await?; + if embedding.embedding.len() != self.args.expected_output_dimension { + if self.spec.output_dimension.is_some() { + api_bail!( + "Expected output dimension {expected} but got {actual} from the embedding API. \ + Consider setting `output_dimension` to {actual} or leave it unset to use the default.", + expected = self.args.expected_output_dimension, + actual = embedding.embedding.len() + ); + } else { + bail!( + "Expected output dimension {expected} but got {actual} from the embedding API. \ + Consider setting `output_dimension` to {actual} as a workaround.", + expected = self.args.expected_output_dimension, + actual = embedding.embedding.len() + ) + } + } Ok(embedding.embedding.into()) } } @@ -87,7 +105,14 @@ impl SimpleFunctionFactoryBase for Factory { dimension: Some(output_dimension as usize), element_type: Box::new(BasicValueType::Float32), })); - Ok((Args { client, text }, output_schema)) + Ok(( + Args { + client, + text, + expected_output_dimension: output_dimension as usize, + }, + output_schema, + )) } async fn build_executor(