Skip to content

Commit 00f923a

Browse files
authored
feat(async-fn): allow prepare to be async; rs always call py in async (#351)
1 parent 38f11e5 commit 00f923a

File tree

3 files changed

+62
-91
lines changed

3 files changed

+62
-91
lines changed

python/cocoindex/op.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import dataclasses
66
import inspect
77

8-
from typing import get_type_hints, Protocol, Any, Callable, dataclass_transform
8+
from typing import get_type_hints, Protocol, Any, Callable, Awaitable, dataclass_transform
99
from enum import Enum
10-
from threading import Lock
10+
from functools import partial
1111

1212
from .typing import encode_enriched_type
1313
from .convert import to_engine_value, make_engine_value_converter
@@ -61,7 +61,7 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
6161
return (encode_enriched_type(result_type), executor)
6262

6363

64-
_gpu_dispatch_lock = Lock()
64+
_gpu_dispatch_lock = asyncio.Lock()
6565

6666
@dataclasses.dataclass
6767
class OpArgs:
@@ -75,11 +75,15 @@ class OpArgs:
7575
cache: bool = False
7676
behavior_version: int | None = None
7777

78+
def _to_async_call(call: Callable) -> Callable[..., Awaitable[Any]]:
79+
if inspect.iscoroutinefunction(call):
80+
return call
81+
return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))
82+
7883
def _register_op_factory(
7984
category: OpCategory,
8085
expected_args: list[tuple[str, inspect.Parameter]],
8186
expected_return,
82-
is_async: bool,
8387
executor_cls: type,
8488
spec_cls: type,
8589
op_args: OpArgs,
@@ -97,10 +101,12 @@ def behavior_version(self):
97101
class _WrappedClass(executor_cls, _Fallback):
98102
_args_converters: list[Callable[[Any], Any]]
99103
_kwargs_converters: dict[str, Callable[[str, Any], Any]]
104+
_acall: Callable
100105

101106
def __init__(self, spec):
102107
super().__init__()
103108
self.spec = spec
109+
self._acall = _to_async_call(super().__call__)
104110

105111
def analyze(self, *args, **kwargs):
106112
"""
@@ -157,42 +163,30 @@ def analyze(self, *args, **kwargs):
157163
else:
158164
return expected_return
159165

160-
def prepare(self):
166+
async def prepare(self):
161167
"""
162168
Prepare for execution.
163169
It's executed after `analyze` and before any `__call__` execution.
164170
"""
165-
setup_method = getattr(executor_cls, 'prepare', None)
171+
setup_method = getattr(super(), 'prepare', None)
166172
if setup_method is not None:
167-
setup_method(self)
173+
await _to_async_call(setup_method)()
168174

169-
def __call__(self, *args, **kwargs):
175+
async def __call__(self, *args, **kwargs):
170176
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
171177
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg)
172178
for arg_name, arg in kwargs.items()}
173-
if is_async:
174-
async def _inner():
175-
if op_args.gpu:
176-
await asyncio.to_thread(_gpu_dispatch_lock.acquire)
177-
try:
178-
output = await super(_WrappedClass, self).__call__(
179-
*converted_args, **converted_kwargs)
180-
finally:
181-
if op_args.gpu:
182-
_gpu_dispatch_lock.release()
183-
return to_engine_value(output)
184-
return _inner()
185179

186180
if op_args.gpu:
187181
# For GPU executions, data-level parallelism is applied, so we don't want to
188182
# execute different tasks in parallel.
189183
# Besides, multiprocessing is more appropriate for pytorch.
190184
# For now, we use a lock to ensure only one task is executed at a time.
191185
# TODO: Implement multi-processing dispatching.
192-
with _gpu_dispatch_lock:
193-
output = super().__call__(*converted_args, **converted_kwargs)
186+
async with _gpu_dispatch_lock:
187+
output = await self._acall(*converted_args, **converted_kwargs)
194188
else:
195-
output = super().__call__(*converted_args, **converted_kwargs)
189+
output = await self._acall(*converted_args, **converted_kwargs)
196190
return to_engine_value(output)
197191

198192
_WrappedClass.__name__ = executor_cls.__name__
@@ -203,9 +197,7 @@ async def _inner():
203197

204198
if category == OpCategory.FUNCTION:
205199
_engine.register_function_factory(
206-
spec_cls.__name__,
207-
_FunctionExecutorFactory(spec_cls, _WrappedClass),
208-
is_async)
200+
spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass))
209201
else:
210202
raise ValueError(f"Unsupported executor type {category}")
211203

@@ -230,7 +222,6 @@ def _inner(cls: type[Executor]) -> type:
230222
category=spec_cls._op_category,
231223
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
232224
expected_return=sig.return_annotation,
233-
is_async=inspect.iscoroutinefunction(cls.__call__),
234225
executor_cls=cls,
235226
spec_cls=spec_cls,
236227
op_args=op_args)
@@ -266,7 +257,6 @@ class _Spec(FunctionSpec):
266257
category=OpCategory.FUNCTION,
267258
expected_args=list(sig.parameters.items()),
268259
expected_return=sig.return_annotation,
269-
is_async=inspect.iscoroutinefunction(fn),
270260
executor_cls=_Executor,
271261
spec_cls=_Spec,
272262
op_args=op_args)

src/ops/py_factory.rs

Lines changed: 43 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ impl PyOpArgSchema {
3939

4040
struct PyFunctionExecutor {
4141
py_function_executor: Py<PyAny>,
42-
is_async: bool,
4342
py_exec_ctx: Arc<crate::py::PythonExecutionContext>,
4443

4544
num_positional_args: usize,
@@ -91,36 +90,22 @@ impl PyFunctionExecutor {
9190
impl SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
9291
async fn evaluate(&self, input: Vec<value::Value>) -> Result<value::Value> {
9392
let self = self.clone();
94-
let result = if self.is_async {
95-
let result_fut = Python::with_gil(|py| -> Result<_> {
96-
let result = self.call_py_fn(py, input)?;
97-
let task_locals = pyo3_async_runtimes::TaskLocals::new(
98-
self.py_exec_ctx.event_loop.bind(py).clone(),
99-
);
100-
Ok(pyo3_async_runtimes::into_future_with_locals(
101-
&task_locals,
102-
result,
103-
)?)
104-
})?;
105-
let result = result_fut.await?;
106-
Python::with_gil(|py| -> Result<_> {
107-
Ok(py::value_from_py_object(
108-
&self.result_type.typ,
109-
&result.into_bound(py),
110-
)?)
111-
})?
112-
} else {
113-
tokio::task::spawn_blocking(move || {
114-
Python::with_gil(|py| -> Result<_> {
115-
Ok(py::value_from_py_object(
116-
&self.result_type.typ,
117-
&self.call_py_fn(py, input)?,
118-
)?)
119-
})
120-
})
121-
.await??
122-
};
123-
Ok(result)
93+
let result_fut = Python::with_gil(|py| -> Result<_> {
94+
let result_coro = self.call_py_fn(py, input)?;
95+
let task_locals =
96+
pyo3_async_runtimes::TaskLocals::new(self.py_exec_ctx.event_loop.bind(py).clone());
97+
Ok(pyo3_async_runtimes::into_future_with_locals(
98+
&task_locals,
99+
result_coro,
100+
)?)
101+
})?;
102+
let result = result_fut.await?;
103+
Python::with_gil(|py| -> Result<_> {
104+
Ok(py::value_from_py_object(
105+
&self.result_type.typ,
106+
&result.into_bound(py),
107+
)?)
108+
})
124109
}
125110

126111
fn enable_cache(&self) -> bool {
@@ -134,7 +119,6 @@ impl SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
134119

135120
pub(crate) struct PyFunctionFactory {
136121
pub py_function_factory: Py<PyAny>,
137-
pub is_async: bool,
138122
}
139123

140124
impl SimpleFunctionFactory for PyFunctionFactory {
@@ -195,31 +179,33 @@ impl SimpleFunctionFactory for PyFunctionFactory {
195179
.as_ref()
196180
.ok_or_else(|| anyhow!("Python execution context is missing"))?
197181
.clone();
198-
let executor = tokio::task::spawn_blocking(move || -> Result<_> {
199-
let (enable_cache, behavior_version) =
200-
Python::with_gil(|py| -> anyhow::Result<_> {
201-
executor.call_method(py, "prepare", (), None)?;
202-
let enable_cache = executor
203-
.call_method(py, "enable_cache", (), None)?
204-
.extract::<bool>(py)?;
205-
let behavior_version = executor
206-
.call_method(py, "behavior_version", (), None)?
207-
.extract::<Option<u32>>(py)?;
208-
Ok((enable_cache, behavior_version))
209-
})?;
210-
Ok(Box::new(Arc::new(PyFunctionExecutor {
211-
py_function_executor: executor,
212-
is_async: self.is_async,
213-
py_exec_ctx,
214-
num_positional_args,
215-
kw_args_names,
216-
result_type,
217-
enable_cache,
218-
behavior_version,
219-
})) as Box<dyn SimpleFunctionExecutor>)
220-
})
221-
.await??;
222-
Ok(executor)
182+
let (prepare_fut, enable_cache, behavior_version) =
183+
Python::with_gil(|py| -> anyhow::Result<_> {
184+
let prepare_coro = executor.call_method(py, "prepare", (), None)?;
185+
let prepare_fut = pyo3_async_runtimes::into_future_with_locals(
186+
&pyo3_async_runtimes::TaskLocals::new(
187+
py_exec_ctx.event_loop.bind(py).clone(),
188+
),
189+
prepare_coro.into_bound(py),
190+
)?;
191+
let enable_cache = executor
192+
.call_method(py, "enable_cache", (), None)?
193+
.extract::<bool>(py)?;
194+
let behavior_version = executor
195+
.call_method(py, "behavior_version", (), None)?
196+
.extract::<Option<u32>>(py)?;
197+
Ok((prepare_fut, enable_cache, behavior_version))
198+
})?;
199+
prepare_fut.await?;
200+
Ok(Box::new(Arc::new(PyFunctionExecutor {
201+
py_function_executor: executor,
202+
py_exec_ctx,
203+
num_positional_args,
204+
kw_args_names,
205+
result_type,
206+
enable_cache,
207+
behavior_version,
208+
})) as Box<dyn SimpleFunctionExecutor>)
223209
}
224210
};
225211

src/py/mod.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,9 @@ fn stop(py: Python<'_>) -> PyResult<()> {
6868
}
6969

7070
#[pyfunction]
71-
fn register_function_factory(
72-
name: String,
73-
py_function_factory: Py<PyAny>,
74-
is_async: bool,
75-
) -> PyResult<()> {
71+
fn register_function_factory(name: String, py_function_factory: Py<PyAny>) -> PyResult<()> {
7672
let factory = PyFunctionFactory {
7773
py_function_factory,
78-
is_async,
7974
};
8075
register_factory(name, ExecutorFactory::SimpleFunction(Arc::new(factory))).into_py_result()
8176
}

0 commit comments

Comments
 (0)