Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 71 additions & 54 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,44 @@ fn remove_additional_properties(value: &mut Value) {
impl AiStudioClient {
fn get_api_url(&self, model: &str, api_name: &str) -> String {
format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:{}?key={}",
"https://generativelanguage.googleapis.com/v1beta/models/{}:{}",
encode(model),
api_name,
encode(&self.api_key)
api_name
)
}
}

fn build_embed_payload(
model: &str,
text: &str,
texts: &[&str],
task_type: Option<&str>,
output_dimension: Option<u32>,
) -> serde_json::Value {
let mut payload = serde_json::json!({
"model": model,
"content": { "parts": [{ "text": text }] },
});
if let Some(task_type) = task_type {
payload["taskType"] = serde_json::Value::String(task_type.to_string());
}
if let Some(output_dimension) = output_dimension {
payload["outputDimensionality"] = serde_json::json!(output_dimension);
if model.starts_with("gemini-embedding-") {
payload["config"] = serde_json::json!({
"outputDimensionality": output_dimension,
let requests: Vec<_> = texts
.iter()
.map(|text| {
let mut req = serde_json::json!({
"model": format!("models/{}", model),
"content": { "parts": [{ "text": text }] },
});
}
}
payload
if let Some(task_type) = task_type {
req["taskType"] = serde_json::Value::String(task_type.to_string());
}
if let Some(output_dimension) = output_dimension {
req["outputDimensionality"] = serde_json::json!(output_dimension);
if model.starts_with("gemini-embedding-") {
req["config"] = serde_json::json!({
"outputDimensionality": output_dimension,
});
}
}
req
})
.collect();

serde_json::json!({
"requests": requests,
})
}

#[async_trait]
Expand Down Expand Up @@ -182,8 +190,8 @@ struct ContentEmbedding {
values: Vec<f32>,
}
#[derive(Deserialize)]
struct EmbedContentResponse {
embedding: ContentEmbedding,
struct BatchEmbedContentResponse {
embeddings: Vec<ContentEmbedding>,
}

#[async_trait]
Expand All @@ -192,29 +200,30 @@ impl LlmEmbeddingClient for AiStudioClient {
&self,
request: super::LlmEmbeddingRequest<'req>,
) -> Result<super::LlmEmbeddingResponse> {
let url = self.get_api_url(request.model, "embedContent");
let url = self.get_api_url(request.model, "batchEmbedContents");
let texts: Vec<&str> = request.texts.iter().map(|t| t.as_ref()).collect();
let payload = build_embed_payload(
request.model,
request.text.as_ref(),
&texts,
request.task_type.as_deref(),
request.output_dimension,
);
let resp = retryable::run(
|| async {
self.client
.post(&url)
.json(&payload)
.send()
.await?
.error_for_status()
},
&retryable::HEAVY_LOADED_OPTIONS,
)
let resp = http::request(|| {
self.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&payload)
})
.await
.context("Gemini API error")?;
let embedding_resp: EmbedContentResponse = resp.json().await.context("Invalid JSON")?;
let embedding_resp: BatchEmbedContentResponse =
resp.json().await.context("Invalid JSON")?;
Ok(super::LlmEmbeddingResponse {
embedding: embedding_resp.embedding.values,
embeddings: embedding_resp
.embeddings
.into_iter()
.map(|e| e.values)
.collect(),
})
}

Expand Down Expand Up @@ -381,15 +390,20 @@ impl LlmEmbeddingClient for VertexAiClient {
request: super::LlmEmbeddingRequest<'req>,
) -> Result<super::LlmEmbeddingResponse> {
// Create the instances for the request
let mut instance = serde_json::json!({
"content": request.text
});
// Add task type if specified
if let Some(task_type) = &request.task_type {
instance["task_type"] = serde_json::Value::String(task_type.to_string());
}

let instances = vec![instance];
let instances: Vec<_> = request
.texts
.iter()
.map(|text| {
let mut instance = serde_json::json!({
"content": text
});
// Add task type if specified
if let Some(task_type) = &request.task_type {
instance["task_type"] = serde_json::Value::String(task_type.to_string());
}
instance
})
.collect();

// Prepare the request parameters
let mut parameters = serde_json::json!({});
Expand All @@ -408,17 +422,20 @@ impl LlmEmbeddingClient for VertexAiClient {
.send()
.await?;

// Extract the embedding from the response
let embeddings = response
// Extract the embeddings from the response
let embeddings: Vec<Vec<f32>> = response
.predictions
.into_iter()
.next()
.and_then(|mut e| e.get_mut("embeddings").map(|v| v.take()))
.ok_or_else(|| anyhow::anyhow!("No embeddings in response"))?;
let embedding: ContentEmbedding = utils::deser::from_json_value(embeddings)?;
Ok(super::LlmEmbeddingResponse {
embedding: embedding.values,
})
.map(|mut prediction| {
let embeddings = prediction
.get_mut("embeddings")
.map(|v| v.take())
.ok_or_else(|| anyhow::anyhow!("No embeddings in prediction"))?;
let embedding: ContentEmbedding = utils::deser::from_json_value(embeddings)?;
Ok(embedding.values)
})
.collect::<Result<_>>()?;
Ok(super::LlmEmbeddingResponse { embeddings })
}

fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
Expand Down
4 changes: 2 additions & 2 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ pub trait LlmGenerationClient: Send + Sync {
#[derive(Debug)]
pub struct LlmEmbeddingRequest<'a> {
pub model: &'a str,
pub text: Cow<'a, str>,
pub texts: Vec<Cow<'a, str>>,
pub output_dimension: Option<u32>,
pub task_type: Option<Cow<'a, str>>,
}

pub struct LlmEmbeddingResponse {
pub embedding: Vec<f32>,
pub embeddings: Vec<Vec<f32>>,
}

#[async_trait]
Expand Down
16 changes: 6 additions & 10 deletions src/llm/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct OllamaResponse {
#[derive(Debug, Serialize)]
struct OllamaEmbeddingRequest<'a> {
pub model: &'a str,
pub input: &'a str,
pub input: Vec<&'a str>,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -130,24 +130,20 @@ impl LlmEmbeddingClient for Client {
&self,
request: super::LlmEmbeddingRequest<'req>,
) -> Result<super::LlmEmbeddingResponse> {
let texts: Vec<&str> = request.texts.iter().map(|t| t.as_ref()).collect();
let req = OllamaEmbeddingRequest {
model: request.model,
input: request.text.as_ref(),
input: texts,
};
let resp = http::request(|| self.reqwest_client.post(self.embed_url.as_str()).json(&req))
.await
.context("Ollama API error")?;

let embedding_resp: OllamaEmbeddingResponse = resp.json().await.context("Invalid JSON")?;

// Extract the first embedding (index 0)
let embedding = embedding_resp
.embeddings
.into_iter()
.next()
.context("Ollama API returned no embeddings")?;

Ok(super::LlmEmbeddingResponse { embedding })
Ok(super::LlmEmbeddingResponse {
embeddings: embedding_resp.embeddings,
})
}

fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
Expand Down
10 changes: 3 additions & 7 deletions src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,12 @@ impl LlmEmbeddingClient for Client {
) -> Result<super::LlmEmbeddingResponse> {
let response = retryable::run(
|| async {
let texts: Vec<String> = request.texts.iter().map(|t| t.to_string()).collect();
self.client
.embeddings()
.create(CreateEmbeddingRequest {
model: request.model.to_string(),
input: EmbeddingInput::String(request.text.to_string()),
input: EmbeddingInput::StringArray(texts),
dimensions: request.output_dimension,
..Default::default()
})
Expand All @@ -198,12 +199,7 @@ impl LlmEmbeddingClient for Client {
)
.await?;
Ok(super::LlmEmbeddingResponse {
embedding: response
.data
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No embedding returned from OpenAI"))?
.embedding,
embeddings: response.data.into_iter().map(|e| e.embedding).collect(),
})
}

Expand Down
13 changes: 7 additions & 6 deletions src/llm/voyage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ impl LlmEmbeddingClient for Client {
) -> Result<LlmEmbeddingResponse> {
let url = "https://api.voyageai.com/v1/embeddings";

let texts: Vec<String> = request.texts.iter().map(|t| t.to_string()).collect();
let mut payload = serde_json::json!({
"input": request.text,
"input": texts,
"model": request.model,
});

Expand All @@ -86,12 +87,12 @@ impl LlmEmbeddingClient for Client {

let embedding_resp: EmbedResponse = resp.json().await.context("Invalid JSON")?;

if embedding_resp.data.is_empty() {
bail!("No embedding data in response");
}

Ok(LlmEmbeddingResponse {
embedding: embedding_resp.data[0].embedding.clone(),
embeddings: embedding_resp
.data
.into_iter()
.map(|d| d.embedding)
.collect(),
})
}

Expand Down
4 changes: 2 additions & 2 deletions src/ops/factory_bases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static {
None
}

fn into_fn_executor(self) -> Box<dyn SimpleFunctionExecutor> {
Box::new(BatchedFunctionExecutorWrapper::new(self))
fn into_fn_executor(self) -> impl SimpleFunctionExecutor {
BatchedFunctionExecutorWrapper::new(self)
}
}

Expand Down
65 changes: 42 additions & 23 deletions src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct Executor {
}

#[async_trait]
impl SimpleFunctionExecutor for Executor {
impl BatchedFunctionExecutor for Executor {
fn behavior_version(&self) -> Option<u32> {
self.args.client.behavior_version()
}
Expand All @@ -36,37 +36,56 @@ impl SimpleFunctionExecutor for Executor {
true
}

async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let text = self.args.text.value(&input)?.as_str()?;
async fn evaluate_batch(&self, args: Vec<Vec<Value>>) -> Result<Vec<Value>> {
let texts = args
.iter()
.map(|arg| {
Ok(Cow::Borrowed(
self.args.text.value(&arg)?.as_str()?.as_ref(),
))
})
.collect::<Result<_>>()?;
let req = LlmEmbeddingRequest {
model: &self.spec.model,
text: Cow::Borrowed(text),
texts,
output_dimension: self.spec.output_dimension,
task_type: self
.spec
.task_type
.as_ref()
.map(|s| Cow::Borrowed(s.as_str())),
};
let embedding = self.args.client.embed_text(req).await?;
if embedding.embedding.len() != self.args.expected_output_dimension {
if self.spec.output_dimension.is_some() {
api_bail!(
"Expected output dimension {expected} but got {actual} from the embedding API. \
Consider setting `output_dimension` to {actual} or leave it unset to use the default.",
expected = self.args.expected_output_dimension,
actual = embedding.embedding.len()
);
} else {
bail!(
"Expected output dimension {expected} but got {actual} from the embedding API. \
Consider setting `output_dimension` to {actual} as a workaround.",
expected = self.args.expected_output_dimension,
actual = embedding.embedding.len()
)
}
let resp = self.args.client.embed_text(req).await?;
if resp.embeddings.len() != args.len() {
api_bail!(
"Expected {expected} embeddings but got {actual} from the embedding API.",
expected = args.len(),
actual = resp.embeddings.len()
);
}
Ok(embedding.embedding.into())
resp.embeddings
.into_iter()
.map(|embedding| {
if embedding.len() != self.args.expected_output_dimension {
if self.spec.output_dimension.is_some() {
api_bail!(
"Expected output dimension {expected} but got {actual} from the embedding API. \
Consider setting `output_dimension` to {actual} or leave it unset to use the default.",
expected = self.args.expected_output_dimension,
actual = embedding.len(),
);
} else {
bail!(
"Expected output dimension {expected} but got {actual} from the embedding API. \
Consider setting `output_dimension` to {actual} as a workaround.",
expected = self.args.expected_output_dimension,
actual = embedding.len(),
);
}
};
Ok(embedding.into())
})
.collect::<Result<Vec<value::Value>>>()
}
}

Expand Down Expand Up @@ -121,7 +140,7 @@ impl SimpleFunctionFactoryBase for Factory {
args: Args,
_context: Arc<FlowInstanceContext>,
) -> Result<impl SimpleFunctionExecutor> {
Ok(Executor { spec, args })
Ok(Executor { spec, args }.into_fn_executor())
}
}

Expand Down
Loading
Loading