Skip to content

Commit d18dd6e

Browse files
authored
feat(batching): add an optional max_batch_size option (#1261)
* feat: batching util support optional batch size * feat: support `max_batch_size` in functions
1 parent c1eaee9 commit d18dd6e

File tree

11 files changed

+375
-51
lines changed

11 files changed

+375
-51
lines changed

docs/docs/custom_ops/custom_functions.mdx

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ Custom functions take the following additional parameters:
145145
* `batching: bool`: Whether the executor will consume requests in batch.
146146
See the [Batching](#batching) section below for details.
147147

148+
* `max_batch_size: int | None`: The maximum batch size for the executor.
149+
148150
* `behavior_version: int`: The version of the behavior of the function.
149151
When the version is changed, the function will be re-executed even if cache is enabled.
150152
It's required to be set if `cache` is `True`.
@@ -221,5 +223,25 @@ class ComputeSomethingExecutor:
221223
...
222224
```
223225

226+
### Controlling Batch Size
227+
228+
You can control the maximum batch size using the `max_batch_size` parameter. This is useful for:
229+
* Limiting memory usage when processing large batches
230+
* Reducing latency by flushing batches before they grow too large
231+
* Working with APIs that have request size limits
232+
233+
```python
234+
@cocoindex.op.function(batching=True, max_batch_size=32)
235+
def compute_something(args: list[str]) -> list[str]:
236+
...
237+
```
238+
239+
With `max_batch_size` set, a batch will be flushed when either:
240+
241+
1. No ongoing batches are running
242+
2. The pending batch size reaches `max_batch_size`
243+
244+
This ensures that requests don't wait indefinitely for a batch to fill up, while still allowing efficient batch processing.
245+
224246
</TabItem>
225247
</Tabs>

examples/code_embedding/main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def code_to_embedding(
1616
Embed the text using a SentenceTransformer model.
1717
"""
1818
# You can also switch to Voyage embedding model:
19-
# return text.transform(
20-
# cocoindex.functions.EmbedText(
21-
# api_type=cocoindex.LlmApiType.VOYAGE,
22-
# model="voyage-code-3",
23-
# )
24-
# )
19+
# return text.transform(
20+
# cocoindex.functions.EmbedText(
21+
# api_type=cocoindex.LlmApiType.GEMINI,
22+
# model="text-embedding-004",
23+
# )
24+
# )
2525
return text.transform(
2626
cocoindex.functions.SentenceTransformerEmbed(
2727
model="sentence-transformers/all-MiniLM-L6-v2"

python/cocoindex/functions/colpali.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class ColPaliEmbedImage(op.FunctionSpec):
125125
gpu=True,
126126
cache=True,
127127
batching=True,
128+
max_batch_size=32,
128129
behavior_version=1,
129130
)
130131
class ColPaliEmbedImageExecutor:
@@ -204,6 +205,7 @@ class ColPaliEmbedQuery(op.FunctionSpec):
204205
cache=True,
205206
behavior_version=1,
206207
batching=True,
208+
max_batch_size=32,
207209
)
208210
class ColPaliEmbedQueryExecutor:
209211
"""Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.)."""

python/cocoindex/functions/sbert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class SentenceTransformerEmbed(op.FunctionSpec):
3131
gpu=True,
3232
cache=True,
3333
batching=True,
34+
max_batch_size=512,
3435
behavior_version=1,
3536
arg_relationship=(op.ArgRelationship.EMBEDDING_ORIGIN_TEXT, "text"),
3637
)

python/cocoindex/op.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class OpArgs:
151151
- gpu: Whether the executor will be executed on GPU.
152152
- cache: Whether the executor will be cached.
153153
- batching: Whether the executor will be batched.
154+
- max_batch_size: The maximum batch size for the executor. Only valid if `batching` is True.
154155
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
155156
changes. Must be provided if `cache` is True.
156157
- arg_relationship: It specifies the relationship between an input argument and the output,
@@ -161,6 +162,7 @@ class OpArgs:
161162
gpu: bool = False
162163
cache: bool = False
163164
batching: bool = False
165+
max_batch_size: int | None = None
164166
behavior_version: int | None = None
165167
arg_relationship: tuple[ArgRelationship, str] | None = None
166168

@@ -389,11 +391,17 @@ def enable_cache(self) -> bool:
389391
def behavior_version(self) -> int | None:
390392
return op_args.behavior_version
391393

394+
def batching_options(self) -> dict[str, Any] | None:
395+
if op_args.batching:
396+
return {
397+
"max_batch_size": op_args.max_batch_size,
398+
}
399+
else:
400+
return None
401+
392402
if category == OpCategory.FUNCTION:
393403
_engine.register_function_factory(
394-
op_kind,
395-
_EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor),
396-
op_args.batching,
404+
op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor)
397405
)
398406
else:
399407
raise ValueError(f"Unsupported executor type {category}")

src/execution/source_indexer.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,10 @@ impl SourceIndexingContext {
304304
rows_to_retry,
305305
}),
306306
setup_execution_ctx,
307-
update_once_batcher: batching::Batcher::new(UpdateOnceRunner),
307+
update_once_batcher: batching::Batcher::new(
308+
UpdateOnceRunner,
309+
batching::BatchingOptions::default(),
310+
),
308311
}))
309312
}
310313

src/ops/factory_bases.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static {
381381
fn into_fn_executor(self) -> impl SimpleFunctionExecutor {
382382
BatchedFunctionExecutorWrapper::new(self)
383383
}
384+
385+
fn batching_options(&self) -> batching::BatchingOptions;
384386
}
385387

386388
#[async_trait]
@@ -404,10 +406,11 @@ struct BatchedFunctionExecutorWrapper<E: BatchedFunctionExecutor> {
404406

405407
impl<E: BatchedFunctionExecutor> BatchedFunctionExecutorWrapper<E> {
406408
fn new(executor: E) -> Self {
409+
let batching_options = executor.batching_options();
407410
Self {
408411
enable_cache: executor.enable_cache(),
409412
behavior_version: executor.behavior_version(),
410-
batcher: batching::Batcher::new(executor),
413+
batcher: batching::Batcher::new(executor, batching_options),
411414
}
412415
}
413416
}

src/ops/functions/embed_text.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ impl BatchedFunctionExecutor for Executor {
3636
true
3737
}
3838

39+
fn batching_options(&self) -> batching::BatchingOptions {
40+
// A safe default for most embeddings providers.
41+
// May tune it for specific providers later.
42+
batching::BatchingOptions {
43+
max_batch_size: Some(64),
44+
}
45+
}
46+
3947
async fn evaluate_batch(&self, args: Vec<Vec<Value>>) -> Result<Vec<Value>> {
4048
let texts = args
4149
.iter()

src/ops/py_factory.rs

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ struct PyBatchedFunctionExecutor {
121121

122122
enable_cache: bool,
123123
behavior_version: Option<u32>,
124+
batching_options: batching::BatchingOptions,
124125
}
125126

126127
#[async_trait]
@@ -168,11 +169,13 @@ impl BatchedFunctionExecutor for PyBatchedFunctionExecutor {
168169
fn behavior_version(&self) -> Option<u32> {
169170
self.behavior_version
170171
}
172+
fn batching_options(&self) -> batching::BatchingOptions {
173+
self.batching_options.clone()
174+
}
171175
}
172176

173177
pub(crate) struct PyFunctionFactory {
174178
pub py_function_factory: Py<PyAny>,
175-
pub batching: bool,
176179
}
177180

178181
#[async_trait]
@@ -237,7 +240,7 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
237240
.as_ref()
238241
.ok_or_else(|| anyhow!("Python execution context is missing"))?
239242
.clone();
240-
let (prepare_fut, enable_cache, behavior_version) =
243+
let (prepare_fut, enable_cache, behavior_version, batching_options) =
241244
Python::with_gil(|py| -> anyhow::Result<_> {
242245
let prepare_coro = executor
243246
.call_method(py, "prepare", (), None)
@@ -257,31 +260,45 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
257260
.call_method(py, "behavior_version", (), None)
258261
.to_result_with_py_trace(py)?
259262
.extract::<Option<u32>>(py)?;
260-
Ok((prepare_fut, enable_cache, behavior_version))
263+
let batching_options = executor
264+
.call_method(py, "batching_options", (), None)
265+
.to_result_with_py_trace(py)?
266+
.extract::<crate::py::Pythonized<Option<batching::BatchingOptions>>>(
267+
py,
268+
)?
269+
.into_inner();
270+
Ok((
271+
prepare_fut,
272+
enable_cache,
273+
behavior_version,
274+
batching_options,
275+
))
261276
})?;
262277
prepare_fut.await?;
263-
let executor: Box<dyn interface::SimpleFunctionExecutor> = if self.batching {
264-
Box::new(
265-
PyBatchedFunctionExecutor {
278+
let executor: Box<dyn interface::SimpleFunctionExecutor> =
279+
if let Some(batching_options) = batching_options {
280+
Box::new(
281+
PyBatchedFunctionExecutor {
282+
py_function_executor: executor,
283+
py_exec_ctx,
284+
result_type,
285+
enable_cache,
286+
behavior_version,
287+
batching_options,
288+
}
289+
.into_fn_executor(),
290+
)
291+
} else {
292+
Box::new(Arc::new(PyFunctionExecutor {
266293
py_function_executor: executor,
267294
py_exec_ctx,
295+
num_positional_args,
296+
kw_args_names,
268297
result_type,
269298
enable_cache,
270299
behavior_version,
271-
}
272-
.into_fn_executor(),
273-
)
274-
} else {
275-
Box::new(Arc::new(PyFunctionExecutor {
276-
py_function_executor: executor,
277-
py_exec_ctx,
278-
num_positional_args,
279-
kw_args_names,
280-
result_type,
281-
enable_cache,
282-
behavior_version,
283-
}))
284-
};
300+
}))
301+
};
285302
Ok(executor)
286303
}
287304
};

src/py/mod.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,9 @@ fn register_source_connector(name: String, py_source_connector: Py<PyAny>) -> Py
156156
}
157157

158158
#[pyfunction]
159-
fn register_function_factory(
160-
name: String,
161-
py_function_factory: Py<PyAny>,
162-
batching: bool,
163-
) -> PyResult<()> {
159+
fn register_function_factory(name: String, py_function_factory: Py<PyAny>) -> PyResult<()> {
164160
let factory = PyFunctionFactory {
165161
py_function_factory,
166-
batching,
167162
};
168163
register_factory(name, ExecutorFactory::SimpleFunction(Arc::new(factory))).into_py_result()
169164
}

0 commit comments

Comments
 (0)