Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import _engine
from . import vector
from . import op
from .typing import dump_type
from .typing import encode_type

class _NameBuilder:
_existing_names: set[str]
Expand Down Expand Up @@ -419,7 +419,7 @@ def __init__(
inspect.Parameter.KEYWORD_ONLY):
raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
param_name, dump_type(param_type))
param_name, encode_type(param_type))
kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds))

output = flow_fn(**kwargs)
Expand Down
16 changes: 8 additions & 8 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import Enum
from threading import Lock

from .typing import dump_type
from .typing import encode_type
from . import _engine


Expand Down Expand Up @@ -57,14 +57,14 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
spec = self._spec_cls(**spec)
executor = self._executor_cls(spec)
result_type = executor.analyze(*args, **kwargs)
return (dump_type(result_type), executor)
return (encode_type(result_type), executor)

def to_engine_value(value: Any) -> Any:
def _to_engine_value(value: Any) -> Any:
"""Convert a Python value to an engine value."""
if dataclasses.is_dataclass(value):
return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
elif isinstance(value, list) or isinstance(value, tuple):
return [to_engine_value(v) for v in value]
return [_to_engine_value(v) for v in value]
return value

_gpu_dispatch_lock = Lock()
Expand Down Expand Up @@ -123,7 +123,7 @@ def analyze(self, *args, **kwargs):
if arg_param.kind == inspect.Parameter.KEYWORD_ONLY or arg_param.kind == inspect.Parameter.VAR_KEYWORD:
raise ValueError(f"Too many positional arguments: {len(args)} > {next_param_idx}")
if arg_param.annotation is not inspect.Parameter.empty:
arg.validate_arg(arg_name, dump_type(arg_param.annotation))
arg.validate_arg(arg_name, encode_type(arg_param.annotation))
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
next_param_idx += 1

Expand All @@ -139,7 +139,7 @@ def analyze(self, *args, **kwargs):
raise ValueError(f"Unexpected keyword argument: {kwarg_name}")
arg_param = expected_arg[1]
if arg_param.annotation is not inspect.Parameter.empty:
kwarg.validate_arg(kwarg_name, dump_type(arg_param.annotation))
kwarg.validate_arg(kwarg_name, encode_type(arg_param.annotation))

missing_args = [name for (name, arg) in expected_kwargs
if arg.default is inspect.Parameter.empty
Expand Down Expand Up @@ -173,7 +173,7 @@ def __call__(self, *args, **kwargs):
output = super().__call__(*args, **kwargs)
else:
output = super().__call__(*args, **kwargs)
return to_engine_value(output)
return _to_engine_value(output)

_WrappedClass.__name__ = cls.__name__

Expand Down
182 changes: 105 additions & 77 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def __init__(self, key: str, value: Any):
self.key = key
self.value = value

Annotation = Vector | TypeKind | TypeAttr

Float32 = Annotated[float, TypeKind('Float32')]
Float64 = Annotated[float, TypeKind('Float64')]
Range = Annotated[tuple[int, int], TypeKind('Range')]
Expand Down Expand Up @@ -43,95 +45,121 @@ class List: # type: ignore[unreachable]
def __class_getitem__(cls, item: type[R]):
return Annotated[list[item], TypeKind('List')]

def _find_annotation(metadata, cls):
for m in iter(metadata):
if isinstance(m, cls):
return m
return None
def _dump_field_schema(field: dataclasses.Field) -> dict[str, Any]:
encoded = _encode_enriched_type(field.type)
encoded['name'] = field.name
return encoded
@dataclasses.dataclass
class AnalyzedTypeInfo:
"""
Analyzed info of a Python type.
"""
kind: str
vector_info: Vector | None
elem_type: type | None
struct_fields: tuple[dataclasses.Field, ...] | None
attrs: dict[str, Any] | None

def _get_origin_type_and_metadata(t):
def analyze_type_info(t) -> AnalyzedTypeInfo:
"""
Analyze a Python type and return the analyzed info.
"""
annotations: tuple[Annotation, ...] = ()
if typing.get_origin(t) is Annotated:
return (t.__origin__, t.__metadata__)
return (t, ())

def _dump_fields_schema(cls: type) -> list[dict[str, Any]]:
return [
{
'name': field.name,
**_dump_enriched_type(field.type)
}
for field in dataclasses.fields(cls)
]

def _dump_type(t, metadata):
origin_type = typing.get_origin(t)
type_kind = _find_annotation(metadata, TypeKind)
if origin_type is collections.abc.Sequence or origin_type is list:
args = typing.get_args(t)
elem_type, elem_type_metadata = _get_origin_type_and_metadata(args[0])
vector_annot = _find_annotation(metadata, Vector)
if vector_annot is not None:
encoded_type = {
'kind': 'Vector',
'element_type': _dump_type(elem_type, elem_type_metadata),
'dimension': vector_annot.dim,
}
elif dataclasses.is_dataclass(elem_type):
if type_kind is not None and type_kind.kind == 'Table':
encoded_type = {
'kind': 'Table',
'row': { 'fields': _dump_fields_schema(elem_type) },
}
else:
encoded_type = {
'kind': 'List',
'row': { 'fields': _dump_fields_schema(elem_type) },
}
else:
raise ValueError(f"Unsupported type: {t}")
elif dataclasses.is_dataclass(t):
encoded_type = {
'kind': 'Struct',
'fields': _dump_fields_schema(t),
}
else:
if type_kind is not None:
kind = type_kind.kind
else:
if t is bytes:
kind = 'Bytes'
elif t is str:
kind = 'Str'
elif t is bool:
kind = 'Bool'
elif t is int:
kind = 'Int64'
elif t is float:
kind = 'Float64'
else:
raise ValueError(f"type unsupported yet: {t}")
encoded_type = { 'kind': kind }

return encoded_type
annotations = t.__metadata__
t = t.__origin__
base_type = typing.get_origin(t)

def _dump_enriched_type(t) -> dict[str, Any]:
t, metadata = _get_origin_type_and_metadata(t)
enriched_type_json = {'type': _dump_type(t, metadata)}
attrs = None
for attr in metadata:
vector_info = None
kind = None
for attr in annotations:
if isinstance(attr, TypeAttr):
if attrs is None:
attrs = dict()
attrs[attr.key] = attr.value
if attrs is not None:
enriched_type_json['attrs'] = attrs
return enriched_type_json
elif isinstance(attr, Vector):
vector_info = attr
elif isinstance(attr, TypeKind):
kind = attr.kind

struct_fields = None
elem_type = None
if dataclasses.is_dataclass(t):
if kind is None:
kind = 'Struct'
elif kind != 'Struct':
raise ValueError(f"Unexpected type kind for struct: {kind}")
struct_fields = dataclasses.fields(t)
elif base_type is collections.abc.Sequence or base_type is list:
if kind is None:
kind = 'Vector' if vector_info is not None else 'List'
elif kind not in ('Vector', 'List', 'Table'):
raise ValueError(f"Unexpected type kind for list: {kind}")

args = typing.get_args(t)
if len(args) != 1:
raise ValueError(f"{kind} must have exactly one type argument")
elem_type = args[0]
elif kind is None:
if base_type is collections.abc.Sequence or base_type is list:
kind = 'Vector' if vector_info is not None else 'List'
elif t is bytes:
kind = 'Bytes'
elif t is str:
kind = 'Str'
elif t is bool:
kind = 'Bool'
elif t is int:
kind = 'Int64'
elif t is float:
kind = 'Float64'
else:
raise ValueError(f"type unsupported yet: {base_type}")

return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type,
struct_fields=struct_fields, attrs=attrs)

def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
encoded_type: dict[str, Any] = { 'kind': type_info.kind }

if type_info.kind == 'Struct':
if type_info.struct_fields is None:
raise ValueError("Struct type must have a struct fields")
encoded_type['fields'] = [_dump_field_schema(field) for field in type_info.struct_fields]

elif type_info.kind == 'Vector':
if type_info.vector_info is None:
raise ValueError("Vector type must have a vector info")
if type_info.elem_type is None:
raise ValueError("Vector type must have an element type")
encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type))
encoded_type['dimension'] = type_info.vector_info.dim

elif type_info.kind in ('List', 'Table'):
if type_info.elem_type is None:
raise ValueError(f"{type_info.kind} type must have an element type")
row_type_inof = analyze_type_info(type_info.elem_type)
if row_type_inof.struct_fields is None:
raise ValueError(f"{type_info.kind} type must have a struct fields")
encoded_type['row'] = {
'fields': [_dump_field_schema(field) for field in row_type_inof.struct_fields],
}

return encoded_type

def _encode_enriched_type(t) -> dict[str, Any]:
enriched_type_info = analyze_type_info(t)
encoded = {'type': _encode_type(enriched_type_info)}
if enriched_type_info.attrs is not None:
encoded['attrs'] = enriched_type_info.attrs
return encoded


def dump_type(t) -> dict[str, Any] | None:
def encode_type(t) -> dict[str, Any] | None:
"""
Convert a Python type to a CocoIndex's type in JSON.
"""
if t is None:
return None
return _dump_enriched_type(t)
return _encode_enriched_type(t)