diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 99c8ed080..fa0058837 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -12,7 +12,7 @@ from . import _engine from . import vector from . import op -from .typing import dump_type +from .typing import encode_type class _NameBuilder: _existing_names: set[str] @@ -419,7 +419,7 @@ def __init__( inspect.Parameter.KEYWORD_ONLY): raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name") engine_ds = flow_builder_state.engine_flow_builder.add_direct_input( - param_name, dump_type(param_type)) + param_name, encode_type(param_type)) kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds)) output = flow_fn(**kwargs) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 9afb6ba65..0a6df995b 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -8,7 +8,7 @@ from enum import Enum from threading import Lock -from .typing import dump_type +from .typing import encode_type from . import _engine @@ -57,14 +57,14 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs): spec = self._spec_cls(**spec) executor = self._executor_cls(spec) result_type = executor.analyze(*args, **kwargs) - return (dump_type(result_type), executor) + return (encode_type(result_type), executor) -def to_engine_value(value: Any) -> Any: +def _to_engine_value(value: Any) -> Any: """Convert a Python value to an engine value.""" if dataclasses.is_dataclass(value): - return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)] + return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)] elif isinstance(value, list) or isinstance(value, tuple): - return [to_engine_value(v) for v in value] + return [_to_engine_value(v) for v in value] return value _gpu_dispatch_lock = Lock() @@ -123,7 +123,7 @@ def analyze(self, *args, **kwargs): if arg_param.kind == inspect.Parameter.KEYWORD_ONLY or arg_param.kind == inspect.Parameter.VAR_KEYWORD: raise ValueError(f"Too many positional arguments: {len(args)} > {next_param_idx}") if arg_param.annotation is not inspect.Parameter.empty: - arg.validate_arg(arg_name, dump_type(arg_param.annotation)) + arg.validate_arg(arg_name, encode_type(arg_param.annotation)) if arg_param.kind != inspect.Parameter.VAR_POSITIONAL: next_param_idx += 1 @@ -139,7 +139,7 @@ def analyze(self, *args, **kwargs): raise ValueError(f"Unexpected keyword argument: {kwarg_name}") arg_param = expected_arg[1] if arg_param.annotation is not inspect.Parameter.empty: - kwarg.validate_arg(kwarg_name, dump_type(arg_param.annotation)) + kwarg.validate_arg(kwarg_name, encode_type(arg_param.annotation)) missing_args = [name for (name, arg) in expected_kwargs if arg.default is inspect.Parameter.empty @@ -173,7 +173,7 @@ def __call__(self, *args, **kwargs): output = super().__call__(*args, **kwargs) else: output = super().__call__(*args, **kwargs) - return to_engine_value(output) + return _to_engine_value(output) _WrappedClass.__name__ = cls.__name__ diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 85e68a2ab..2eacad9fb 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -16,6 +16,8 @@ def __init__(self, key: str, value: Any): self.key = key self.value = value +Annotation = Vector | TypeKind | TypeAttr + Float32 = Annotated[float, TypeKind('Float32')] Float64 = Annotated[float, TypeKind('Float64')] Range = Annotated[tuple[int, int], TypeKind('Range')] @@ -43,95 +45,121 @@ class List: # type: ignore[unreachable] def __class_getitem__(cls, item: type[R]): return Annotated[list[item], TypeKind('List')] -def _find_annotation(metadata, cls): - for m in iter(metadata): - if isinstance(m, cls): - return m - return None +def _dump_field_schema(field: dataclasses.Field) -> dict[str, Any]: + encoded = _encode_enriched_type(field.type) + encoded['name'] = field.name + return encoded +@dataclasses.dataclass +class AnalyzedTypeInfo: + """ + Analyzed info of a Python type. + """ + kind: str + vector_info: Vector | None + elem_type: type | None + struct_fields: tuple[dataclasses.Field, ...] | None + attrs: dict[str, Any] | None -def _get_origin_type_and_metadata(t): +def analyze_type_info(t) -> AnalyzedTypeInfo: + """ + Analyze a Python type and return the analyzed info. + """ + annotations: tuple[Annotation, ...] = () if typing.get_origin(t) is Annotated: - return (t.__origin__, t.__metadata__) - return (t, ()) - -def _dump_fields_schema(cls: type) -> list[dict[str, Any]]: - return [ - { - 'name': field.name, - **_dump_enriched_type(field.type) - } - for field in dataclasses.fields(cls) - ] - -def _dump_type(t, metadata): - origin_type = typing.get_origin(t) - type_kind = _find_annotation(metadata, TypeKind) - if origin_type is collections.abc.Sequence or origin_type is list: - args = typing.get_args(t) - elem_type, elem_type_metadata = _get_origin_type_and_metadata(args[0]) - vector_annot = _find_annotation(metadata, Vector) - if vector_annot is not None: - encoded_type = { - 'kind': 'Vector', - 'element_type': _dump_type(elem_type, elem_type_metadata), - 'dimension': vector_annot.dim, - } - elif dataclasses.is_dataclass(elem_type): - if type_kind is not None and type_kind.kind == 'Table': - encoded_type = { - 'kind': 'Table', - 'row': { 'fields': _dump_fields_schema(elem_type) }, - } - else: - encoded_type = { - 'kind': 'List', - 'row': { 'fields': _dump_fields_schema(elem_type) }, - } - else: - raise ValueError(f"Unsupported type: {t}") - elif dataclasses.is_dataclass(t): - encoded_type = { - 'kind': 'Struct', - 'fields': _dump_fields_schema(t), - } - else: - if type_kind is not None: - kind = type_kind.kind - else: - if t is bytes: - kind = 'Bytes' - elif t is str: - kind = 'Str' - elif t is bool: - kind = 'Bool' - elif t is int: - kind = 'Int64' - elif t is float: - kind = 'Float64' - else: - raise ValueError(f"type unsupported yet: {t}") - encoded_type = { 'kind': kind } - - return encoded_type + annotations = t.__metadata__ + t = t.__origin__ + base_type = typing.get_origin(t) -def _dump_enriched_type(t) -> dict[str, Any]: - t, metadata = _get_origin_type_and_metadata(t) - enriched_type_json = {'type': _dump_type(t, metadata)} attrs = None - for attr in metadata: + vector_info = None + kind = None + for attr in annotations: if isinstance(attr, TypeAttr): if attrs is None: attrs = dict() attrs[attr.key] = attr.value - if attrs is not None: - enriched_type_json['attrs'] = attrs - return enriched_type_json + elif isinstance(attr, Vector): + vector_info = attr + elif isinstance(attr, TypeKind): + kind = attr.kind + + struct_fields = None + elem_type = None + if dataclasses.is_dataclass(t): + if kind is None: + kind = 'Struct' + elif kind != 'Struct': + raise ValueError(f"Unexpected type kind for struct: {kind}") + struct_fields = dataclasses.fields(t) + elif base_type is collections.abc.Sequence or base_type is list: + if kind is None: + kind = 'Vector' if vector_info is not None else 'List' + elif kind not in ('Vector', 'List', 'Table'): + raise ValueError(f"Unexpected type kind for list: {kind}") + + args = typing.get_args(t) + if len(args) != 1: + raise ValueError(f"{kind} must have exactly one type argument") + elem_type = args[0] + elif kind is None: + if base_type is collections.abc.Sequence or base_type is list: + kind = 'Vector' if vector_info is not None else 'List' + elif t is bytes: + kind = 'Bytes' + elif t is str: + kind = 'Str' + elif t is bool: + kind = 'Bool' + elif t is int: + kind = 'Int64' + elif t is float: + kind = 'Float64' + else: + raise ValueError(f"type unsupported yet: {base_type}") + + return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type, + struct_fields=struct_fields, attrs=attrs) + +def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: + encoded_type: dict[str, Any] = { 'kind': type_info.kind } + + if type_info.kind == 'Struct': + if type_info.struct_fields is None: + raise ValueError("Struct type must have a struct fields") + encoded_type['fields'] = [_dump_field_schema(field) for field in type_info.struct_fields] + + elif type_info.kind == 'Vector': + if type_info.vector_info is None: + raise ValueError("Vector type must have a vector info") + if type_info.elem_type is None: + raise ValueError("Vector type must have an element type") + encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type)) + encoded_type['dimension'] = type_info.vector_info.dim + + elif type_info.kind in ('List', 'Table'): + if type_info.elem_type is None: + raise ValueError(f"{type_info.kind} type must have an element type") + row_type_inof = analyze_type_info(type_info.elem_type) + if row_type_inof.struct_fields is None: + raise ValueError(f"{type_info.kind} type must have a struct fields") + encoded_type['row'] = { + 'fields': [_dump_field_schema(field) for field in row_type_inof.struct_fields], + } + + return encoded_type + +def _encode_enriched_type(t) -> dict[str, Any]: + enriched_type_info = analyze_type_info(t) + encoded = {'type': _encode_type(enriched_type_info)} + if enriched_type_info.attrs is not None: + encoded['attrs'] = enriched_type_info.attrs + return encoded -def dump_type(t) -> dict[str, Any] | None: +def encode_type(t) -> dict[str, Any] | None: """ Convert a Python type to a CocoIndex's type in JSON. """ if t is None: return None - return _dump_enriched_type(t) + return _encode_enriched_type(t)