Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions python/cocoindex/functions.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
"""All builtin functions."""

from typing import Annotated, Any, TYPE_CHECKING, Literal
import dataclasses
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeVar

import numpy as np
from numpy.typing import NDArray
import dataclasses

from .typing import Float32, Vector, TypeAttr
from . import op, llm
from . import llm, op
from .flow import DataSlice
from .typing import TypeAttr, Vector

# Libraries that are heavy to import. Lazily import them later.
if TYPE_CHECKING:
import sentence_transformers

T = TypeVar("T")


class ParseJson(op.FunctionSpec):
"""Parse a text into a JSON object."""

def __call__(
self, *, text: DataSlice[T], language: str | None = "json"
) -> DataSlice[T]:
return super().__call__(text=text, language=language)


@dataclasses.dataclass
class CustomLanguageSpec:
Expand All @@ -31,6 +40,23 @@ class SplitRecursively(op.FunctionSpec):

custom_languages: list[CustomLanguageSpec] = dataclasses.field(default_factory=list)

def __call__(
self,
*,
text: DataSlice[T],
chunk_size: int,
min_chunk_size: int | None = None,
chunk_overlap: int | None = None,
language: DataSlice[T] | None = None,
) -> DataSlice[T]:
return super().__call__(
text=text,
chunk_size=chunk_size,
language=language,
min_chunk_size=min_chunk_size,
chunk_overlap=chunk_overlap,
)


class EmbedText(op.FunctionSpec):
"""Embed a text into a vector space."""
Expand All @@ -49,6 +75,11 @@ class ExtractByLlm(op.FunctionSpec):
output_type: type
instruction: str | None = None

def __call__(
self, *, text: DataSlice[T] | None = None, image: DataSlice[T] | None = None
) -> DataSlice[T]:
return super().__call__(text=text, image=image)


class SentenceTransformerEmbed(op.FunctionSpec):
"""
Expand Down
57 changes: 50 additions & 7 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import asyncio
import dataclasses
import inspect

from typing import Protocol, Any, Callable, Awaitable, dataclass_transform
from enum import Enum
from typing import Any, Awaitable, Callable, Protocol, dataclass_transform

from .typing import encode_enriched_type, resolve_forward_ref
from .convert import encode_engine_value, make_engine_value_decoder
from . import _engine # type: ignore
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
from .flow import DataSlice, _create_data_slice, _spec_kind, _to_snake_case
from .typing import encode_enriched_type, resolve_forward_ref


class OpCategory(Enum):
Expand Down Expand Up @@ -49,7 +49,51 @@ class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: dis


class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods
"""A function spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
"""A function spec. Can be instantiated and called like a function: spec(args...).
For non-chain-style calls, use spec(args...) with at least one DataSlice argument.
For chain-style calls, use data_slice.transform(spec, args...).
"""

def __call__(self, *args: Any, **kwargs: Any) -> DataSlice[Any]:
"""Execute the function, returning a DataSlice."""

data_slice_args = [arg for arg in args if isinstance(arg, DataSlice)]
data_slice_kwargs = {
k: v for k, v in kwargs.items() if isinstance(v, DataSlice)
}
if not data_slice_args and not data_slice_kwargs:
raise ValueError(
"At least one DataSlice argument is required to provide flow context"
)

first_data_slice = (
data_slice_args[0]
if data_slice_args
else list(data_slice_kwargs.values())[0]
)
flow_builder_state = first_data_slice._state.flow_builder_state

transform_args: list[tuple[Any, str | None]] = [
(flow_builder_state.get_data_slice(v), None) for v in args if v is not None
]
transform_args += [
(flow_builder_state.get_data_slice(v), k)
for k, v in kwargs.items()
if v is not None
]

return _create_data_slice(
flow_builder_state,
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
_spec_kind(self),
dump_engine_object(self),
transform_args,
target_scope,
flow_builder_state.field_name_builder.build_name(
name, prefix=_to_snake_case(_spec_kind(self)) + "_"
),
),
)


class TargetSpec(metaclass=SpecMeta, category=OpCategory.TARGET): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -311,8 +355,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
return fn(*args, **kwargs)

class _Spec(FunctionSpec):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return fn(*args, **kwargs)
pass

_Spec.__name__ = op_name
_Spec.__doc__ = fn.__doc__
Expand Down
Loading