Skip to content

Commit ec214e3

Browse files
authored
Refactor logic to parse Python type - to support dataclass type binding. #19 (#69)
Refactor logic to parse Python type - to support dataclass type binding.
1 parent 3353fc3 commit ec214e3

File tree

3 files changed

+115
-87
lines changed

3 files changed

+115
-87
lines changed

python/cocoindex/flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import _engine
1313
from . import vector
1414
from . import op
15-
from .typing import dump_type
15+
from .typing import encode_type
1616

1717
class _NameBuilder:
1818
_existing_names: set[str]
@@ -419,7 +419,7 @@ def __init__(
419419
inspect.Parameter.KEYWORD_ONLY):
420420
raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
421421
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
422-
param_name, dump_type(param_type))
422+
param_name, encode_type(param_type))
423423
kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds))
424424

425425
output = flow_fn(**kwargs)

python/cocoindex/op.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from enum import Enum
99
from threading import Lock
1010

11-
from .typing import dump_type
11+
from .typing import encode_type
1212
from . import _engine
1313

1414

@@ -57,14 +57,14 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
5757
spec = self._spec_cls(**spec)
5858
executor = self._executor_cls(spec)
5959
result_type = executor.analyze(*args, **kwargs)
60-
return (dump_type(result_type), executor)
60+
return (encode_type(result_type), executor)
6161

62-
def to_engine_value(value: Any) -> Any:
62+
def _to_engine_value(value: Any) -> Any:
6363
"""Convert a Python value to an engine value."""
6464
if dataclasses.is_dataclass(value):
65-
return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
65+
return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
6666
elif isinstance(value, list) or isinstance(value, tuple):
67-
return [to_engine_value(v) for v in value]
67+
return [_to_engine_value(v) for v in value]
6868
return value
6969

7070
_gpu_dispatch_lock = Lock()
@@ -123,7 +123,7 @@ def analyze(self, *args, **kwargs):
123123
if arg_param.kind == inspect.Parameter.KEYWORD_ONLY or arg_param.kind == inspect.Parameter.VAR_KEYWORD:
124124
raise ValueError(f"Too many positional arguments: {len(args)} > {next_param_idx}")
125125
if arg_param.annotation is not inspect.Parameter.empty:
126-
arg.validate_arg(arg_name, dump_type(arg_param.annotation))
126+
arg.validate_arg(arg_name, encode_type(arg_param.annotation))
127127
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
128128
next_param_idx += 1
129129

@@ -139,7 +139,7 @@ def analyze(self, *args, **kwargs):
139139
raise ValueError(f"Unexpected keyword argument: {kwarg_name}")
140140
arg_param = expected_arg[1]
141141
if arg_param.annotation is not inspect.Parameter.empty:
142-
kwarg.validate_arg(kwarg_name, dump_type(arg_param.annotation))
142+
kwarg.validate_arg(kwarg_name, encode_type(arg_param.annotation))
143143

144144
missing_args = [name for (name, arg) in expected_kwargs
145145
if arg.default is inspect.Parameter.empty
@@ -173,7 +173,7 @@ def __call__(self, *args, **kwargs):
173173
output = super().__call__(*args, **kwargs)
174174
else:
175175
output = super().__call__(*args, **kwargs)
176-
return to_engine_value(output)
176+
return _to_engine_value(output)
177177

178178
_WrappedClass.__name__ = cls.__name__
179179

python/cocoindex/typing.py

Lines changed: 105 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def __init__(self, key: str, value: Any):
1616
self.key = key
1717
self.value = value
1818

19+
Annotation = Vector | TypeKind | TypeAttr
20+
1921
Float32 = Annotated[float, TypeKind('Float32')]
2022
Float64 = Annotated[float, TypeKind('Float64')]
2123
Range = Annotated[tuple[int, int], TypeKind('Range')]
@@ -43,95 +45,121 @@ class List: # type: ignore[unreachable]
4345
def __class_getitem__(cls, item: type[R]):
4446
return Annotated[list[item], TypeKind('List')]
4547

46-
def _find_annotation(metadata, cls):
47-
for m in iter(metadata):
48-
if isinstance(m, cls):
49-
return m
50-
return None
48+
def _dump_field_schema(field: dataclasses.Field) -> dict[str, Any]:
49+
encoded = _encode_enriched_type(field.type)
50+
encoded['name'] = field.name
51+
return encoded
52+
@dataclasses.dataclass
53+
class AnalyzedTypeInfo:
54+
"""
55+
Analyzed info of a Python type.
56+
"""
57+
kind: str
58+
vector_info: Vector | None
59+
elem_type: type | None
60+
struct_fields: tuple[dataclasses.Field, ...] | None
61+
attrs: dict[str, Any] | None
5162

52-
def _get_origin_type_and_metadata(t):
63+
def analyze_type_info(t) -> AnalyzedTypeInfo:
64+
"""
65+
Analyze a Python type and return the analyzed info.
66+
"""
67+
annotations: tuple[Annotation, ...] = ()
5368
if typing.get_origin(t) is Annotated:
54-
return (t.__origin__, t.__metadata__)
55-
return (t, ())
56-
57-
def _dump_fields_schema(cls: type) -> list[dict[str, Any]]:
58-
return [
59-
{
60-
'name': field.name,
61-
**_dump_enriched_type(field.type)
62-
}
63-
for field in dataclasses.fields(cls)
64-
]
65-
66-
def _dump_type(t, metadata):
67-
origin_type = typing.get_origin(t)
68-
type_kind = _find_annotation(metadata, TypeKind)
69-
if origin_type is collections.abc.Sequence or origin_type is list:
70-
args = typing.get_args(t)
71-
elem_type, elem_type_metadata = _get_origin_type_and_metadata(args[0])
72-
vector_annot = _find_annotation(metadata, Vector)
73-
if vector_annot is not None:
74-
encoded_type = {
75-
'kind': 'Vector',
76-
'element_type': _dump_type(elem_type, elem_type_metadata),
77-
'dimension': vector_annot.dim,
78-
}
79-
elif dataclasses.is_dataclass(elem_type):
80-
if type_kind is not None and type_kind.kind == 'Table':
81-
encoded_type = {
82-
'kind': 'Table',
83-
'row': { 'fields': _dump_fields_schema(elem_type) },
84-
}
85-
else:
86-
encoded_type = {
87-
'kind': 'List',
88-
'row': { 'fields': _dump_fields_schema(elem_type) },
89-
}
90-
else:
91-
raise ValueError(f"Unsupported type: {t}")
92-
elif dataclasses.is_dataclass(t):
93-
encoded_type = {
94-
'kind': 'Struct',
95-
'fields': _dump_fields_schema(t),
96-
}
97-
else:
98-
if type_kind is not None:
99-
kind = type_kind.kind
100-
else:
101-
if t is bytes:
102-
kind = 'Bytes'
103-
elif t is str:
104-
kind = 'Str'
105-
elif t is bool:
106-
kind = 'Bool'
107-
elif t is int:
108-
kind = 'Int64'
109-
elif t is float:
110-
kind = 'Float64'
111-
else:
112-
raise ValueError(f"type unsupported yet: {t}")
113-
encoded_type = { 'kind': kind }
114-
115-
return encoded_type
69+
annotations = t.__metadata__
70+
t = t.__origin__
71+
base_type = typing.get_origin(t)
11672

117-
def _dump_enriched_type(t) -> dict[str, Any]:
118-
t, metadata = _get_origin_type_and_metadata(t)
119-
enriched_type_json = {'type': _dump_type(t, metadata)}
12073
attrs = None
121-
for attr in metadata:
74+
vector_info = None
75+
kind = None
76+
for attr in annotations:
12277
if isinstance(attr, TypeAttr):
12378
if attrs is None:
12479
attrs = dict()
12580
attrs[attr.key] = attr.value
126-
if attrs is not None:
127-
enriched_type_json['attrs'] = attrs
128-
return enriched_type_json
81+
elif isinstance(attr, Vector):
82+
vector_info = attr
83+
elif isinstance(attr, TypeKind):
84+
kind = attr.kind
85+
86+
struct_fields = None
87+
elem_type = None
88+
if dataclasses.is_dataclass(t):
89+
if kind is None:
90+
kind = 'Struct'
91+
elif kind != 'Struct':
92+
raise ValueError(f"Unexpected type kind for struct: {kind}")
93+
struct_fields = dataclasses.fields(t)
94+
elif base_type is collections.abc.Sequence or base_type is list:
95+
if kind is None:
96+
kind = 'Vector' if vector_info is not None else 'List'
97+
elif kind not in ('Vector', 'List', 'Table'):
98+
raise ValueError(f"Unexpected type kind for list: {kind}")
99+
100+
args = typing.get_args(t)
101+
if len(args) != 1:
102+
raise ValueError(f"{kind} must have exactly one type argument")
103+
elem_type = args[0]
104+
elif kind is None:
105+
if base_type is collections.abc.Sequence or base_type is list:
106+
kind = 'Vector' if vector_info is not None else 'List'
107+
elif t is bytes:
108+
kind = 'Bytes'
109+
elif t is str:
110+
kind = 'Str'
111+
elif t is bool:
112+
kind = 'Bool'
113+
elif t is int:
114+
kind = 'Int64'
115+
elif t is float:
116+
kind = 'Float64'
117+
else:
118+
raise ValueError(f"type unsupported yet: {base_type}")
119+
120+
return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type,
121+
struct_fields=struct_fields, attrs=attrs)
122+
123+
def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
124+
encoded_type: dict[str, Any] = { 'kind': type_info.kind }
125+
126+
if type_info.kind == 'Struct':
127+
if type_info.struct_fields is None:
128+
raise ValueError("Struct type must have a struct fields")
129+
encoded_type['fields'] = [_dump_field_schema(field) for field in type_info.struct_fields]
130+
131+
elif type_info.kind == 'Vector':
132+
if type_info.vector_info is None:
133+
raise ValueError("Vector type must have a vector info")
134+
if type_info.elem_type is None:
135+
raise ValueError("Vector type must have an element type")
136+
encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type))
137+
encoded_type['dimension'] = type_info.vector_info.dim
138+
139+
elif type_info.kind in ('List', 'Table'):
140+
if type_info.elem_type is None:
141+
raise ValueError(f"{type_info.kind} type must have an element type")
142+
row_type_inof = analyze_type_info(type_info.elem_type)
143+
if row_type_inof.struct_fields is None:
144+
raise ValueError(f"{type_info.kind} type must have a struct fields")
145+
encoded_type['row'] = {
146+
'fields': [_dump_field_schema(field) for field in row_type_inof.struct_fields],
147+
}
148+
149+
return encoded_type
150+
151+
def _encode_enriched_type(t) -> dict[str, Any]:
152+
enriched_type_info = analyze_type_info(t)
153+
encoded = {'type': _encode_type(enriched_type_info)}
154+
if enriched_type_info.attrs is not None:
155+
encoded['attrs'] = enriched_type_info.attrs
156+
return encoded
129157

130158

131-
def dump_type(t) -> dict[str, Any] | None:
159+
def encode_type(t) -> dict[str, Any] | None:
132160
"""
133161
Convert a Python type to a CocoIndex's type in JSON.
134162
"""
135163
if t is None:
136164
return None
137-
return _dump_enriched_type(t)
165+
return _encode_enriched_type(t)

0 commit comments

Comments
 (0)