Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 201 additions & 4 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,15 +909,193 @@ def _inner(handler: Callable[[str], Any]) -> Callable[[str], Any]:
return _inner


class TransientFlowWrapper(Flow):
"""
A wrapper for transient flows that doesn't support persistence operations.
"""

def __init__(self, name: str, fl_def: Callable[[FlowBuilder, DataScope], None]):
self._fl_def = fl_def
# For transient flows, we don't create a regular engine flow creator
super().__init__(name, self._create_transient_engine_flow)

def _create_transient_engine_flow(self) -> _engine.Flow:
"""
Create a transient engine flow. This should not be called directly.
"""
raise NotImplementedError(
"Transient flows don't support this operation. Use evaluate() instead."
)

def setup(self, report_to_stdout: bool = False) -> None:
"""
Setup is not supported for transient flows.
"""
raise NotImplementedError("Setup is not supported for transient flows")

async def setup_async(self, report_to_stdout: bool = False) -> None:
"""
Setup is not supported for transient flows.
"""
raise NotImplementedError("Setup is not supported for transient flows")

def drop(self, report_to_stdout: bool = False) -> None:
"""
Drop is not supported for transient flows.
"""
raise NotImplementedError("Drop is not supported for transient flows")

async def drop_async(self, report_to_stdout: bool = False) -> None:
"""
Drop is not supported for transient flows.
"""
raise NotImplementedError("Drop is not supported for transient flows")

def update(self, /, *, reexport_targets: bool = False) -> _engine.IndexUpdateInfo:
"""
Update is not supported for transient flows.
"""
raise NotImplementedError("Live updates are not supported for transient flows")

async def update_async(
self, /, *, reexport_targets: bool = False
) -> _engine.IndexUpdateInfo:
"""
Update is not supported for transient flows.
"""
raise NotImplementedError("Live updates are not supported for transient flows")

def evaluate_and_dump(
self, options: EvaluateAndDumpOptions
) -> _engine.IndexUpdateInfo:
"""
Evaluate and dump is not supported for transient flows.
"""
raise NotImplementedError(
"Evaluate and dump is not supported for transient flows"
)

def internal_flow(self) -> _engine.Flow:
"""
Internal flow is not supported for transient flows.
"""
raise NotImplementedError(
"Internal flow access is not supported for transient flows"
)

async def internal_flow_async(self) -> _engine.Flow:
"""
Internal flow is not supported for transient flows.
"""
raise NotImplementedError(
"Internal flow access is not supported for transient flows"
)

def add_query_handler(
self,
name: str,
handler: Callable[[str], Any],
/,
*,
result_fields: QueryHandlerResultFields | None = None,
) -> None:
"""
Query handlers are not supported for transient flows.
"""
raise NotImplementedError(
"Query handlers are not supported for transient flows"
)

def evaluate(self, **input_values: Any) -> Any:
"""
Evaluate the transient flow with the given input values.

Args:
**input_values: Input values as keyword arguments.

Returns:
The result of evaluating the transient flow.
"""
return execution_context.run(self.evaluate_async(**input_values))

async def evaluate_async(self, **input_values: Any) -> Any:
"""
Evaluate the transient flow with the given input values asynchronously.

Args:
**input_values: Input values as keyword arguments.

Returns:
The result of evaluating the transient flow.
"""
flow_full_name = get_flow_full_name(self._name)
validate_full_flow_name(flow_full_name)
flow_builder_state = _FlowBuilderState(flow_full_name)

# Add direct inputs for each input value and collect the data slices
input_data_slices: dict[str, DataSlice[Any]] = {}
for key, value in input_values.items():
encoded_type = encode_enriched_type(type(value))
if encoded_type is None:
raise ValueError(
f"Input value `{key}` has unsupported type {type(value)}"
)

# Add the direct input to the flow builder
data_slice = flow_builder_state.engine_flow_builder.add_direct_input(
key, dump_engine_object(encoded_type)
)
input_data_slices[key] = DataSlice(
_DataSliceState(flow_builder_state, data_slice)
)

# Create a root scope for the flow definition
root_scope = DataScope(
flow_builder_state, flow_builder_state.engine_flow_builder.root_scope()
)

# Add input data slices to the root scope so the flow definition can access them
for key, data_slice in input_data_slices.items():
root_scope[key] = data_slice

# Build the flow definition - this should set the output
self._fl_def(FlowBuilder(flow_builder_state), root_scope)

# Build the transient flow
transient_engine_flow = (
await flow_builder_state.engine_flow_builder.build_transient_flow_async(
execution_context.event_loop
)
)

# Convert input values to the format expected by the engine
input_args = [input_values[key] for key in input_values.keys()]

# Evaluate the transient flow
result = await transient_engine_flow.evaluate_async(input_args)
return result


def _create_lazy_flow(
name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]
name: str | None,
fl_def: Callable[[FlowBuilder, DataScope], None],
*,
transient: bool = False,
) -> Flow:
"""
Create a flow without really building it yet.
The flow will be built the first time when it's really needed.

Args:
name: The name of the flow.
fl_def: The flow definition function.
transient: If True, creates a transient flow that doesn't maintain state.
"""
flow_name = _flow_name_builder.build_name(name, prefix="_flow_")

if transient:
return TransientFlowWrapper(flow_name, fl_def)

def _create_engine_flow() -> _engine.Flow:
flow_full_name = get_flow_full_name(flow_name)
validate_full_flow_name(flow_full_name)
Expand All @@ -944,14 +1122,25 @@ def get_flow_full_name(name: str) -> str:
return f"{setting.get_app_namespace(trailing_delimiter='.')}{name}"


def open_flow(name: str, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow:
def open_flow(
name: str,
fl_def: Callable[[FlowBuilder, DataScope], None],
*,
transient: bool = False,
) -> Flow:
"""
Open a flow, with the given name and definition.

Args:
name: The name of the flow.
fl_def: The flow definition function.
transient: If True, creates a transient flow that doesn't maintain state
and doesn't support live updates. Default is False.
"""
with _flows_lock:
if name in _flows:
raise KeyError(f"Flow with name {name} already exists")
fl = _flows[name] = _create_lazy_flow(name, fl_def)
fl = _flows[name] = _create_lazy_flow(name, fl_def, transient=transient)
return fl


Expand All @@ -971,11 +1160,19 @@ def remove_flow(fl: Flow) -> None:

def flow_def(
name: str | None = None,
transient: bool = False,
) -> Callable[[Callable[[FlowBuilder, DataScope], None]], Flow]:
"""
A decorator to wrap the flow definition.

Args:
name: The name of the flow. If None, uses the function name.
transient: If True, creates a transient flow that doesn't maintain state
and doesn't support live updates. Default is False.
"""
return lambda fl_def: open_flow(name or fl_def.__name__, fl_def)
return lambda fl_def: open_flow(
name or fl_def.__name__, fl_def, transient=transient
)


def flow_names() -> list[str]:
Expand Down