Skip to content

Commit 38f11e5

Browse files
authored
feat(async-fn): Allow custom functions to be async. (#349)
1 parent 774a296 commit 38f11e5

File tree

10 files changed

+168
-49
lines changed

10 files changed

+168
-49
lines changed

python/cocoindex/flow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from . import op
2020
from .convert import dump_engine_object
2121
from .typing import encode_enriched_type
22+
from .runtime import op_execution_context
2223

2324
class _NameBuilder:
2425
_existing_names: set[str]
@@ -475,7 +476,7 @@ def _create_engine_flow() -> _engine.Flow:
475476
root_scope = DataScope(
476477
flow_builder_state, flow_builder_state.engine_flow_builder.root_scope())
477478
fl_def(FlowBuilder(flow_builder_state), root_scope)
478-
return flow_builder_state.engine_flow_builder.build_flow()
479+
return flow_builder_state.engine_flow_builder.build_flow(op_execution_context.event_loop)
479480

480481
return Flow(_create_engine_flow)
481482

@@ -570,7 +571,8 @@ def __init__(
570571
output = flow_fn(**kwargs)
571572
flow_builder_state.engine_flow_builder.set_direct_output(
572573
_data_slice_state(output).engine_data_slice)
573-
self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow()
574+
self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow(
575+
op_execution_context.event_loop)
574576

575577
def __str__(self):
576578
return str(self._engine_flow)

python/cocoindex/lib.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
"""
22
Library level functions and states.
33
"""
4+
import asyncio
45
import os
56
import sys
67
import functools
78
import inspect
8-
import asyncio
9-
from typing import Callable, Self, Any
9+
10+
from typing import Callable, Self
1011
from dataclasses import dataclass
1112

1213
from . import _engine
1314
from . import flow, query, cli
1415

16+
1517
def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
1618
value = os.getenv(env_name)
1719
if value is None:

python/cocoindex/op.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Facilities for defining cocoindex operations.
33
"""
4+
import asyncio
45
import dataclasses
56
import inspect
67

@@ -78,6 +79,7 @@ def _register_op_factory(
7879
category: OpCategory,
7980
expected_args: list[tuple[str, inspect.Parameter]],
8081
expected_return,
82+
is_async: bool,
8183
executor_cls: type,
8284
spec_cls: type,
8385
op_args: OpArgs,
@@ -168,6 +170,19 @@ def __call__(self, *args, **kwargs):
168170
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
169171
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg)
170172
for arg_name, arg in kwargs.items()}
173+
if is_async:
174+
async def _inner():
175+
if op_args.gpu:
176+
await asyncio.to_thread(_gpu_dispatch_lock.acquire)
177+
try:
178+
output = await super(_WrappedClass, self).__call__(
179+
*converted_args, **converted_kwargs)
180+
finally:
181+
if op_args.gpu:
182+
_gpu_dispatch_lock.release()
183+
return to_engine_value(output)
184+
return _inner()
185+
171186
if op_args.gpu:
172187
# For GPU executions, data-level parallelism is applied, so we don't want to
173188
# execute different tasks in parallel.
@@ -189,7 +204,8 @@ def __call__(self, *args, **kwargs):
189204
if category == OpCategory.FUNCTION:
190205
_engine.register_function_factory(
191206
spec_cls.__name__,
192-
_FunctionExecutorFactory(spec_cls, _WrappedClass))
207+
_FunctionExecutorFactory(spec_cls, _WrappedClass),
208+
is_async)
193209
else:
194210
raise ValueError(f"Unsupported executor type {category}")
195211

@@ -214,6 +230,7 @@ def _inner(cls: type[Executor]) -> type:
214230
category=spec_cls._op_category,
215231
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
216232
expected_return=sig.return_annotation,
233+
is_async=inspect.iscoroutinefunction(cls.__call__),
217234
executor_cls=cls,
218235
spec_cls=spec_cls,
219236
op_args=op_args)
@@ -249,6 +266,7 @@ class _Spec(FunctionSpec):
249266
category=OpCategory.FUNCTION,
250267
expected_args=list(sig.parameters.items()),
251268
expected_return=sig.return_annotation,
269+
is_async=inspect.iscoroutinefunction(fn),
252270
executor_cls=_Executor,
253271
spec_cls=_Spec,
254272
op_args=op_args)

python/cocoindex/runtime.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import threading
2+
import asyncio
3+
4+
class _OpExecutionContext:
5+
_lock: threading.Lock
6+
_event_loop: asyncio.AbstractEventLoop | None = None
7+
8+
def __init__(self):
9+
self._lock = threading.Lock()
10+
11+
@property
12+
def event_loop(self) -> asyncio.AbstractEventLoop:
13+
"""Get the event loop for the cocoindex library."""
14+
with self._lock:
15+
if self._event_loop is None:
16+
self._event_loop = asyncio.new_event_loop()
17+
asyncio.set_event_loop(self._event_loop)
18+
threading.Thread(target=self._event_loop.run_forever, daemon=True).start()
19+
return self._event_loop
20+
21+
op_execution_context = _OpExecutionContext()

src/builder/analyzed_flow.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::prelude::*;
1+
use crate::{ops::interface::FlowInstanceContext, prelude::*};
22

33
use super::{analyzer, plan};
44
use crate::{
@@ -19,12 +19,16 @@ pub struct AnalyzedFlow {
1919
impl AnalyzedFlow {
2020
pub async fn from_flow_instance(
2121
flow_instance: crate::base::spec::FlowInstanceSpec,
22+
flow_instance_ctx: Arc<FlowInstanceContext>,
2223
existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
2324
registry: &ExecutorFactoryRegistry,
2425
) -> Result<Self> {
25-
let ctx = analyzer::build_flow_instance_context(&flow_instance.name);
26-
let (data_schema, execution_plan_fut, desired_state) =
27-
analyzer::analyze_flow(&flow_instance, &ctx, existing_flow_ss, registry)?;
26+
let (data_schema, execution_plan_fut, desired_state) = analyzer::analyze_flow(
27+
&flow_instance,
28+
&flow_instance_ctx,
29+
existing_flow_ss,
30+
registry,
31+
)?;
2832
let setup_status_check =
2933
setup::check_flow_setup_status(Some(&desired_state), existing_flow_ss)?;
3034
let execution_plan = if setup_status_check.is_up_to_date() {
@@ -72,8 +76,9 @@ impl AnalyzedTransientFlow {
7276
pub async fn from_transient_flow(
7377
transient_flow: spec::TransientFlowSpec,
7478
registry: &ExecutorFactoryRegistry,
79+
py_exec_ctx: Option<crate::py::PythonExecutionContext>,
7580
) -> Result<Self> {
76-
let ctx = analyzer::build_flow_instance_context(&transient_flow.name);
81+
let ctx = analyzer::build_flow_instance_context(&transient_flow.name, py_exec_ctx);
7782
let (output_type, data_schema, execution_plan_fut) =
7883
analyzer::analyze_transient_flow(&transient_flow, &ctx, registry)?;
7984
Ok(Self {

src/builder/analyzer.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,10 +1023,14 @@ impl AnalyzerContext<'_> {
10231023
}
10241024
}
10251025

1026-
pub fn build_flow_instance_context(flow_inst_name: &str) -> Arc<FlowInstanceContext> {
1026+
pub fn build_flow_instance_context(
1027+
flow_inst_name: &str,
1028+
py_exec_ctx: Option<crate::py::PythonExecutionContext>,
1029+
) -> Arc<FlowInstanceContext> {
10271030
Arc::new(FlowInstanceContext {
10281031
flow_instance_name: flow_inst_name.to_string(),
10291032
auth_registry: get_auth_registry().clone(),
1033+
py_exec_ctx: py_exec_ctx.map(Arc::new),
10301034
})
10311035
}
10321036

src/builder/flow_builder.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ impl FlowBuilder {
347347
.get(name)
348348
.cloned();
349349
let root_data_scope = Arc::new(Mutex::new(DataScopeBuilder::new()));
350-
let flow_inst_context = build_flow_instance_context(name);
350+
let flow_inst_context = build_flow_instance_context(name, None);
351351
let result = Self {
352352
lib_context,
353353
flow_inst_context,
@@ -636,17 +636,22 @@ impl FlowBuilder {
636636
}))
637637
}
638638

639-
pub fn build_flow(&self, py: Python<'_>) -> PyResult<py::Flow> {
639+
pub fn build_flow(&self, py: Python<'_>, py_event_loop: Py<PyAny>) -> PyResult<py::Flow> {
640640
let spec = spec::FlowInstanceSpec {
641641
name: self.flow_instance_name.clone(),
642642
import_ops: self.import_ops.clone(),
643643
reactive_ops: self.reactive_ops.clone(),
644644
export_ops: self.export_ops.clone(),
645645
};
646+
let flow_instance_ctx = build_flow_instance_context(
647+
&self.flow_instance_name,
648+
Some(crate::py::PythonExecutionContext::new(py, py_event_loop)),
649+
);
646650
let analyzed_flow = py
647651
.allow_threads(|| {
648652
get_runtime().block_on(super::AnalyzedFlow::from_flow_instance(
649653
spec,
654+
flow_instance_ctx,
650655
self.existing_flow_ss.as_ref(),
651656
&crate::ops::executor_factory_registry(),
652657
))
@@ -669,7 +674,11 @@ impl FlowBuilder {
669674
Ok(py::Flow(flow_ctx))
670675
}
671676

672-
pub fn build_transient_flow(&self, py: Python<'_>) -> PyResult<py::TransientFlow> {
677+
pub fn build_transient_flow(
678+
&self,
679+
py: Python<'_>,
680+
py_event_loop: Py<PyAny>,
681+
) -> PyResult<py::TransientFlow> {
673682
if self.direct_input_fields.is_empty() {
674683
return Err(PyException::new_err("expect at least one direct input"));
675684
}
@@ -684,11 +693,13 @@ impl FlowBuilder {
684693
reactive_ops: self.reactive_ops.clone(),
685694
output_value: direct_output_value.clone(),
686695
};
696+
let py_ctx = crate::py::PythonExecutionContext::new(py, py_event_loop);
687697
let analyzed_flow = py
688698
.allow_threads(|| {
689699
get_runtime().block_on(super::AnalyzedTransientFlow::from_transient_flow(
690700
spec,
691701
&crate::ops::executor_factory_registry(),
702+
Some(py_ctx),
692703
))
693704
})
694705
.into_py_result()?;

src/ops/interface.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use serde::Serialize;
1313
pub struct FlowInstanceContext {
1414
pub flow_instance_name: String,
1515
pub auth_registry: Arc<AuthRegistry>,
16+
pub py_exec_ctx: Option<Arc<crate::py::PythonExecutionContext>>,
1617
}
1718

1819
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]

0 commit comments

Comments
 (0)