Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from . import op
from .convert import dump_engine_object
from .typing import encode_enriched_type
from .runtime import op_execution_context

class _NameBuilder:
_existing_names: set[str]
Expand Down Expand Up @@ -475,7 +476,7 @@ def _create_engine_flow() -> _engine.Flow:
root_scope = DataScope(
flow_builder_state, flow_builder_state.engine_flow_builder.root_scope())
fl_def(FlowBuilder(flow_builder_state), root_scope)
return flow_builder_state.engine_flow_builder.build_flow()
return flow_builder_state.engine_flow_builder.build_flow(op_execution_context.event_loop)

return Flow(_create_engine_flow)

Expand Down Expand Up @@ -570,7 +571,8 @@ def __init__(
output = flow_fn(**kwargs)
flow_builder_state.engine_flow_builder.set_direct_output(
_data_slice_state(output).engine_data_slice)
self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow()
self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow(
op_execution_context.event_loop)

def __str__(self):
return str(self._engine_flow)
Expand Down
6 changes: 4 additions & 2 deletions python/cocoindex/lib.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""
Library level functions and states.
"""
import asyncio
import os
import sys
import functools
import inspect
import asyncio
from typing import Callable, Self, Any

from typing import Callable, Self
from dataclasses import dataclass

from . import _engine
from . import flow, query, cli


def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
value = os.getenv(env_name)
if value is None:
Expand Down
20 changes: 19 additions & 1 deletion python/cocoindex/op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Facilities for defining cocoindex operations.
"""
import asyncio
import dataclasses
import inspect

Expand Down Expand Up @@ -78,6 +79,7 @@ def _register_op_factory(
category: OpCategory,
expected_args: list[tuple[str, inspect.Parameter]],
expected_return,
is_async: bool,
executor_cls: type,
spec_cls: type,
op_args: OpArgs,
Expand Down Expand Up @@ -168,6 +170,19 @@ def __call__(self, *args, **kwargs):
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg)
for arg_name, arg in kwargs.items()}
if is_async:
async def _inner():
if op_args.gpu:
await asyncio.to_thread(_gpu_dispatch_lock.acquire)
try:
output = await super(_WrappedClass, self).__call__(
*converted_args, **converted_kwargs)
finally:
if op_args.gpu:
_gpu_dispatch_lock.release()
return to_engine_value(output)
return _inner()

if op_args.gpu:
# For GPU executions, data-level parallelism is applied, so we don't want to
# execute different tasks in parallel.
Expand All @@ -189,7 +204,8 @@ def __call__(self, *args, **kwargs):
if category == OpCategory.FUNCTION:
_engine.register_function_factory(
spec_cls.__name__,
_FunctionExecutorFactory(spec_cls, _WrappedClass))
_FunctionExecutorFactory(spec_cls, _WrappedClass),
is_async)
else:
raise ValueError(f"Unsupported executor type {category}")

Expand All @@ -214,6 +230,7 @@ def _inner(cls: type[Executor]) -> type:
category=spec_cls._op_category,
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
expected_return=sig.return_annotation,
is_async=inspect.iscoroutinefunction(cls.__call__),
executor_cls=cls,
spec_cls=spec_cls,
op_args=op_args)
Expand Down Expand Up @@ -249,6 +266,7 @@ class _Spec(FunctionSpec):
category=OpCategory.FUNCTION,
expected_args=list(sig.parameters.items()),
expected_return=sig.return_annotation,
is_async=inspect.iscoroutinefunction(fn),
executor_cls=_Executor,
spec_cls=_Spec,
op_args=op_args)
Expand Down
21 changes: 21 additions & 0 deletions python/cocoindex/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import threading
import asyncio

class _OpExecutionContext:
_lock: threading.Lock
_event_loop: asyncio.AbstractEventLoop | None = None

def __init__(self):
self._lock = threading.Lock()

@property
def event_loop(self) -> asyncio.AbstractEventLoop:
"""Get the event loop for the cocoindex library."""
with self._lock:
if self._event_loop is None:
self._event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._event_loop)
threading.Thread(target=self._event_loop.run_forever, daemon=True).start()
return self._event_loop

op_execution_context = _OpExecutionContext()
15 changes: 10 additions & 5 deletions src/builder/analyzed_flow.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::prelude::*;
use crate::{ops::interface::FlowInstanceContext, prelude::*};

use super::{analyzer, plan};
use crate::{
Expand All @@ -19,12 +19,16 @@ pub struct AnalyzedFlow {
impl AnalyzedFlow {
pub async fn from_flow_instance(
flow_instance: crate::base::spec::FlowInstanceSpec,
flow_instance_ctx: Arc<FlowInstanceContext>,
existing_flow_ss: Option<&setup::FlowSetupState<setup::ExistingMode>>,
registry: &ExecutorFactoryRegistry,
) -> Result<Self> {
let ctx = analyzer::build_flow_instance_context(&flow_instance.name);
let (data_schema, execution_plan_fut, desired_state) =
analyzer::analyze_flow(&flow_instance, &ctx, existing_flow_ss, registry)?;
let (data_schema, execution_plan_fut, desired_state) = analyzer::analyze_flow(
&flow_instance,
&flow_instance_ctx,
existing_flow_ss,
registry,
)?;
let setup_status_check =
setup::check_flow_setup_status(Some(&desired_state), existing_flow_ss)?;
let execution_plan = if setup_status_check.is_up_to_date() {
Expand Down Expand Up @@ -72,8 +76,9 @@ impl AnalyzedTransientFlow {
pub async fn from_transient_flow(
transient_flow: spec::TransientFlowSpec,
registry: &ExecutorFactoryRegistry,
py_exec_ctx: Option<crate::py::PythonExecutionContext>,
) -> Result<Self> {
let ctx = analyzer::build_flow_instance_context(&transient_flow.name);
let ctx = analyzer::build_flow_instance_context(&transient_flow.name, py_exec_ctx);
let (output_type, data_schema, execution_plan_fut) =
analyzer::analyze_transient_flow(&transient_flow, &ctx, registry)?;
Ok(Self {
Expand Down
6 changes: 5 additions & 1 deletion src/builder/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1023,10 +1023,14 @@ impl AnalyzerContext<'_> {
}
}

pub fn build_flow_instance_context(flow_inst_name: &str) -> Arc<FlowInstanceContext> {
pub fn build_flow_instance_context(
flow_inst_name: &str,
py_exec_ctx: Option<crate::py::PythonExecutionContext>,
) -> Arc<FlowInstanceContext> {
Arc::new(FlowInstanceContext {
flow_instance_name: flow_inst_name.to_string(),
auth_registry: get_auth_registry().clone(),
py_exec_ctx: py_exec_ctx.map(Arc::new),
})
}

Expand Down
17 changes: 14 additions & 3 deletions src/builder/flow_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ impl FlowBuilder {
.get(name)
.cloned();
let root_data_scope = Arc::new(Mutex::new(DataScopeBuilder::new()));
let flow_inst_context = build_flow_instance_context(name);
let flow_inst_context = build_flow_instance_context(name, None);
let result = Self {
lib_context,
flow_inst_context,
Expand Down Expand Up @@ -636,17 +636,22 @@ impl FlowBuilder {
}))
}

pub fn build_flow(&self, py: Python<'_>) -> PyResult<py::Flow> {
pub fn build_flow(&self, py: Python<'_>, py_event_loop: Py<PyAny>) -> PyResult<py::Flow> {
let spec = spec::FlowInstanceSpec {
name: self.flow_instance_name.clone(),
import_ops: self.import_ops.clone(),
reactive_ops: self.reactive_ops.clone(),
export_ops: self.export_ops.clone(),
};
let flow_instance_ctx = build_flow_instance_context(
&self.flow_instance_name,
Some(crate::py::PythonExecutionContext::new(py, py_event_loop)),
);
let analyzed_flow = py
.allow_threads(|| {
get_runtime().block_on(super::AnalyzedFlow::from_flow_instance(
spec,
flow_instance_ctx,
self.existing_flow_ss.as_ref(),
&crate::ops::executor_factory_registry(),
))
Expand All @@ -669,7 +674,11 @@ impl FlowBuilder {
Ok(py::Flow(flow_ctx))
}

pub fn build_transient_flow(&self, py: Python<'_>) -> PyResult<py::TransientFlow> {
pub fn build_transient_flow(
&self,
py: Python<'_>,
py_event_loop: Py<PyAny>,
) -> PyResult<py::TransientFlow> {
if self.direct_input_fields.is_empty() {
return Err(PyException::new_err("expect at least one direct input"));
}
Expand All @@ -684,11 +693,13 @@ impl FlowBuilder {
reactive_ops: self.reactive_ops.clone(),
output_value: direct_output_value.clone(),
};
let py_ctx = crate::py::PythonExecutionContext::new(py, py_event_loop);
let analyzed_flow = py
.allow_threads(|| {
get_runtime().block_on(super::AnalyzedTransientFlow::from_transient_flow(
spec,
&crate::ops::executor_factory_registry(),
Some(py_ctx),
))
})
.into_py_result()?;
Expand Down
1 change: 1 addition & 0 deletions src/ops/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use serde::Serialize;
pub struct FlowInstanceContext {
pub flow_instance_name: String,
pub auth_registry: Arc<AuthRegistry>,
pub py_exec_ctx: Option<Arc<crate::py::PythonExecutionContext>>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
Expand Down
Loading