diff --git a/python/cocoindex/functions/sbert.py b/python/cocoindex/functions/sbert.py index b1415a51..93f9fc44 100644 --- a/python/cocoindex/functions/sbert.py +++ b/python/cocoindex/functions/sbert.py @@ -31,6 +31,7 @@ class SentenceTransformerEmbed(op.FunctionSpec): @op.executor_class( gpu=True, cache=True, + batching=True, behavior_version=1, arg_relationship=(op.ArgRelationship.EMBEDDING_ORIGIN_TEXT, "text"), ) @@ -57,7 +58,9 @@ def analyze(self) -> type: dim = self._model.get_sentence_embedding_dimension() return Vector[np.float32, Literal[dim]] # type: ignore - def __call__(self, text: str) -> NDArray[np.float32]: + def __call__(self, text: list[str]) -> list[NDArray[np.float32]]: assert self._model is not None - result: NDArray[np.float32] = self._model.encode(text, convert_to_numpy=True) - return result + results: list[NDArray[np.float32]] = self._model.encode( + text, convert_to_numpy=True + ) + return results diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index cee693a0..3a15fd5a 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -32,6 +32,7 @@ ) from .typing import ( KEY_FIELD_NAME, + AnalyzedListType, AnalyzedTypeInfo, StructSchema, StructType, @@ -45,6 +46,7 @@ EnrichedValueType, decode_engine_field_schemas, FieldSchema, + ValueType, ) from .runtime import to_async_call from .index import IndexOptions @@ -149,6 +151,7 @@ class OpArgs: """ - gpu: Whether the executor will be executed on GPU. - cache: Whether the executor will be cached. + - batching: Whether the executor will be batched. - behavior_version: The behavior version of the executor. Cache will be invalidated if it changes. Must be provided if `cache` is True. - arg_relationship: It specifies the relationship between an input argument and the output, @@ -158,6 +161,7 @@ class OpArgs: gpu: bool = False cache: bool = False + batching: bool = False behavior_version: int | None = None arg_relationship: tuple[ArgRelationship, str] | None = None @@ -168,6 +172,16 @@ class _ArgInfo: is_required: bool +def _make_batched_engine_value_decoder( + field_path: list[str], src_type: ValueType, dst_type_info: AnalyzedTypeInfo +) -> Callable[[Any], Any]: + if not isinstance(dst_type_info.variant, AnalyzedListType): + raise ValueError("Expected arguments for batching function to be a list type") + elem_type_info = analyze_type_info(dst_type_info.variant.elem_type) + base_decoder = make_engine_value_decoder(field_path, src_type, elem_type_info) + return lambda value: [base_decoder(v) for v in value] + + def _register_op_factory( category: OpCategory, expected_args: list[tuple[str, inspect.Parameter]], @@ -181,6 +195,10 @@ def _register_op_factory( Register an op factory. """ + if op_args.batching: + if len(expected_args) != 1: + raise ValueError("Batching is only supported for single argument functions") + class _WrappedExecutor: _executor: Any _args_info: list[_ArgInfo] @@ -208,7 +226,7 @@ def analyze_schema( """ self._args_info = [] self._kwargs_info = {} - attributes = [] + attributes = {} potentially_missing_required_arg = False def process_arg( @@ -220,14 +238,17 @@ def process_arg( if op_args.arg_relationship is not None: related_attr, related_arg_name = op_args.arg_relationship if related_arg_name == arg_name: - attributes.append( - TypeAttr(related_attr.value, actual_arg.analyzed_value) - ) + attributes[related_attr.value] = actual_arg.analyzed_value type_info = analyze_type_info(arg_param.annotation) enriched = EnrichedValueType.decode(actual_arg.value_type) - decoder = make_engine_value_decoder( - [arg_name], enriched.type, type_info - ) + if op_args.batching: + decoder = _make_batched_engine_value_decoder( + [arg_name], enriched.type, type_info + ) + else: + decoder = make_engine_value_decoder( + [arg_name], enriched.type, type_info + ) is_required = not type_info.nullable if is_required and actual_arg.value_type.get("nullable", False): potentially_missing_required_arg = True @@ -302,20 +323,32 @@ def process_arg( if len(missing_args) > 0: raise ValueError(f"Missing arguments: {', '.join(missing_args)}") + analyzed_expected_return_type = analyze_type_info(expected_return) + self._result_encoder = make_engine_value_encoder( + analyzed_expected_return_type + ) + base_analyze_method = getattr(self._executor, "analyze", None) if base_analyze_method is not None: - result_type = base_analyze_method() + analyzed_result_type = analyze_type_info(base_analyze_method()) else: - result_type = expected_return + if op_args.batching: + if not isinstance( + analyzed_expected_return_type.variant, AnalyzedListType + ): + raise ValueError( + "Expected return type for batching function to be a list type" + ) + analyzed_result_type = analyze_type_info( + analyzed_expected_return_type.variant.elem_type + ) + else: + analyzed_result_type = analyzed_expected_return_type if len(attributes) > 0: - result_type = Annotated[result_type, *attributes] - - analyzed_result_type_info = analyze_type_info(result_type) - encoded_type = encode_enriched_type_info(analyzed_result_type_info) + analyzed_result_type.attrs = attributes if potentially_missing_required_arg: - encoded_type["nullable"] = True - - self._result_encoder = make_engine_value_encoder(analyzed_result_type_info) + analyzed_result_type.nullable = True + encoded_type = encode_enriched_type_info(analyzed_result_type) return encoded_type @@ -359,7 +392,9 @@ def behavior_version(self) -> int | None: if category == OpCategory.FUNCTION: _engine.register_function_factory( - op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor) + op_kind, + _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor), + op_args.batching, ) else: raise ValueError(f"Unsupported executor type {category}") diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index 6deb89cb..b6e5426e 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -10,7 +10,11 @@ use crate::base::schema::*; use crate::base::spec::*; use crate::builder::plan::AnalyzedValueMapping; use crate::setup; -// SourceFactoryBase + +//////////////////////////////////////////////////////// +// Op Args +//////////////////////////////////////////////////////// + pub struct OpArgResolver<'arg> { name: String, resolved_op_arg: Option<(usize, EnrichedValueType)>, @@ -204,6 +208,10 @@ impl<'a> OpArgsResolver<'a> { } } +//////////////////////////////////////////////////////// +// Source +//////////////////////////////////////////////////////// + #[async_trait] pub trait SourceFactoryBase: SourceFactory + Send + Sync + 'static { type Spec: DeserializeOwned + Send + Sync; @@ -254,7 +262,9 @@ impl SourceFactory for T { } } -// SimpleFunctionFactoryBase +//////////////////////////////////////////////////////// +// Function +//////////////////////////////////////////////////////// #[async_trait] pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'static { @@ -355,6 +365,70 @@ impl SimpleFunctionFactory for T { } } +#[async_trait] +pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static { + async fn evaluate_batch(&self, args: Vec>) -> Result>; + + fn enable_cache(&self) -> bool { + false + } + + fn behavior_version(&self) -> Option { + None + } + + fn into_fn_executor(self) -> Box { + Box::new(BatchedFunctionExecutorWrapper::new(self)) + } +} + +#[async_trait] +impl batching::Runner for E { + type Input = Vec; + type Output = value::Value; + + async fn run( + &self, + inputs: Vec, + ) -> Result> { + Ok(self.evaluate_batch(inputs).await?.into_iter()) + } +} + +struct BatchedFunctionExecutorWrapper { + batcher: batching::Batcher, + enable_cache: bool, + behavior_version: Option, +} + +impl BatchedFunctionExecutorWrapper { + fn new(executor: E) -> Self { + Self { + enable_cache: executor.enable_cache(), + behavior_version: executor.behavior_version(), + batcher: batching::Batcher::new(executor), + } + } +} + +#[async_trait] +impl SimpleFunctionExecutor for BatchedFunctionExecutorWrapper { + async fn evaluate(&self, args: Vec) -> Result { + self.batcher.run(args).await + } + + fn enable_cache(&self) -> bool { + self.enable_cache + } + fn behavior_version(&self) -> Option { + self.behavior_version + } +} + +//////////////////////////////////////////////////////// +// Target +//////////////////////////////////////////////////////// + pub struct TypedExportDataCollectionBuildOutput { pub export_context: BoxFuture<'static, Result>>, pub setup_key: F::SetupKey, @@ -636,6 +710,10 @@ fn from_json_combined_state( }) } +//////////////////////////////////////////////////////// +// Target Attachment +//////////////////////////////////////////////////////// + pub struct TypedTargetAttachmentState { pub setup_key: F::SetupKey, pub setup_state: F::SetupState, diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index acf85305..be16b5fb 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -1,4 +1,4 @@ -use crate::{prelude::*, py::future::from_py_future}; +use crate::{ops::sdk::BatchedFunctionExecutor, prelude::*, py::future::from_py_future}; use pyo3::{ Bound, IntoPyObjectExt, Py, PyAny, Python, pyclass, pymethods, @@ -114,8 +114,65 @@ impl interface::SimpleFunctionExecutor for Arc { } } +struct PyBatchedFunctionExecutor { + py_function_executor: Py, + py_exec_ctx: Arc, + result_type: schema::EnrichedValueType, + + enable_cache: bool, + behavior_version: Option, +} + +#[async_trait] +impl BatchedFunctionExecutor for PyBatchedFunctionExecutor { + async fn evaluate_batch(&self, args: Vec>) -> Result> { + let result_fut = Python::with_gil(|py| -> pyo3::PyResult<_> { + let py_args = PyList::new( + py, + args.into_iter() + .map(|v| { + py::value_to_py_object( + py, + v.get(0).ok_or_else(|| { + pyo3::PyErr::new::( + "Expected a list of lists", + ) + })?, + ) + }) + .collect::>>()?, + )?; + let result_coro = self.py_function_executor.call1(py, (py_args,))?; + let task_locals = + pyo3_async_runtimes::TaskLocals::new(self.py_exec_ctx.event_loop.bind(py).clone()); + Ok(from_py_future( + py, + &task_locals, + result_coro.into_bound(py), + )?) + })?; + let result = result_fut.await; + Python::with_gil(|py| -> Result<_> { + let result = result.to_result_with_py_trace(py)?; + let result_bound = result.into_bound(py); + let result_list = result_bound.extract::>>()?; + Ok(result_list + .into_iter() + .map(|v| py::value_from_py_object(&self.result_type.typ, &v)) + .collect::>>()?) + }) + } + fn enable_cache(&self) -> bool { + self.enable_cache + } + fn behavior_version(&self) -> Option { + self.behavior_version + } +} + pub(crate) struct PyFunctionFactory { pub py_function_factory: Py, + pub batching: bool, } #[async_trait] @@ -203,16 +260,27 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory { Ok((prepare_fut, enable_cache, behavior_version)) })?; prepare_fut.await?; - Ok(Box::new(Arc::new(PyFunctionExecutor { - py_function_executor: executor, - py_exec_ctx, - num_positional_args, - kw_args_names, - result_type, - enable_cache, - behavior_version, - })) - as Box) + let executor = if self.batching { + PyBatchedFunctionExecutor { + py_function_executor: executor, + py_exec_ctx, + result_type, + enable_cache, + behavior_version, + } + .into_fn_executor() + } else { + Box::new(Arc::new(PyFunctionExecutor { + py_function_executor: executor, + py_exec_ctx, + num_positional_args, + kw_args_names, + result_type, + enable_cache, + behavior_version, + })) as Box + }; + Ok(executor) } }; diff --git a/src/prelude.rs b/src/prelude.rs index 5f5f0f36..e668c776 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -26,7 +26,7 @@ pub(crate) use crate::ops::interface; pub(crate) use crate::service::error::{ApiError, invariance_violation}; pub(crate) use crate::setup; pub(crate) use crate::setup::AuthRegistry; -pub(crate) use crate::utils::{self, concur_control, retryable}; +pub(crate) use crate::utils::{self, batching, concur_control, retryable}; pub(crate) use crate::{api_bail, api_error}; pub(crate) use anyhow::{anyhow, bail}; diff --git a/src/py/mod.rs b/src/py/mod.rs index f0dcca38..b362a414 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -156,9 +156,14 @@ fn register_source_connector(name: String, py_source_connector: Py) -> Py } #[pyfunction] -fn register_function_factory(name: String, py_function_factory: Py) -> PyResult<()> { +fn register_function_factory( + name: String, + py_function_factory: Py, + batching: bool, +) -> PyResult<()> { let factory = PyFunctionFactory { py_function_factory, + batching, }; register_factory(name, ExecutorFactory::SimpleFunction(Arc::new(factory))).into_py_result() }