Skip to content

Commit ca0efea

Browse files
authored
feat: support batching for Python SDK functions (#1232)
1 parent a1d89ec commit ca0efea

File tree

6 files changed

+224
-35
lines changed

6 files changed

+224
-35
lines changed

python/cocoindex/functions/sbert.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class SentenceTransformerEmbed(op.FunctionSpec):
3131
@op.executor_class(
3232
gpu=True,
3333
cache=True,
34+
batching=True,
3435
behavior_version=1,
3536
arg_relationship=(op.ArgRelationship.EMBEDDING_ORIGIN_TEXT, "text"),
3637
)
@@ -57,7 +58,9 @@ def analyze(self) -> type:
5758
dim = self._model.get_sentence_embedding_dimension()
5859
return Vector[np.float32, Literal[dim]] # type: ignore
5960

60-
def __call__(self, text: str) -> NDArray[np.float32]:
61+
def __call__(self, text: list[str]) -> list[NDArray[np.float32]]:
6162
assert self._model is not None
62-
result: NDArray[np.float32] = self._model.encode(text, convert_to_numpy=True)
63-
return result
63+
results: list[NDArray[np.float32]] = self._model.encode(
64+
text, convert_to_numpy=True
65+
)
66+
return results

python/cocoindex/op.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from .typing import (
3434
KEY_FIELD_NAME,
35+
AnalyzedListType,
3536
AnalyzedTypeInfo,
3637
StructSchema,
3738
StructType,
@@ -45,6 +46,7 @@
4546
EnrichedValueType,
4647
decode_engine_field_schemas,
4748
FieldSchema,
49+
ValueType,
4850
)
4951
from .runtime import to_async_call
5052
from .index import IndexOptions
@@ -149,6 +151,7 @@ class OpArgs:
149151
"""
150152
- gpu: Whether the executor will be executed on GPU.
151153
- cache: Whether the executor will be cached.
154+
- batching: Whether the executor will be batched.
152155
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
153156
changes. Must be provided if `cache` is True.
154157
- arg_relationship: It specifies the relationship between an input argument and the output,
@@ -158,6 +161,7 @@ class OpArgs:
158161

159162
gpu: bool = False
160163
cache: bool = False
164+
batching: bool = False
161165
behavior_version: int | None = None
162166
arg_relationship: tuple[ArgRelationship, str] | None = None
163167

@@ -168,6 +172,16 @@ class _ArgInfo:
168172
is_required: bool
169173

170174

175+
def _make_batched_engine_value_decoder(
176+
field_path: list[str], src_type: ValueType, dst_type_info: AnalyzedTypeInfo
177+
) -> Callable[[Any], Any]:
178+
if not isinstance(dst_type_info.variant, AnalyzedListType):
179+
raise ValueError("Expected arguments for batching function to be a list type")
180+
elem_type_info = analyze_type_info(dst_type_info.variant.elem_type)
181+
base_decoder = make_engine_value_decoder(field_path, src_type, elem_type_info)
182+
return lambda value: [base_decoder(v) for v in value]
183+
184+
171185
def _register_op_factory(
172186
category: OpCategory,
173187
expected_args: list[tuple[str, inspect.Parameter]],
@@ -181,6 +195,10 @@ def _register_op_factory(
181195
Register an op factory.
182196
"""
183197

198+
if op_args.batching:
199+
if len(expected_args) != 1:
200+
raise ValueError("Batching is only supported for single argument functions")
201+
184202
class _WrappedExecutor:
185203
_executor: Any
186204
_args_info: list[_ArgInfo]
@@ -208,7 +226,7 @@ def analyze_schema(
208226
"""
209227
self._args_info = []
210228
self._kwargs_info = {}
211-
attributes = []
229+
attributes = {}
212230
potentially_missing_required_arg = False
213231

214232
def process_arg(
@@ -220,14 +238,17 @@ def process_arg(
220238
if op_args.arg_relationship is not None:
221239
related_attr, related_arg_name = op_args.arg_relationship
222240
if related_arg_name == arg_name:
223-
attributes.append(
224-
TypeAttr(related_attr.value, actual_arg.analyzed_value)
225-
)
241+
attributes[related_attr.value] = actual_arg.analyzed_value
226242
type_info = analyze_type_info(arg_param.annotation)
227243
enriched = EnrichedValueType.decode(actual_arg.value_type)
228-
decoder = make_engine_value_decoder(
229-
[arg_name], enriched.type, type_info
230-
)
244+
if op_args.batching:
245+
decoder = _make_batched_engine_value_decoder(
246+
[arg_name], enriched.type, type_info
247+
)
248+
else:
249+
decoder = make_engine_value_decoder(
250+
[arg_name], enriched.type, type_info
251+
)
231252
is_required = not type_info.nullable
232253
if is_required and actual_arg.value_type.get("nullable", False):
233254
potentially_missing_required_arg = True
@@ -302,20 +323,32 @@ def process_arg(
302323
if len(missing_args) > 0:
303324
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
304325

326+
analyzed_expected_return_type = analyze_type_info(expected_return)
327+
self._result_encoder = make_engine_value_encoder(
328+
analyzed_expected_return_type
329+
)
330+
305331
base_analyze_method = getattr(self._executor, "analyze", None)
306332
if base_analyze_method is not None:
307-
result_type = base_analyze_method()
333+
analyzed_result_type = analyze_type_info(base_analyze_method())
308334
else:
309-
result_type = expected_return
335+
if op_args.batching:
336+
if not isinstance(
337+
analyzed_expected_return_type.variant, AnalyzedListType
338+
):
339+
raise ValueError(
340+
"Expected return type for batching function to be a list type"
341+
)
342+
analyzed_result_type = analyze_type_info(
343+
analyzed_expected_return_type.variant.elem_type
344+
)
345+
else:
346+
analyzed_result_type = analyzed_expected_return_type
310347
if len(attributes) > 0:
311-
result_type = Annotated[result_type, *attributes]
312-
313-
analyzed_result_type_info = analyze_type_info(result_type)
314-
encoded_type = encode_enriched_type_info(analyzed_result_type_info)
348+
analyzed_result_type.attrs = attributes
315349
if potentially_missing_required_arg:
316-
encoded_type["nullable"] = True
317-
318-
self._result_encoder = make_engine_value_encoder(analyzed_result_type_info)
350+
analyzed_result_type.nullable = True
351+
encoded_type = encode_enriched_type_info(analyzed_result_type)
319352

320353
return encoded_type
321354

@@ -359,7 +392,9 @@ def behavior_version(self) -> int | None:
359392

360393
if category == OpCategory.FUNCTION:
361394
_engine.register_function_factory(
362-
op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor)
395+
op_kind,
396+
_EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor),
397+
op_args.batching,
363398
)
364399
else:
365400
raise ValueError(f"Unsupported executor type {category}")

src/ops/factory_bases.rs

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ use crate::base::schema::*;
1010
use crate::base::spec::*;
1111
use crate::builder::plan::AnalyzedValueMapping;
1212
use crate::setup;
13-
// SourceFactoryBase
13+
14+
////////////////////////////////////////////////////////
15+
// Op Args
16+
////////////////////////////////////////////////////////
17+
1418
pub struct OpArgResolver<'arg> {
1519
name: String,
1620
resolved_op_arg: Option<(usize, EnrichedValueType)>,
@@ -204,6 +208,10 @@ impl<'a> OpArgsResolver<'a> {
204208
}
205209
}
206210

211+
////////////////////////////////////////////////////////
212+
// Source
213+
////////////////////////////////////////////////////////
214+
207215
#[async_trait]
208216
pub trait SourceFactoryBase: SourceFactory + Send + Sync + 'static {
209217
type Spec: DeserializeOwned + Send + Sync;
@@ -254,7 +262,9 @@ impl<T: SourceFactoryBase> SourceFactory for T {
254262
}
255263
}
256264

257-
// SimpleFunctionFactoryBase
265+
////////////////////////////////////////////////////////
266+
// Function
267+
////////////////////////////////////////////////////////
258268

259269
#[async_trait]
260270
pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'static {
@@ -355,6 +365,70 @@ impl<T: SimpleFunctionFactoryBase> SimpleFunctionFactory for T {
355365
}
356366
}
357367

368+
#[async_trait]
369+
pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static {
370+
async fn evaluate_batch(&self, args: Vec<Vec<value::Value>>) -> Result<Vec<value::Value>>;
371+
372+
fn enable_cache(&self) -> bool {
373+
false
374+
}
375+
376+
fn behavior_version(&self) -> Option<u32> {
377+
None
378+
}
379+
380+
fn into_fn_executor(self) -> Box<dyn SimpleFunctionExecutor> {
381+
Box::new(BatchedFunctionExecutorWrapper::new(self))
382+
}
383+
}
384+
385+
#[async_trait]
386+
impl<E: BatchedFunctionExecutor> batching::Runner for E {
387+
type Input = Vec<value::Value>;
388+
type Output = value::Value;
389+
390+
async fn run(
391+
&self,
392+
inputs: Vec<Self::Input>,
393+
) -> Result<impl ExactSizeIterator<Item = Self::Output>> {
394+
Ok(self.evaluate_batch(inputs).await?.into_iter())
395+
}
396+
}
397+
398+
struct BatchedFunctionExecutorWrapper<E: BatchedFunctionExecutor> {
399+
batcher: batching::Batcher<E>,
400+
enable_cache: bool,
401+
behavior_version: Option<u32>,
402+
}
403+
404+
impl<E: BatchedFunctionExecutor> BatchedFunctionExecutorWrapper<E> {
405+
fn new(executor: E) -> Self {
406+
Self {
407+
enable_cache: executor.enable_cache(),
408+
behavior_version: executor.behavior_version(),
409+
batcher: batching::Batcher::new(executor),
410+
}
411+
}
412+
}
413+
414+
#[async_trait]
415+
impl<E: BatchedFunctionExecutor> SimpleFunctionExecutor for BatchedFunctionExecutorWrapper<E> {
416+
async fn evaluate(&self, args: Vec<value::Value>) -> Result<value::Value> {
417+
self.batcher.run(args).await
418+
}
419+
420+
fn enable_cache(&self) -> bool {
421+
self.enable_cache
422+
}
423+
fn behavior_version(&self) -> Option<u32> {
424+
self.behavior_version
425+
}
426+
}
427+
428+
////////////////////////////////////////////////////////
429+
// Target
430+
////////////////////////////////////////////////////////
431+
358432
pub struct TypedExportDataCollectionBuildOutput<F: TargetFactoryBase + ?Sized> {
359433
pub export_context: BoxFuture<'static, Result<Arc<F::ExportContext>>>,
360434
pub setup_key: F::SetupKey,
@@ -636,6 +710,10 @@ fn from_json_combined_state<T: Debug + Clone + Serialize + DeserializeOwned>(
636710
})
637711
}
638712

713+
////////////////////////////////////////////////////////
714+
// Target Attachment
715+
////////////////////////////////////////////////////////
716+
639717
pub struct TypedTargetAttachmentState<F: TargetSpecificAttachmentFactoryBase + ?Sized> {
640718
pub setup_key: F::SetupKey,
641719
pub setup_state: F::SetupState,

src/ops/py_factory.rs

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::{prelude::*, py::future::from_py_future};
1+
use crate::{ops::sdk::BatchedFunctionExecutor, prelude::*, py::future::from_py_future};
22

33
use pyo3::{
44
Bound, IntoPyObjectExt, Py, PyAny, Python, pyclass, pymethods,
@@ -114,8 +114,65 @@ impl interface::SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
114114
}
115115
}
116116

117+
struct PyBatchedFunctionExecutor {
118+
py_function_executor: Py<PyAny>,
119+
py_exec_ctx: Arc<py::PythonExecutionContext>,
120+
result_type: schema::EnrichedValueType,
121+
122+
enable_cache: bool,
123+
behavior_version: Option<u32>,
124+
}
125+
126+
#[async_trait]
127+
impl BatchedFunctionExecutor for PyBatchedFunctionExecutor {
128+
async fn evaluate_batch(&self, args: Vec<Vec<value::Value>>) -> Result<Vec<value::Value>> {
129+
let result_fut = Python::with_gil(|py| -> pyo3::PyResult<_> {
130+
let py_args = PyList::new(
131+
py,
132+
args.into_iter()
133+
.map(|v| {
134+
py::value_to_py_object(
135+
py,
136+
v.get(0).ok_or_else(|| {
137+
pyo3::PyErr::new::<pyo3::exceptions::PyValueError, _>(
138+
"Expected a list of lists",
139+
)
140+
})?,
141+
)
142+
})
143+
.collect::<pyo3::PyResult<Vec<_>>>()?,
144+
)?;
145+
let result_coro = self.py_function_executor.call1(py, (py_args,))?;
146+
let task_locals =
147+
pyo3_async_runtimes::TaskLocals::new(self.py_exec_ctx.event_loop.bind(py).clone());
148+
Ok(from_py_future(
149+
py,
150+
&task_locals,
151+
result_coro.into_bound(py),
152+
)?)
153+
})?;
154+
let result = result_fut.await;
155+
Python::with_gil(|py| -> Result<_> {
156+
let result = result.to_result_with_py_trace(py)?;
157+
let result_bound = result.into_bound(py);
158+
let result_list = result_bound.extract::<Vec<Bound<'_, PyAny>>>()?;
159+
Ok(result_list
160+
.into_iter()
161+
.map(|v| py::value_from_py_object(&self.result_type.typ, &v))
162+
.collect::<pyo3::PyResult<Vec<_>>>()?)
163+
})
164+
}
165+
fn enable_cache(&self) -> bool {
166+
self.enable_cache
167+
}
168+
fn behavior_version(&self) -> Option<u32> {
169+
self.behavior_version
170+
}
171+
}
172+
117173
pub(crate) struct PyFunctionFactory {
118174
pub py_function_factory: Py<PyAny>,
175+
pub batching: bool,
119176
}
120177

121178
#[async_trait]
@@ -203,16 +260,27 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
203260
Ok((prepare_fut, enable_cache, behavior_version))
204261
})?;
205262
prepare_fut.await?;
206-
Ok(Box::new(Arc::new(PyFunctionExecutor {
207-
py_function_executor: executor,
208-
py_exec_ctx,
209-
num_positional_args,
210-
kw_args_names,
211-
result_type,
212-
enable_cache,
213-
behavior_version,
214-
}))
215-
as Box<dyn interface::SimpleFunctionExecutor>)
263+
let executor = if self.batching {
264+
PyBatchedFunctionExecutor {
265+
py_function_executor: executor,
266+
py_exec_ctx,
267+
result_type,
268+
enable_cache,
269+
behavior_version,
270+
}
271+
.into_fn_executor()
272+
} else {
273+
Box::new(Arc::new(PyFunctionExecutor {
274+
py_function_executor: executor,
275+
py_exec_ctx,
276+
num_positional_args,
277+
kw_args_names,
278+
result_type,
279+
enable_cache,
280+
behavior_version,
281+
})) as Box<dyn interface::SimpleFunctionExecutor>
282+
};
283+
Ok(executor)
216284
}
217285
};
218286

0 commit comments

Comments
 (0)