Skip to content

Commit 194dd90

Browse files
authored
feat(batching): support batching for embedding APIs (#1236)
1 parent b6a8ce3 commit 194dd90

File tree

8 files changed

+145
-114
lines changed

8 files changed

+145
-114
lines changed

src/llm/gemini.rs

Lines changed: 71 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -70,36 +70,44 @@ fn remove_additional_properties(value: &mut Value) {
7070
impl AiStudioClient {
7171
fn get_api_url(&self, model: &str, api_name: &str) -> String {
7272
format!(
73-
"https://generativelanguage.googleapis.com/v1beta/models/{}:{}?key={}",
73+
"https://generativelanguage.googleapis.com/v1beta/models/{}:{}",
7474
encode(model),
75-
api_name,
76-
encode(&self.api_key)
75+
api_name
7776
)
7877
}
7978
}
8079

8180
fn build_embed_payload(
8281
model: &str,
83-
text: &str,
82+
texts: &[&str],
8483
task_type: Option<&str>,
8584
output_dimension: Option<u32>,
8685
) -> serde_json::Value {
87-
let mut payload = serde_json::json!({
88-
"model": model,
89-
"content": { "parts": [{ "text": text }] },
90-
});
91-
if let Some(task_type) = task_type {
92-
payload["taskType"] = serde_json::Value::String(task_type.to_string());
93-
}
94-
if let Some(output_dimension) = output_dimension {
95-
payload["outputDimensionality"] = serde_json::json!(output_dimension);
96-
if model.starts_with("gemini-embedding-") {
97-
payload["config"] = serde_json::json!({
98-
"outputDimensionality": output_dimension,
86+
let requests: Vec<_> = texts
87+
.iter()
88+
.map(|text| {
89+
let mut req = serde_json::json!({
90+
"model": format!("models/{}", model),
91+
"content": { "parts": [{ "text": text }] },
9992
});
100-
}
101-
}
102-
payload
93+
if let Some(task_type) = task_type {
94+
req["taskType"] = serde_json::Value::String(task_type.to_string());
95+
}
96+
if let Some(output_dimension) = output_dimension {
97+
req["outputDimensionality"] = serde_json::json!(output_dimension);
98+
if model.starts_with("gemini-embedding-") {
99+
req["config"] = serde_json::json!({
100+
"outputDimensionality": output_dimension,
101+
});
102+
}
103+
}
104+
req
105+
})
106+
.collect();
107+
108+
serde_json::json!({
109+
"requests": requests,
110+
})
103111
}
104112

105113
#[async_trait]
@@ -182,8 +190,8 @@ struct ContentEmbedding {
182190
values: Vec<f32>,
183191
}
184192
#[derive(Deserialize)]
185-
struct EmbedContentResponse {
186-
embedding: ContentEmbedding,
193+
struct BatchEmbedContentResponse {
194+
embeddings: Vec<ContentEmbedding>,
187195
}
188196

189197
#[async_trait]
@@ -192,29 +200,30 @@ impl LlmEmbeddingClient for AiStudioClient {
192200
&self,
193201
request: super::LlmEmbeddingRequest<'req>,
194202
) -> Result<super::LlmEmbeddingResponse> {
195-
let url = self.get_api_url(request.model, "embedContent");
203+
let url = self.get_api_url(request.model, "batchEmbedContents");
204+
let texts: Vec<&str> = request.texts.iter().map(|t| t.as_ref()).collect();
196205
let payload = build_embed_payload(
197206
request.model,
198-
request.text.as_ref(),
207+
&texts,
199208
request.task_type.as_deref(),
200209
request.output_dimension,
201210
);
202-
let resp = retryable::run(
203-
|| async {
204-
self.client
205-
.post(&url)
206-
.json(&payload)
207-
.send()
208-
.await?
209-
.error_for_status()
210-
},
211-
&retryable::HEAVY_LOADED_OPTIONS,
212-
)
211+
let resp = http::request(|| {
212+
self.client
213+
.post(&url)
214+
.header("x-goog-api-key", &self.api_key)
215+
.json(&payload)
216+
})
213217
.await
214218
.context("Gemini API error")?;
215-
let embedding_resp: EmbedContentResponse = resp.json().await.context("Invalid JSON")?;
219+
let embedding_resp: BatchEmbedContentResponse =
220+
resp.json().await.context("Invalid JSON")?;
216221
Ok(super::LlmEmbeddingResponse {
217-
embedding: embedding_resp.embedding.values,
222+
embeddings: embedding_resp
223+
.embeddings
224+
.into_iter()
225+
.map(|e| e.values)
226+
.collect(),
218227
})
219228
}
220229

@@ -381,15 +390,20 @@ impl LlmEmbeddingClient for VertexAiClient {
381390
request: super::LlmEmbeddingRequest<'req>,
382391
) -> Result<super::LlmEmbeddingResponse> {
383392
// Create the instances for the request
384-
let mut instance = serde_json::json!({
385-
"content": request.text
386-
});
387-
// Add task type if specified
388-
if let Some(task_type) = &request.task_type {
389-
instance["task_type"] = serde_json::Value::String(task_type.to_string());
390-
}
391-
392-
let instances = vec![instance];
393+
let instances: Vec<_> = request
394+
.texts
395+
.iter()
396+
.map(|text| {
397+
let mut instance = serde_json::json!({
398+
"content": text
399+
});
400+
// Add task type if specified
401+
if let Some(task_type) = &request.task_type {
402+
instance["task_type"] = serde_json::Value::String(task_type.to_string());
403+
}
404+
instance
405+
})
406+
.collect();
393407

394408
// Prepare the request parameters
395409
let mut parameters = serde_json::json!({});
@@ -408,17 +422,20 @@ impl LlmEmbeddingClient for VertexAiClient {
408422
.send()
409423
.await?;
410424

411-
// Extract the embedding from the response
412-
let embeddings = response
425+
// Extract the embeddings from the response
426+
let embeddings: Vec<Vec<f32>> = response
413427
.predictions
414428
.into_iter()
415-
.next()
416-
.and_then(|mut e| e.get_mut("embeddings").map(|v| v.take()))
417-
.ok_or_else(|| anyhow::anyhow!("No embeddings in response"))?;
418-
let embedding: ContentEmbedding = utils::deser::from_json_value(embeddings)?;
419-
Ok(super::LlmEmbeddingResponse {
420-
embedding: embedding.values,
421-
})
429+
.map(|mut prediction| {
430+
let embeddings = prediction
431+
.get_mut("embeddings")
432+
.map(|v| v.take())
433+
.ok_or_else(|| anyhow::anyhow!("No embeddings in prediction"))?;
434+
let embedding: ContentEmbedding = utils::deser::from_json_value(embeddings)?;
435+
Ok(embedding.values)
436+
})
437+
.collect::<Result<_>>()?;
438+
Ok(super::LlmEmbeddingResponse { embeddings })
422439
}
423440

424441
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {

src/llm/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ pub trait LlmGenerationClient: Send + Sync {
8383
#[derive(Debug)]
8484
pub struct LlmEmbeddingRequest<'a> {
8585
pub model: &'a str,
86-
pub text: Cow<'a, str>,
86+
pub texts: Vec<Cow<'a, str>>,
8787
pub output_dimension: Option<u32>,
8888
pub task_type: Option<Cow<'a, str>>,
8989
}
9090

9191
pub struct LlmEmbeddingResponse {
92-
pub embedding: Vec<f32>,
92+
pub embeddings: Vec<Vec<f32>>,
9393
}
9494

9595
#[async_trait]

src/llm/ollama.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct OllamaResponse {
5959
#[derive(Debug, Serialize)]
6060
struct OllamaEmbeddingRequest<'a> {
6161
pub model: &'a str,
62-
pub input: &'a str,
62+
pub input: Vec<&'a str>,
6363
}
6464

6565
#[derive(Debug, Deserialize)]
@@ -130,24 +130,20 @@ impl LlmEmbeddingClient for Client {
130130
&self,
131131
request: super::LlmEmbeddingRequest<'req>,
132132
) -> Result<super::LlmEmbeddingResponse> {
133+
let texts: Vec<&str> = request.texts.iter().map(|t| t.as_ref()).collect();
133134
let req = OllamaEmbeddingRequest {
134135
model: request.model,
135-
input: request.text.as_ref(),
136+
input: texts,
136137
};
137138
let resp = http::request(|| self.reqwest_client.post(self.embed_url.as_str()).json(&req))
138139
.await
139140
.context("Ollama API error")?;
140141

141142
let embedding_resp: OllamaEmbeddingResponse = resp.json().await.context("Invalid JSON")?;
142143

143-
// Extract the first embedding (index 0)
144-
let embedding = embedding_resp
145-
.embeddings
146-
.into_iter()
147-
.next()
148-
.context("Ollama API returned no embeddings")?;
149-
150-
Ok(super::LlmEmbeddingResponse { embedding })
144+
Ok(super::LlmEmbeddingResponse {
145+
embeddings: embedding_resp.embeddings,
146+
})
151147
}
152148

153149
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {

src/llm/openai.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,12 @@ impl LlmEmbeddingClient for Client {
184184
) -> Result<super::LlmEmbeddingResponse> {
185185
let response = retryable::run(
186186
|| async {
187+
let texts: Vec<String> = request.texts.iter().map(|t| t.to_string()).collect();
187188
self.client
188189
.embeddings()
189190
.create(CreateEmbeddingRequest {
190191
model: request.model.to_string(),
191-
input: EmbeddingInput::String(request.text.to_string()),
192+
input: EmbeddingInput::StringArray(texts),
192193
dimensions: request.output_dimension,
193194
..Default::default()
194195
})
@@ -198,12 +199,7 @@ impl LlmEmbeddingClient for Client {
198199
)
199200
.await?;
200201
Ok(super::LlmEmbeddingResponse {
201-
embedding: response
202-
.data
203-
.into_iter()
204-
.next()
205-
.ok_or_else(|| anyhow::anyhow!("No embedding returned from OpenAI"))?
206-
.embedding,
202+
embeddings: response.data.into_iter().map(|e| e.embedding).collect(),
207203
})
208204
}
209205

src/llm/voyage.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ impl LlmEmbeddingClient for Client {
6666
) -> Result<LlmEmbeddingResponse> {
6767
let url = "https://api.voyageai.com/v1/embeddings";
6868

69+
let texts: Vec<String> = request.texts.iter().map(|t| t.to_string()).collect();
6970
let mut payload = serde_json::json!({
70-
"input": request.text,
71+
"input": texts,
7172
"model": request.model,
7273
});
7374

@@ -86,12 +87,12 @@ impl LlmEmbeddingClient for Client {
8687

8788
let embedding_resp: EmbedResponse = resp.json().await.context("Invalid JSON")?;
8889

89-
if embedding_resp.data.is_empty() {
90-
bail!("No embedding data in response");
91-
}
92-
9390
Ok(LlmEmbeddingResponse {
94-
embedding: embedding_resp.data[0].embedding.clone(),
91+
embeddings: embedding_resp
92+
.data
93+
.into_iter()
94+
.map(|d| d.embedding)
95+
.collect(),
9596
})
9697
}
9798

src/ops/factory_bases.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,8 @@ pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static {
377377
None
378378
}
379379

380-
fn into_fn_executor(self) -> Box<dyn SimpleFunctionExecutor> {
381-
Box::new(BatchedFunctionExecutorWrapper::new(self))
380+
fn into_fn_executor(self) -> impl SimpleFunctionExecutor {
381+
BatchedFunctionExecutorWrapper::new(self)
382382
}
383383
}
384384

src/ops/functions/embed_text.rs

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct Executor {
2727
}
2828

2929
#[async_trait]
30-
impl SimpleFunctionExecutor for Executor {
30+
impl BatchedFunctionExecutor for Executor {
3131
fn behavior_version(&self) -> Option<u32> {
3232
self.args.client.behavior_version()
3333
}
@@ -36,37 +36,56 @@ impl SimpleFunctionExecutor for Executor {
3636
true
3737
}
3838

39-
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
40-
let text = self.args.text.value(&input)?.as_str()?;
39+
async fn evaluate_batch(&self, args: Vec<Vec<Value>>) -> Result<Vec<Value>> {
40+
let texts = args
41+
.iter()
42+
.map(|arg| {
43+
Ok(Cow::Borrowed(
44+
self.args.text.value(&arg)?.as_str()?.as_ref(),
45+
))
46+
})
47+
.collect::<Result<_>>()?;
4148
let req = LlmEmbeddingRequest {
4249
model: &self.spec.model,
43-
text: Cow::Borrowed(text),
50+
texts,
4451
output_dimension: self.spec.output_dimension,
4552
task_type: self
4653
.spec
4754
.task_type
4855
.as_ref()
4956
.map(|s| Cow::Borrowed(s.as_str())),
5057
};
51-
let embedding = self.args.client.embed_text(req).await?;
52-
if embedding.embedding.len() != self.args.expected_output_dimension {
53-
if self.spec.output_dimension.is_some() {
54-
api_bail!(
55-
"Expected output dimension {expected} but got {actual} from the embedding API. \
56-
Consider setting `output_dimension` to {actual} or leave it unset to use the default.",
57-
expected = self.args.expected_output_dimension,
58-
actual = embedding.embedding.len()
59-
);
60-
} else {
61-
bail!(
62-
"Expected output dimension {expected} but got {actual} from the embedding API. \
63-
Consider setting `output_dimension` to {actual} as a workaround.",
64-
expected = self.args.expected_output_dimension,
65-
actual = embedding.embedding.len()
66-
)
67-
}
58+
let resp = self.args.client.embed_text(req).await?;
59+
if resp.embeddings.len() != args.len() {
60+
api_bail!(
61+
"Expected {expected} embeddings but got {actual} from the embedding API.",
62+
expected = args.len(),
63+
actual = resp.embeddings.len()
64+
);
6865
}
69-
Ok(embedding.embedding.into())
66+
resp.embeddings
67+
.into_iter()
68+
.map(|embedding| {
69+
if embedding.len() != self.args.expected_output_dimension {
70+
if self.spec.output_dimension.is_some() {
71+
api_bail!(
72+
"Expected output dimension {expected} but got {actual} from the embedding API. \
73+
Consider setting `output_dimension` to {actual} or leave it unset to use the default.",
74+
expected = self.args.expected_output_dimension,
75+
actual = embedding.len(),
76+
);
77+
} else {
78+
bail!(
79+
"Expected output dimension {expected} but got {actual} from the embedding API. \
80+
Consider setting `output_dimension` to {actual} as a workaround.",
81+
expected = self.args.expected_output_dimension,
82+
actual = embedding.len(),
83+
);
84+
}
85+
};
86+
Ok(embedding.into())
87+
})
88+
.collect::<Result<Vec<value::Value>>>()
7089
}
7190
}
7291

@@ -121,7 +140,7 @@ impl SimpleFunctionFactoryBase for Factory {
121140
args: Args,
122141
_context: Arc<FlowInstanceContext>,
123142
) -> Result<impl SimpleFunctionExecutor> {
124-
Ok(Executor { spec, args })
143+
Ok(Executor { spec, args }.into_fn_executor())
125144
}
126145
}
127146

0 commit comments

Comments
 (0)