From c8f06882bd8fc88429523fbbf0cb525685510d2a Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Sat, 18 Oct 2025 12:29:54 -0700 Subject: [PATCH] feat(custom-source): make custom source API more robust --- python/cocoindex/op.py | 18 ++++++++---------- src/ops/py_factory.rs | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 74fe74868..03fc37f7a 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -485,9 +485,8 @@ class _SourceExecutorContext: AsyncIterator[PartialSourceRow[Any, Any]] | Iterator[PartialSourceRow[Any, Any]], ] - _get_value_fn: Callable[ - [Any, SourceReadOptions], Awaitable[PartialSourceRowData[Any]] - ] + _orig_get_value_fn: Callable[..., Any] + _get_value_fn: Callable[..., Awaitable[PartialSourceRowData[Any]]] _provides_ordinal_fn: Callable[[], bool] | None def __init__( @@ -504,7 +503,8 @@ def __init__( self._value_encoder = make_engine_value_encoder(value_type_info) self._list_fn = _get_required_method(executor, "list") - self._get_value_fn = to_async_call(_get_required_method(executor, "get_value")) + self._orig_get_value_fn = _get_required_method(executor, "get_value") + self._get_value_fn = to_async_call(self._orig_get_value_fn) self._provides_ordinal_fn = getattr(executor, "provides_ordinal", None) def provides_ordinal(self) -> bool: @@ -521,11 +521,9 @@ async def list_async( Return an async iterator that yields individual rows one by one. Each yielded item is a tuple of (key, data). """ - # Convert the options dict to SourceReadOptions read_options = load_engine_object(SourceReadOptions, options) - - # Call the user's list method - list_result = self._list_fn(read_options) + args = _build_args(self._list_fn, 0, options=read_options) + list_result = self._list_fn(*args) # Handle both sync and async iterators if hasattr(list_result, "__aiter__"): @@ -548,8 +546,8 @@ async def get_value_async( ) -> dict[str, Any]: key = self._key_decoder(raw_key) read_options = load_engine_object(SourceReadOptions, options) - - row_data = await self._get_value_fn(key, read_options) + args = _build_args(self._orig_get_value_fn, 1, key=key, options=read_options) + row_data = await self._get_value_fn(*args) return self._encode_source_row_data(row_data) def _encode_source_row_data( diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index 12baf5a54..b6c4a174c 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -355,7 +355,7 @@ impl PySourceExecutor { if py_err.is_instance_of::(py) { Ok(None) } else { - Err(anyhow!("Error from async iterator: {}", py_err)) + Err(py_err).to_result_with_py_trace(py) } } }