Skip to content

Commit b9dfc7e

Browse files
committed
fixing issue of gemini-embedding-001 wrt outputDimensionality
1 parent 416962e commit b9dfc7e

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

src/llm/gemini.rs

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,30 @@ impl AiStudioClient {
7474
}
7575
}
7676

77+
fn build_embed_payload(
78+
model: &str,
79+
text: &str,
80+
task_type: Option<&str>,
81+
output_dimension: Option<u32>,
82+
) -> serde_json::Value {
83+
let mut payload = serde_json::json!({
84+
"model": model,
85+
"content": { "parts": [{ "text": text }] },
86+
});
87+
if let Some(task_type) = task_type {
88+
payload["taskType"] = serde_json::Value::String(task_type.to_string());
89+
}
90+
if let Some(output_dimension) = output_dimension {
91+
payload["outputDimensionality"] = serde_json::json!(output_dimension);
92+
if model.starts_with("gemini-embedding-") {
93+
payload["config"] = serde_json::json!({
94+
"outputDimensionality": output_dimension,
95+
});
96+
}
97+
}
98+
payload
99+
}
100+
77101
#[async_trait]
78102
impl LlmGenerationClient for AiStudioClient {
79103
async fn generate<'req>(
@@ -174,16 +198,12 @@ impl LlmEmbeddingClient for AiStudioClient {
174198
request: super::LlmEmbeddingRequest<'req>,
175199
) -> Result<super::LlmEmbeddingResponse> {
176200
let url = self.get_api_url(request.model, "embedContent");
177-
let mut payload = serde_json::json!({
178-
"model": request.model,
179-
"content": { "parts": [{ "text": request.text }] },
180-
});
181-
if let Some(task_type) = request.task_type {
182-
payload["taskType"] = serde_json::Value::String(task_type.into());
183-
}
184-
if let Some(output_dimension) = request.output_dimension {
185-
payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
186-
}
201+
let payload = build_embed_payload(
202+
request.model,
203+
request.text.as_ref(),
204+
request.task_type.as_deref(),
205+
request.output_dimension,
206+
);
187207
let resp = retryable::run(
188208
|| async {
189209
self.client

0 commit comments

Comments
 (0)