Skip to content

Commit b941fdd

Browse files
authored
feat: add strong-typed representation of cocoindex type in Python (#1019)
1 parent 7a01503 commit b941fdd

File tree

5 files changed

+258
-57
lines changed

5 files changed

+258
-57
lines changed

python/cocoindex/convert.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
import inspect
1010
import warnings
1111
from enum import Enum
12-
from typing import Any, Callable, Mapping, Sequence, Type, get_origin
12+
from typing import Any, Callable, Mapping, get_origin
1313

1414
import numpy as np
1515

1616
from .typing import (
17-
TABLE_TYPES,
1817
AnalyzedAnyType,
1918
AnalyzedBasicType,
2019
AnalyzedDictType,
@@ -27,6 +26,11 @@
2726
encode_enriched_type,
2827
is_namedtuple_type,
2928
is_numpy_number_type,
29+
ValueType,
30+
FieldSchema,
31+
BasicValueType,
32+
StructType,
33+
TableType,
3034
)
3135

3236

@@ -172,7 +176,7 @@ def encode_basic_value(value: Any) -> Any:
172176

173177
def make_engine_key_decoder(
174178
field_path: list[str],
175-
key_fields_schema: list[dict[str, Any]],
179+
key_fields_schema: list[FieldSchema],
176180
dst_type_info: AnalyzedTypeInfo,
177181
) -> Callable[[Any], Any]:
178182
"""
@@ -183,7 +187,7 @@ def make_engine_key_decoder(
183187
):
184188
single_key_decoder = make_engine_value_decoder(
185189
field_path,
186-
key_fields_schema[0]["type"],
190+
key_fields_schema[0].value_type.type,
187191
dst_type_info,
188192
for_key=True,
189193
)
@@ -203,7 +207,7 @@ def key_decoder(value: list[Any]) -> Any:
203207

204208
def make_engine_value_decoder(
205209
field_path: list[str],
206-
src_type: dict[str, Any],
210+
src_type: ValueType,
207211
dst_type_info: AnalyzedTypeInfo,
208212
for_key: bool = False,
209213
) -> Callable[[Any], Any]:
@@ -219,7 +223,7 @@ def make_engine_value_decoder(
219223
A decoder from an engine value to a Python value.
220224
"""
221225

222-
src_type_kind = src_type["kind"]
226+
src_type_kind = src_type.kind
223227

224228
dst_type_variant = dst_type_info.variant
225229

@@ -229,19 +233,19 @@ def make_engine_value_decoder(
229233
f"declared `{dst_type_info.core_type}`, an unsupported type"
230234
)
231235

232-
if src_type_kind == "Struct":
236+
if isinstance(src_type, StructType): # type: ignore[redundant-cast]
233237
return make_engine_struct_decoder(
234238
field_path,
235-
src_type["fields"],
239+
src_type.fields,
236240
dst_type_info,
237241
for_key=for_key,
238242
)
239243

240-
if src_type_kind in TABLE_TYPES:
244+
if isinstance(src_type, TableType): # type: ignore[redundant-cast]
241245
with ChildFieldPath(field_path, "[*]"):
242-
engine_fields_schema = src_type["row"]["fields"]
246+
engine_fields_schema = src_type.row.fields
243247

244-
if src_type_kind == "LTable":
248+
if src_type.kind == "LTable":
245249
if isinstance(dst_type_variant, AnalyzedAnyType):
246250
dst_elem_type = Any
247251
elif isinstance(dst_type_variant, AnalyzedListType):
@@ -262,7 +266,7 @@ def decode(value: Any) -> Any | None:
262266
return None
263267
return [row_decoder(v) for v in value]
264268

265-
elif src_type_kind == "KTable":
269+
elif src_type.kind == "KTable":
266270
if isinstance(dst_type_variant, AnalyzedAnyType):
267271
key_type, value_type = Any, Any
268272
elif isinstance(dst_type_variant, AnalyzedDictType):
@@ -274,7 +278,7 @@ def decode(value: Any) -> Any | None:
274278
f"declared `{dst_type_info.core_type}`, a dict type expected"
275279
)
276280

277-
num_key_parts = src_type.get("num_key_parts", 1)
281+
num_key_parts = src_type.num_key_parts or 1
278282
key_decoder = make_engine_key_decoder(
279283
field_path,
280284
engine_fields_schema[0:num_key_parts],
@@ -298,7 +302,7 @@ def decode(value: Any) -> Any | None:
298302

299303
return decode
300304

301-
if src_type_kind == "Union":
305+
if isinstance(src_type, BasicValueType) and src_type.kind == "Union":
302306
if isinstance(dst_type_variant, AnalyzedAnyType):
303307
return lambda value: value[1]
304308

@@ -307,7 +311,10 @@ def decode(value: Any) -> Any | None:
307311
if isinstance(dst_type_variant, AnalyzedUnionType)
308312
else [dst_type_info]
309313
)
310-
src_type_variants = src_type["types"]
314+
# mypy: union info exists for Union kind
315+
assert src_type.union is not None # type: ignore[unreachable]
316+
src_type_variants_basic: list[BasicValueType] = src_type.union.variants
317+
src_type_variants = src_type_variants_basic
311318
decoders = []
312319
for i, src_type_variant in enumerate(src_type_variants):
313320
with ChildFieldPath(field_path, f"[{i}]"):
@@ -331,7 +338,7 @@ def decode(value: Any) -> Any | None:
331338
if isinstance(dst_type_variant, AnalyzedAnyType):
332339
return lambda value: value
333340

334-
if src_type_kind == "Vector":
341+
if isinstance(src_type, BasicValueType) and src_type.kind == "Vector":
335342
field_path_str = "".join(field_path)
336343
if not isinstance(dst_type_variant, AnalyzedListType):
337344
raise ValueError(
@@ -350,9 +357,11 @@ def decode(value: Any) -> Any | None:
350357
if is_numpy_number_type(dst_type_variant.elem_type):
351358
scalar_dtype = dst_type_variant.elem_type
352359
else:
360+
# mypy: vector info exists for Vector kind
361+
assert src_type.vector is not None # type: ignore[unreachable]
353362
vec_elem_decoder = make_engine_value_decoder(
354363
field_path + ["[*]"],
355-
src_type["element_type"],
364+
src_type.vector.element_type,
356365
analyze_type_info(
357366
dst_type_variant.elem_type if dst_type_variant else Any
358367
),
@@ -432,7 +441,7 @@ def _get_auto_default_for_type(
432441

433442
def make_engine_struct_decoder(
434443
field_path: list[str],
435-
src_fields: list[dict[str, Any]],
444+
src_fields: list[FieldSchema],
436445
dst_type_info: AnalyzedTypeInfo,
437446
for_key: bool = False,
438447
) -> Callable[[list[Any]], Any]:
@@ -461,7 +470,7 @@ def make_engine_struct_decoder(
461470
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
462471
)
463472

464-
src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
473+
src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
465474
dst_struct_type = dst_type_variant.struct_type
466475

467476
parameters: Mapping[str, inspect.Parameter]
@@ -493,7 +502,10 @@ def make_closure_for_field(
493502
with ChildFieldPath(field_path, f".{name}"):
494503
if src_idx is not None:
495504
field_decoder = make_engine_value_decoder(
496-
field_path, src_fields[src_idx]["type"], type_info, for_key=for_key
505+
field_path,
506+
src_fields[src_idx].value_type.type,
507+
type_info,
508+
for_key=for_key,
497509
)
498510
return lambda values: field_decoder(values[src_idx])
499511

@@ -526,19 +538,19 @@ def make_closure_for_field(
526538

527539
def _make_engine_struct_to_dict_decoder(
528540
field_path: list[str],
529-
src_fields: list[dict[str, Any]],
541+
src_fields: list[FieldSchema],
530542
value_type_annotation: Any,
531543
) -> Callable[[list[Any] | None], dict[str, Any] | None]:
532544
"""Make a decoder from engine field values to a Python dict."""
533545

534546
field_decoders = []
535547
value_type_info = analyze_type_info(value_type_annotation)
536548
for field_schema in src_fields:
537-
field_name = field_schema["name"]
549+
field_name = field_schema.name
538550
with ChildFieldPath(field_path, f".{field_name}"):
539551
field_decoder = make_engine_value_decoder(
540552
field_path,
541-
field_schema["type"],
553+
field_schema.value_type.type,
542554
value_type_info,
543555
)
544556
field_decoders.append((field_name, field_decoder))
@@ -560,19 +572,19 @@ def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
560572

561573
def _make_engine_struct_to_tuple_decoder(
562574
field_path: list[str],
563-
src_fields: list[dict[str, Any]],
575+
src_fields: list[FieldSchema],
564576
) -> Callable[[list[Any] | None], tuple[Any, ...] | None]:
565577
"""Make a decoder from engine field values to a Python tuple."""
566578

567579
field_decoders = []
568580
value_type_info = analyze_type_info(Any)
569581
for field_schema in src_fields:
570-
field_name = field_schema["name"]
582+
field_name = field_schema.name
571583
with ChildFieldPath(field_path, f".{field_name}"):
572584
field_decoders.append(
573585
make_engine_value_decoder(
574586
field_path,
575-
field_schema["type"],
587+
field_schema.value_type.type,
576588
value_type_info,
577589
)
578590
)

python/cocoindex/flow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from .op import FunctionSpec
4141
from .runtime import execution_context, to_async_call
4242
from .setup import SetupChangeBundle
43-
from .typing import analyze_type_info, encode_enriched_type
43+
from .typing import analyze_type_info, encode_enriched_type, decode_engine_value_type
4444
from .query_handler import QueryHandlerInfo, QueryHandlerResultFields
4545
from .validation import (
4646
validate_flow_name,
@@ -1164,7 +1164,9 @@ async def _build_flow_info_async(self) -> TransformFlowInfo:
11641164
inspect.signature(self._flow_fn).return_annotation
11651165
)
11661166
result_decoder = make_engine_value_decoder(
1167-
[], engine_return_type["type"], analyze_type_info(python_return_type)
1167+
[],
1168+
decode_engine_value_type(engine_return_type["type"]),
1169+
analyze_type_info(python_return_type),
11681170
)
11691171

11701172
return TransformFlowInfo(engine_flow, result_decoder)

python/cocoindex/op.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
analyze_type_info,
3131
AnalyzedAnyType,
3232
AnalyzedDictType,
33+
EnrichedValueType,
34+
decode_engine_field_schemas,
3335
)
3436
from .runtime import to_async_call
3537

@@ -212,8 +214,9 @@ def process_arg(
212214
TypeAttr(related_attr.value, actual_arg.analyzed_value)
213215
)
214216
type_info = analyze_type_info(arg_param.annotation)
217+
enriched = EnrichedValueType.decode(actual_arg.value_type)
215218
decoder = make_engine_value_decoder(
216-
[arg_name], actual_arg.value_type["type"], type_info
219+
[arg_name], enriched.type, type_info
217220
)
218221
is_required = not type_info.nullable
219222
if is_required and actual_arg.value_type.get("nullable", False):
@@ -527,10 +530,14 @@ def create_export_context(
527530
)
528531

529532
key_decoder = make_engine_key_decoder(
530-
["(key)"], key_fields_schema, analyze_type_info(key_annotation)
533+
["(key)"],
534+
decode_engine_field_schemas(key_fields_schema),
535+
analyze_type_info(key_annotation),
531536
)
532537
value_decoder = make_engine_struct_decoder(
533-
["(value)"], value_fields_schema, analyze_type_info(value_annotation)
538+
["(value)"],
539+
decode_engine_field_schemas(value_fields_schema),
540+
analyze_type_info(value_annotation),
534541
)
535542

536543
loaded_spec = _load_spec_from_engine(self._spec_cls, spec)

0 commit comments

Comments
 (0)