Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@
from __future__ import annotations

import asyncio
import re
import inspect
import datetime
import functools

import inspect
import re
from dataclasses import dataclass
from enum import Enum
from threading import Lock
from typing import (
Any,
Callable,
Generic,
NamedTuple,
Sequence,
TypeVar,
Generic,
cast,
get_args,
get_origin,
NamedTuple,
cast,
)
from threading import Lock
from enum import Enum
from dataclasses import dataclass

from rich.text import Text
from rich.tree import Tree

Expand All @@ -32,9 +32,10 @@
from . import op
from . import setting
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
from .typing import encode_enriched_type
from .op import FunctionSpec
from .runtime import execution_context
from .setup import make_setup_bundle, make_drop_bundle
from .setup import make_drop_bundle, make_setup_bundle
from .typing import encode_enriched_type


class _NameBuilder:
Expand Down Expand Up @@ -92,6 +93,30 @@ def _spec_kind(spec: Any) -> str:
return cast(str, spec.__class__.__name__)


def _transform_helper(
flow_builder_state: _FlowBuilderState,
fn_spec: FunctionSpec,
transform_args: list[tuple[Any, str | None]],
name: str | None = None,
) -> DataSlice[Any]:
if not isinstance(fn_spec, FunctionSpec):
raise ValueError("transform() can only be called on a CocoIndex function")

return _create_data_slice(
flow_builder_state,
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
_spec_kind(fn_spec),
dump_engine_object(fn_spec),
transform_args,
target_scope,
flow_builder_state.field_name_builder.build_name(
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
),
),
name,
)


T = TypeVar("T")
S = TypeVar("S")

Expand Down Expand Up @@ -191,31 +216,19 @@ def transform(
"""
Apply a function to the data slice.
"""
if not isinstance(fn_spec, op.FunctionSpec):
raise ValueError("transform() can only be called on a CocoIndex function")

transform_args: list[tuple[Any, str | None]]
transform_args = [(self._state.engine_data_slice, None)]
transform_args: list[tuple[Any, str | None]] = [
(self._state.engine_data_slice, None)
]
transform_args += [
(self._state.flow_builder_state.get_data_slice(v), None) for v in args
]
transform_args += [
(self._state.flow_builder_state.get_data_slice(v), k)
for (k, v) in kwargs.items()
for k, v in kwargs.items()
]

flow_builder_state = self._state.flow_builder_state
return _create_data_slice(
flow_builder_state,
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
_spec_kind(fn_spec),
dump_engine_object(fn_spec),
transform_args,
target_scope,
flow_builder_state.field_name_builder.build_name(
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
),
),
return _transform_helper(
self._state.flow_builder_state, fn_spec, transform_args
)

def call(self, func: Callable[..., S], *args: Any, **kwargs: Any) -> S:
Expand Down Expand Up @@ -446,6 +459,24 @@ def add_source(
name,
)

def transform(
self, fn_spec: FunctionSpec, *args: Any, **kwargs: Any
) -> DataSlice[Any]:
"""
Apply a function to inputs, returning a DataSlice.
"""
transform_args: list[tuple[Any, str | None]] = [
(self._state.get_data_slice(v), None) for v in args
]
transform_args += [
(self._state.get_data_slice(v), k) for k, v in kwargs.items()
]

if not transform_args:
raise ValueError("At least one input is required for transformation")

return _transform_helper(self._state, fn_spec, transform_args)

def declare(self, spec: op.DeclarationSpec) -> None:
"""
Add a declaration to the flow.
Expand Down
7 changes: 3 additions & 4 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import asyncio
import dataclasses
import inspect

from typing import Protocol, Any, Callable, Awaitable, dataclass_transform
from enum import Enum
from typing import Any, Awaitable, Callable, Protocol, dataclass_transform

from .typing import encode_enriched_type, resolve_forward_ref
from .convert import encode_engine_value, make_engine_value_decoder
from . import _engine # type: ignore
from .convert import encode_engine_value, make_engine_value_decoder
from .typing import encode_enriched_type, resolve_forward_ref


class OpCategory(Enum):
Expand Down