diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 4feae4dc1..7e10e69fa 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -19,6 +19,7 @@ from . import op from .convert import dump_engine_object from .typing import encode_enriched_type +from .runtime import op_execution_context class _NameBuilder: _existing_names: set[str] @@ -475,7 +476,7 @@ def _create_engine_flow() -> _engine.Flow: root_scope = DataScope( flow_builder_state, flow_builder_state.engine_flow_builder.root_scope()) fl_def(FlowBuilder(flow_builder_state), root_scope) - return flow_builder_state.engine_flow_builder.build_flow() + return flow_builder_state.engine_flow_builder.build_flow(op_execution_context.event_loop) return Flow(_create_engine_flow) @@ -570,7 +571,8 @@ def __init__( output = flow_fn(**kwargs) flow_builder_state.engine_flow_builder.set_direct_output( _data_slice_state(output).engine_data_slice) - self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow() + self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow( + op_execution_context.event_loop) def __str__(self): return str(self._engine_flow) diff --git a/python/cocoindex/lib.py b/python/cocoindex/lib.py index 44525f231..391a47084 100644 --- a/python/cocoindex/lib.py +++ b/python/cocoindex/lib.py @@ -1,17 +1,19 @@ """ Library level functions and states. """ +import asyncio import os import sys import functools import inspect -import asyncio -from typing import Callable, Self, Any + +from typing import Callable, Self from dataclasses import dataclass from . import _engine from . import flow, query, cli + def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False): value = os.getenv(env_name) if value is None: diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 5d3532c9e..f97ee8536 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -1,6 +1,7 @@ """ Facilities for defining cocoindex operations. """ +import asyncio import dataclasses import inspect @@ -78,6 +79,7 @@ def _register_op_factory( category: OpCategory, expected_args: list[tuple[str, inspect.Parameter]], expected_return, + is_async: bool, executor_cls: type, spec_cls: type, op_args: OpArgs, @@ -168,6 +170,19 @@ def __call__(self, *args, **kwargs): converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args)) converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg) for arg_name, arg in kwargs.items()} + if is_async: + async def _inner(): + if op_args.gpu: + await asyncio.to_thread(_gpu_dispatch_lock.acquire) + try: + output = await super(_WrappedClass, self).__call__( + *converted_args, **converted_kwargs) + finally: + if op_args.gpu: + _gpu_dispatch_lock.release() + return to_engine_value(output) + return _inner() + if op_args.gpu: # For GPU executions, data-level parallelism is applied, so we don't want to # execute different tasks in parallel. @@ -189,7 +204,8 @@ def __call__(self, *args, **kwargs): if category == OpCategory.FUNCTION: _engine.register_function_factory( spec_cls.__name__, - _FunctionExecutorFactory(spec_cls, _WrappedClass)) + _FunctionExecutorFactory(spec_cls, _WrappedClass), + is_async) else: raise ValueError(f"Unsupported executor type {category}") @@ -214,6 +230,7 @@ def _inner(cls: type[Executor]) -> type: category=spec_cls._op_category, expected_args=list(sig.parameters.items())[1:], # First argument is `self` expected_return=sig.return_annotation, + is_async=inspect.iscoroutinefunction(cls.__call__), executor_cls=cls, spec_cls=spec_cls, op_args=op_args) @@ -249,6 +266,7 @@ class _Spec(FunctionSpec): category=OpCategory.FUNCTION, expected_args=list(sig.parameters.items()), expected_return=sig.return_annotation, + is_async=inspect.iscoroutinefunction(fn), executor_cls=_Executor, spec_cls=_Spec, op_args=op_args) diff --git a/python/cocoindex/runtime.py b/python/cocoindex/runtime.py new file mode 100644 index 000000000..48bf5f579 --- /dev/null +++ b/python/cocoindex/runtime.py @@ -0,0 +1,21 @@ +import threading +import asyncio + +class _OpExecutionContext: + _lock: threading.Lock + _event_loop: asyncio.AbstractEventLoop | None = None + + def __init__(self): + self._lock = threading.Lock() + + @property + def event_loop(self) -> asyncio.AbstractEventLoop: + """Get the event loop for the cocoindex library.""" + with self._lock: + if self._event_loop is None: + self._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._event_loop) + threading.Thread(target=self._event_loop.run_forever, daemon=True).start() + return self._event_loop + +op_execution_context = _OpExecutionContext() diff --git a/src/builder/analyzed_flow.rs b/src/builder/analyzed_flow.rs index e4e6759fd..3bd037787 100644 --- a/src/builder/analyzed_flow.rs +++ b/src/builder/analyzed_flow.rs @@ -1,4 +1,4 @@ -use crate::prelude::*; +use crate::{ops::interface::FlowInstanceContext, prelude::*}; use super::{analyzer, plan}; use crate::{ @@ -19,12 +19,16 @@ pub struct AnalyzedFlow { impl AnalyzedFlow { pub async fn from_flow_instance( flow_instance: crate::base::spec::FlowInstanceSpec, + flow_instance_ctx: Arc, existing_flow_ss: Option<&setup::FlowSetupState>, registry: &ExecutorFactoryRegistry, ) -> Result { - let ctx = analyzer::build_flow_instance_context(&flow_instance.name); - let (data_schema, execution_plan_fut, desired_state) = - analyzer::analyze_flow(&flow_instance, &ctx, existing_flow_ss, registry)?; + let (data_schema, execution_plan_fut, desired_state) = analyzer::analyze_flow( + &flow_instance, + &flow_instance_ctx, + existing_flow_ss, + registry, + )?; let setup_status_check = setup::check_flow_setup_status(Some(&desired_state), existing_flow_ss)?; let execution_plan = if setup_status_check.is_up_to_date() { @@ -72,8 +76,9 @@ impl AnalyzedTransientFlow { pub async fn from_transient_flow( transient_flow: spec::TransientFlowSpec, registry: &ExecutorFactoryRegistry, + py_exec_ctx: Option, ) -> Result { - let ctx = analyzer::build_flow_instance_context(&transient_flow.name); + let ctx = analyzer::build_flow_instance_context(&transient_flow.name, py_exec_ctx); let (output_type, data_schema, execution_plan_fut) = analyzer::analyze_transient_flow(&transient_flow, &ctx, registry)?; Ok(Self { diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index a489c5e85..d96518c62 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -1023,10 +1023,14 @@ impl AnalyzerContext<'_> { } } -pub fn build_flow_instance_context(flow_inst_name: &str) -> Arc { +pub fn build_flow_instance_context( + flow_inst_name: &str, + py_exec_ctx: Option, +) -> Arc { Arc::new(FlowInstanceContext { flow_instance_name: flow_inst_name.to_string(), auth_registry: get_auth_registry().clone(), + py_exec_ctx: py_exec_ctx.map(Arc::new), }) } diff --git a/src/builder/flow_builder.rs b/src/builder/flow_builder.rs index 177b85f97..1354f7344 100644 --- a/src/builder/flow_builder.rs +++ b/src/builder/flow_builder.rs @@ -347,7 +347,7 @@ impl FlowBuilder { .get(name) .cloned(); let root_data_scope = Arc::new(Mutex::new(DataScopeBuilder::new())); - let flow_inst_context = build_flow_instance_context(name); + let flow_inst_context = build_flow_instance_context(name, None); let result = Self { lib_context, flow_inst_context, @@ -636,17 +636,22 @@ impl FlowBuilder { })) } - pub fn build_flow(&self, py: Python<'_>) -> PyResult { + pub fn build_flow(&self, py: Python<'_>, py_event_loop: Py) -> PyResult { let spec = spec::FlowInstanceSpec { name: self.flow_instance_name.clone(), import_ops: self.import_ops.clone(), reactive_ops: self.reactive_ops.clone(), export_ops: self.export_ops.clone(), }; + let flow_instance_ctx = build_flow_instance_context( + &self.flow_instance_name, + Some(crate::py::PythonExecutionContext::new(py, py_event_loop)), + ); let analyzed_flow = py .allow_threads(|| { get_runtime().block_on(super::AnalyzedFlow::from_flow_instance( spec, + flow_instance_ctx, self.existing_flow_ss.as_ref(), &crate::ops::executor_factory_registry(), )) @@ -669,7 +674,11 @@ impl FlowBuilder { Ok(py::Flow(flow_ctx)) } - pub fn build_transient_flow(&self, py: Python<'_>) -> PyResult { + pub fn build_transient_flow( + &self, + py: Python<'_>, + py_event_loop: Py, + ) -> PyResult { if self.direct_input_fields.is_empty() { return Err(PyException::new_err("expect at least one direct input")); } @@ -684,11 +693,13 @@ impl FlowBuilder { reactive_ops: self.reactive_ops.clone(), output_value: direct_output_value.clone(), }; + let py_ctx = crate::py::PythonExecutionContext::new(py, py_event_loop); let analyzed_flow = py .allow_threads(|| { get_runtime().block_on(super::AnalyzedTransientFlow::from_transient_flow( spec, &crate::ops::executor_factory_registry(), + Some(py_ctx), )) }) .into_py_result()?; diff --git a/src/ops/interface.rs b/src/ops/interface.rs index 927fd86c4..67aec53ea 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -13,6 +13,7 @@ use serde::Serialize; pub struct FlowInstanceContext { pub flow_instance_name: String, pub auth_registry: Arc, + pub py_exec_ctx: Option>, } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index 17d8cb703..7dddb7a88 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -14,7 +14,7 @@ use crate::{ builder::plan, py, }; -use anyhow::Result; +use anyhow::{anyhow, Result}; use super::interface::{FlowInstanceContext, SimpleFunctionExecutor, SimpleFunctionFactory}; @@ -39,6 +39,9 @@ impl PyOpArgSchema { struct PyFunctionExecutor { py_function_executor: Py, + is_async: bool, + py_exec_ctx: Arc, + num_positional_args: usize, kw_args_names: Vec>, result_type: schema::EnrichedValueType, @@ -47,47 +50,76 @@ struct PyFunctionExecutor { behavior_version: Option, } +impl PyFunctionExecutor { + fn call_py_fn<'py>( + &self, + py: Python<'py>, + input: Vec, + ) -> Result> { + let mut args = Vec::with_capacity(self.num_positional_args); + for v in input[0..self.num_positional_args].iter() { + args.push(py::value_to_py_object(py, v)?); + } + + let kwargs = if self.kw_args_names.is_empty() { + None + } else { + let mut kwargs = Vec::with_capacity(self.kw_args_names.len()); + for (name, v) in self + .kw_args_names + .iter() + .zip(input[self.num_positional_args..].iter()) + { + kwargs.push((name.bind(py), py::value_to_py_object(py, v)?)); + } + Some(kwargs) + }; + + let result = self.py_function_executor.call( + py, + PyTuple::new(py, args.into_iter())?, + kwargs + .map(|kwargs| -> Result<_> { Ok(kwargs.into_py_dict(py)?) }) + .transpose()? + .as_ref(), + )?; + Ok(result.into_bound(py)) + } +} + #[async_trait] impl SimpleFunctionExecutor for Arc { async fn evaluate(&self, input: Vec) -> Result { let self = self.clone(); - let result = tokio::task::spawn_blocking(move || { + let result = if self.is_async { + let result_fut = Python::with_gil(|py| -> Result<_> { + let result = self.call_py_fn(py, input)?; + let task_locals = pyo3_async_runtimes::TaskLocals::new( + self.py_exec_ctx.event_loop.bind(py).clone(), + ); + Ok(pyo3_async_runtimes::into_future_with_locals( + &task_locals, + result, + )?) + })?; + let result = result_fut.await?; Python::with_gil(|py| -> Result<_> { - let mut args = Vec::with_capacity(self.num_positional_args); - for v in input[0..self.num_positional_args].iter() { - args.push(py::value_to_py_object(py, v)?); - } - - let kwargs = if self.kw_args_names.is_empty() { - None - } else { - let mut kwargs = Vec::with_capacity(self.kw_args_names.len()); - for (name, v) in self - .kw_args_names - .iter() - .zip(input[self.num_positional_args..].iter()) - { - kwargs.push((name.bind(py), py::value_to_py_object(py, v)?)); - } - Some(kwargs) - }; - - let result = self.py_function_executor.call( - py, - PyTuple::new(py, args.into_iter())?, - kwargs - .map(|kwargs| -> Result<_> { Ok(kwargs.into_py_dict(py)?) }) - .transpose()? - .as_ref(), - )?; - Ok(py::value_from_py_object( &self.result_type.typ, - result.bind(py), + &result.into_bound(py), )?) + })? + } else { + tokio::task::spawn_blocking(move || { + Python::with_gil(|py| -> Result<_> { + Ok(py::value_from_py_object( + &self.result_type.typ, + &self.call_py_fn(py, input)?, + )?) + }) }) - }) - .await??; + .await?? + }; Ok(result) } @@ -102,6 +134,7 @@ impl SimpleFunctionExecutor for Arc { pub(crate) struct PyFunctionFactory { pub py_function_factory: Py, + pub is_async: bool, } impl SimpleFunctionFactory for PyFunctionFactory { @@ -109,7 +142,7 @@ impl SimpleFunctionFactory for PyFunctionFactory { self: Arc, spec: serde_json::Value, input_schema: Vec, - _context: Arc, + context: Arc, ) -> Result<( schema::EnrichedValueType, BoxFuture<'static, Result>>, @@ -157,6 +190,11 @@ impl SimpleFunctionFactory for PyFunctionFactory { let executor_fut = { let result_type = result_type.clone(); async move { + let py_exec_ctx = context + .py_exec_ctx + .as_ref() + .ok_or_else(|| anyhow!("Python execution context is missing"))? + .clone(); let executor = tokio::task::spawn_blocking(move || -> Result<_> { let (enable_cache, behavior_version) = Python::with_gil(|py| -> anyhow::Result<_> { @@ -171,6 +209,8 @@ impl SimpleFunctionFactory for PyFunctionFactory { })?; Ok(Box::new(Arc::new(PyFunctionExecutor { py_function_executor: executor, + is_async: self.is_async, + py_exec_ctx, num_positional_args, kw_args_names, result_type, diff --git a/src/py/mod.rs b/src/py/mod.rs index 726d7db0f..cbc88592b 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -16,6 +16,16 @@ use std::collections::btree_map; mod convert; pub use convert::*; +pub struct PythonExecutionContext { + pub event_loop: Py, +} + +impl PythonExecutionContext { + pub fn new(_py: Python<'_>, event_loop: Py) -> Self { + Self { event_loop } + } +} + pub trait IntoPyResult { fn into_py_result(self) -> PyResult; } @@ -58,9 +68,14 @@ fn stop(py: Python<'_>) -> PyResult<()> { } #[pyfunction] -fn register_function_factory(name: String, py_function_factory: Py) -> PyResult<()> { +fn register_function_factory( + name: String, + py_function_factory: Py, + is_async: bool, +) -> PyResult<()> { let factory = PyFunctionFactory { py_function_factory, + is_async, }; register_factory(name, ExecutorFactory::SimpleFunction(Arc::new(factory))).into_py_result() }