Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,22 @@ impl LlmEmbeddingClient for AiStudioClient {
if let Some(task_type) = request.task_type {
payload["taskType"] = serde_json::Value::String(task_type.into());
}
if let Some(output_dimension) = request.output_dimension {
payload["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
}
let resp = retryable::run(
|| self.client.post(&url).json(&payload).send(),
|| async {
self.client
.post(&url)
.json(&payload)
.send()
.await?
.error_for_status()
},
&retryable::HEAVY_LOADED_OPTIONS,
)
.await?;
if !resp.status().is_success() {
bail!(
"Gemini API error: {:?}\n{}\n",
resp.status(),
resp.text().await?
);
}
.await
.context("Gemini API error")?;
let embedding_resp: EmbedContentResponse = resp.json().await.context("Invalid JSON")?;
Ok(super::LlmEmbeddingResponse {
embedding: embedding_resp.embedding.values,
Expand All @@ -202,6 +206,10 @@ impl LlmEmbeddingClient for AiStudioClient {
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
get_embedding_dimension(model)
}

fn behavior_version(&self) -> Option<u32> {
Some(2)
}
}

pub struct VertexAiClient {
Expand Down
4 changes: 4 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ pub trait LlmEmbeddingClient: Send + Sync {
) -> Result<LlmEmbeddingResponse>;

fn get_default_embedding_dimension(&self, model: &str) -> Option<u32>;

fn behavior_version(&self) -> Option<u32> {
Some(1)
}
}

mod anthropic;
Expand Down
29 changes: 27 additions & 2 deletions src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct Spec {
struct Args {
client: Box<dyn LlmEmbeddingClient>,
text: ResolvedOpArg,
expected_output_dimension: usize,
}

struct Executor {
Expand All @@ -28,7 +29,7 @@ struct Executor {
#[async_trait]
impl SimpleFunctionExecutor for Executor {
fn behavior_version(&self) -> Option<u32> {
Some(1)
self.args.client.behavior_version()
}

fn enable_cache(&self) -> bool {
Expand All @@ -48,6 +49,23 @@ impl SimpleFunctionExecutor for Executor {
.map(|s| Cow::Borrowed(s.as_str())),
};
let embedding = self.args.client.embed_text(req).await?;
if embedding.embedding.len() != self.args.expected_output_dimension {
if self.spec.output_dimension.is_some() {
api_bail!(
"Expected output dimension {expected} but got {actual} from the embedding API. \
Consider setting `output_dimension` to {actual} or leave it unset to use the default.",
expected = self.args.expected_output_dimension,
actual = embedding.embedding.len()
);
} else {
bail!(
"Expected output dimension {expected} but got {actual} from the embedding API. \
Consider setting `output_dimension` to {actual} as a workaround.",
expected = self.args.expected_output_dimension,
actual = embedding.embedding.len()
)
}
}
Ok(embedding.embedding.into())
}
}
Expand Down Expand Up @@ -87,7 +105,14 @@ impl SimpleFunctionFactoryBase for Factory {
dimension: Some(output_dimension as usize),
element_type: Box::new(BasicValueType::Float32),
}));
Ok((Args { client, text }, output_schema))
Ok((
Args {
client,
text,
expected_output_dimension: output_dimension as usize,
},
output_schema,
))
}

async fn build_executor(
Expand Down
Loading