Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/cocoindex/functions/sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SentenceTransformerEmbed(op.FunctionSpec):
@op.executor_class(
gpu=True,
cache=True,
batching=True,
behavior_version=1,
arg_relationship=(op.ArgRelationship.EMBEDDING_ORIGIN_TEXT, "text"),
)
Expand All @@ -57,7 +58,9 @@ def analyze(self) -> type:
dim = self._model.get_sentence_embedding_dimension()
return Vector[np.float32, Literal[dim]] # type: ignore

def __call__(self, text: str) -> NDArray[np.float32]:
def __call__(self, text: list[str]) -> list[NDArray[np.float32]]:
assert self._model is not None
result: NDArray[np.float32] = self._model.encode(text, convert_to_numpy=True)
return result
results: list[NDArray[np.float32]] = self._model.encode(
text, convert_to_numpy=True
)
return results
69 changes: 52 additions & 17 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .typing import (
KEY_FIELD_NAME,
AnalyzedListType,
AnalyzedTypeInfo,
StructSchema,
StructType,
Expand All @@ -45,6 +46,7 @@
EnrichedValueType,
decode_engine_field_schemas,
FieldSchema,
ValueType,
)
from .runtime import to_async_call
from .index import IndexOptions
Expand Down Expand Up @@ -149,6 +151,7 @@ class OpArgs:
"""
- gpu: Whether the executor will be executed on GPU.
- cache: Whether the executor will be cached.
- batching: Whether the executor will be batched.
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
changes. Must be provided if `cache` is True.
- arg_relationship: It specifies the relationship between an input argument and the output,
Expand All @@ -158,6 +161,7 @@ class OpArgs:

gpu: bool = False
cache: bool = False
batching: bool = False
behavior_version: int | None = None
arg_relationship: tuple[ArgRelationship, str] | None = None

Expand All @@ -168,6 +172,16 @@ class _ArgInfo:
is_required: bool


def _make_batched_engine_value_decoder(
field_path: list[str], src_type: ValueType, dst_type_info: AnalyzedTypeInfo
) -> Callable[[Any], Any]:
if not isinstance(dst_type_info.variant, AnalyzedListType):
raise ValueError("Expected arguments for batching function to be a list type")
elem_type_info = analyze_type_info(dst_type_info.variant.elem_type)
base_decoder = make_engine_value_decoder(field_path, src_type, elem_type_info)
return lambda value: [base_decoder(v) for v in value]


def _register_op_factory(
category: OpCategory,
expected_args: list[tuple[str, inspect.Parameter]],
Expand All @@ -181,6 +195,10 @@ def _register_op_factory(
Register an op factory.
"""

if op_args.batching:
if len(expected_args) != 1:
raise ValueError("Batching is only supported for single argument functions")

class _WrappedExecutor:
_executor: Any
_args_info: list[_ArgInfo]
Expand Down Expand Up @@ -208,7 +226,7 @@ def analyze_schema(
"""
self._args_info = []
self._kwargs_info = {}
attributes = []
attributes = {}
potentially_missing_required_arg = False

def process_arg(
Expand All @@ -220,14 +238,17 @@ def process_arg(
if op_args.arg_relationship is not None:
related_attr, related_arg_name = op_args.arg_relationship
if related_arg_name == arg_name:
attributes.append(
TypeAttr(related_attr.value, actual_arg.analyzed_value)
)
attributes[related_attr.value] = actual_arg.analyzed_value
type_info = analyze_type_info(arg_param.annotation)
enriched = EnrichedValueType.decode(actual_arg.value_type)
decoder = make_engine_value_decoder(
[arg_name], enriched.type, type_info
)
if op_args.batching:
decoder = _make_batched_engine_value_decoder(
[arg_name], enriched.type, type_info
)
else:
decoder = make_engine_value_decoder(
[arg_name], enriched.type, type_info
)
is_required = not type_info.nullable
if is_required and actual_arg.value_type.get("nullable", False):
potentially_missing_required_arg = True
Expand Down Expand Up @@ -302,20 +323,32 @@ def process_arg(
if len(missing_args) > 0:
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")

analyzed_expected_return_type = analyze_type_info(expected_return)
self._result_encoder = make_engine_value_encoder(
analyzed_expected_return_type
)

base_analyze_method = getattr(self._executor, "analyze", None)
if base_analyze_method is not None:
result_type = base_analyze_method()
analyzed_result_type = analyze_type_info(base_analyze_method())
else:
result_type = expected_return
if op_args.batching:
if not isinstance(
analyzed_expected_return_type.variant, AnalyzedListType
):
raise ValueError(
"Expected return type for batching function to be a list type"
)
analyzed_result_type = analyze_type_info(
analyzed_expected_return_type.variant.elem_type
)
else:
analyzed_result_type = analyzed_expected_return_type
if len(attributes) > 0:
result_type = Annotated[result_type, *attributes]

analyzed_result_type_info = analyze_type_info(result_type)
encoded_type = encode_enriched_type_info(analyzed_result_type_info)
analyzed_result_type.attrs = attributes
if potentially_missing_required_arg:
encoded_type["nullable"] = True

self._result_encoder = make_engine_value_encoder(analyzed_result_type_info)
analyzed_result_type.nullable = True
encoded_type = encode_enriched_type_info(analyzed_result_type)

return encoded_type

Expand Down Expand Up @@ -359,7 +392,9 @@ def behavior_version(self) -> int | None:

if category == OpCategory.FUNCTION:
_engine.register_function_factory(
op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor)
op_kind,
_EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor),
op_args.batching,
)
else:
raise ValueError(f"Unsupported executor type {category}")
Expand Down
82 changes: 80 additions & 2 deletions src/ops/factory_bases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ use crate::base::schema::*;
use crate::base::spec::*;
use crate::builder::plan::AnalyzedValueMapping;
use crate::setup;
// SourceFactoryBase

////////////////////////////////////////////////////////
// Op Args
////////////////////////////////////////////////////////

pub struct OpArgResolver<'arg> {
name: String,
resolved_op_arg: Option<(usize, EnrichedValueType)>,
Expand Down Expand Up @@ -204,6 +208,10 @@ impl<'a> OpArgsResolver<'a> {
}
}

////////////////////////////////////////////////////////
// Source
////////////////////////////////////////////////////////

#[async_trait]
pub trait SourceFactoryBase: SourceFactory + Send + Sync + 'static {
type Spec: DeserializeOwned + Send + Sync;
Expand Down Expand Up @@ -254,7 +262,9 @@ impl<T: SourceFactoryBase> SourceFactory for T {
}
}

// SimpleFunctionFactoryBase
////////////////////////////////////////////////////////
// Function
////////////////////////////////////////////////////////

#[async_trait]
pub trait SimpleFunctionFactoryBase: SimpleFunctionFactory + Send + Sync + 'static {
Expand Down Expand Up @@ -355,6 +365,70 @@ impl<T: SimpleFunctionFactoryBase> SimpleFunctionFactory for T {
}
}

#[async_trait]
pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static {
async fn evaluate_batch(&self, args: Vec<Vec<value::Value>>) -> Result<Vec<value::Value>>;

fn enable_cache(&self) -> bool {
false
}

fn behavior_version(&self) -> Option<u32> {
None
}

fn into_fn_executor(self) -> Box<dyn SimpleFunctionExecutor> {
Box::new(BatchedFunctionExecutorWrapper::new(self))
}
}

#[async_trait]
impl<E: BatchedFunctionExecutor> batching::Runner for E {
type Input = Vec<value::Value>;
type Output = value::Value;

async fn run(
&self,
inputs: Vec<Self::Input>,
) -> Result<impl ExactSizeIterator<Item = Self::Output>> {
Ok(self.evaluate_batch(inputs).await?.into_iter())
}
}

struct BatchedFunctionExecutorWrapper<E: BatchedFunctionExecutor> {
batcher: batching::Batcher<E>,
enable_cache: bool,
behavior_version: Option<u32>,
}

impl<E: BatchedFunctionExecutor> BatchedFunctionExecutorWrapper<E> {
fn new(executor: E) -> Self {
Self {
enable_cache: executor.enable_cache(),
behavior_version: executor.behavior_version(),
batcher: batching::Batcher::new(executor),
}
}
}

#[async_trait]
impl<E: BatchedFunctionExecutor> SimpleFunctionExecutor for BatchedFunctionExecutorWrapper<E> {
async fn evaluate(&self, args: Vec<value::Value>) -> Result<value::Value> {
self.batcher.run(args).await
}

fn enable_cache(&self) -> bool {
self.enable_cache
}
fn behavior_version(&self) -> Option<u32> {
self.behavior_version
}
}

////////////////////////////////////////////////////////
// Target
////////////////////////////////////////////////////////

pub struct TypedExportDataCollectionBuildOutput<F: TargetFactoryBase + ?Sized> {
pub export_context: BoxFuture<'static, Result<Arc<F::ExportContext>>>,
pub setup_key: F::SetupKey,
Expand Down Expand Up @@ -636,6 +710,10 @@ fn from_json_combined_state<T: Debug + Clone + Serialize + DeserializeOwned>(
})
}

////////////////////////////////////////////////////////
// Target Attachment
////////////////////////////////////////////////////////

pub struct TypedTargetAttachmentState<F: TargetSpecificAttachmentFactoryBase + ?Sized> {
pub setup_key: F::SetupKey,
pub setup_state: F::SetupState,
Expand Down
90 changes: 79 additions & 11 deletions src/ops/py_factory.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{prelude::*, py::future::from_py_future};
use crate::{ops::sdk::BatchedFunctionExecutor, prelude::*, py::future::from_py_future};

use pyo3::{
Bound, IntoPyObjectExt, Py, PyAny, Python, pyclass, pymethods,
Expand Down Expand Up @@ -114,8 +114,65 @@ impl interface::SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
}
}

struct PyBatchedFunctionExecutor {
py_function_executor: Py<PyAny>,
py_exec_ctx: Arc<py::PythonExecutionContext>,
result_type: schema::EnrichedValueType,

enable_cache: bool,
behavior_version: Option<u32>,
}

#[async_trait]
impl BatchedFunctionExecutor for PyBatchedFunctionExecutor {
async fn evaluate_batch(&self, args: Vec<Vec<value::Value>>) -> Result<Vec<value::Value>> {
let result_fut = Python::with_gil(|py| -> pyo3::PyResult<_> {
let py_args = PyList::new(
py,
args.into_iter()
.map(|v| {
py::value_to_py_object(
py,
v.get(0).ok_or_else(|| {
pyo3::PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Expected a list of lists",
)
})?,
)
})
.collect::<pyo3::PyResult<Vec<_>>>()?,
)?;
let result_coro = self.py_function_executor.call1(py, (py_args,))?;
let task_locals =
pyo3_async_runtimes::TaskLocals::new(self.py_exec_ctx.event_loop.bind(py).clone());
Ok(from_py_future(
py,
&task_locals,
result_coro.into_bound(py),
)?)
})?;
let result = result_fut.await;
Python::with_gil(|py| -> Result<_> {
let result = result.to_result_with_py_trace(py)?;
let result_bound = result.into_bound(py);
let result_list = result_bound.extract::<Vec<Bound<'_, PyAny>>>()?;
Ok(result_list
.into_iter()
.map(|v| py::value_from_py_object(&self.result_type.typ, &v))
.collect::<pyo3::PyResult<Vec<_>>>()?)
})
}
fn enable_cache(&self) -> bool {
self.enable_cache
}
fn behavior_version(&self) -> Option<u32> {
self.behavior_version
}
}

pub(crate) struct PyFunctionFactory {
pub py_function_factory: Py<PyAny>,
pub batching: bool,
}

#[async_trait]
Expand Down Expand Up @@ -203,16 +260,27 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
Ok((prepare_fut, enable_cache, behavior_version))
})?;
prepare_fut.await?;
Ok(Box::new(Arc::new(PyFunctionExecutor {
py_function_executor: executor,
py_exec_ctx,
num_positional_args,
kw_args_names,
result_type,
enable_cache,
behavior_version,
}))
as Box<dyn interface::SimpleFunctionExecutor>)
let executor = if self.batching {
PyBatchedFunctionExecutor {
py_function_executor: executor,
py_exec_ctx,
result_type,
enable_cache,
behavior_version,
}
.into_fn_executor()
} else {
Box::new(Arc::new(PyFunctionExecutor {
py_function_executor: executor,
py_exec_ctx,
num_positional_args,
kw_args_names,
result_type,
enable_cache,
behavior_version,
})) as Box<dyn interface::SimpleFunctionExecutor>
};
Ok(executor)
}
};

Expand Down
Loading
Loading