Skip to content

Commit 70f3b71

Browse files
authored
refactor: make the logic to encode/decode values better arranged (#800)
* refactor: make the logic to encode/decode values better arranged * docs: fix docstring for new types
1 parent 92e4701 commit 70f3b71

File tree

4 files changed

+397
-407
lines changed

4 files changed

+397
-407
lines changed

python/cocoindex/convert.py

Lines changed: 141 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,19 @@
1313
from .typing import (
1414
KEY_FIELD_NAME,
1515
TABLE_TYPES,
16-
DtypeRegistry,
1716
analyze_type_info,
1817
encode_enriched_type,
19-
extract_ndarray_scalar_dtype,
2018
is_namedtuple_type,
2119
is_struct_type,
20+
AnalyzedTypeInfo,
21+
AnalyzedAnyType,
22+
AnalyzedDictType,
23+
AnalyzedListType,
24+
AnalyzedBasicType,
25+
AnalyzedUnionType,
26+
AnalyzedUnknownType,
27+
AnalyzedStructType,
28+
is_numpy_number_type,
2229
)
2330

2431

@@ -79,46 +86,88 @@ def make_engine_value_decoder(
7986
Returns:
8087
A decoder from an engine value to a Python value.
8188
"""
89+
8290
src_type_kind = src_type["kind"]
8391

84-
dst_is_any = (
85-
dst_annotation is None
86-
or dst_annotation is inspect.Parameter.empty
87-
or dst_annotation is Any
88-
)
89-
if dst_is_any:
90-
if src_type_kind == "Union":
91-
return lambda value: value[1]
92-
if src_type_kind == "Struct":
93-
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
94-
if src_type_kind in TABLE_TYPES:
95-
if src_type_kind == "LTable":
92+
dst_type_info = analyze_type_info(dst_annotation)
93+
dst_type_variant = dst_type_info.variant
94+
95+
if isinstance(dst_type_variant, AnalyzedUnknownType):
96+
raise ValueError(
97+
f"Type mismatch for `{''.join(field_path)}`: "
98+
f"declared `{dst_type_info.core_type}`, an unsupported type"
99+
)
100+
101+
if src_type_kind == "Struct":
102+
return _make_engine_struct_value_decoder(
103+
field_path,
104+
src_type["fields"],
105+
dst_type_info,
106+
)
107+
108+
if src_type_kind in TABLE_TYPES:
109+
field_path.append("[*]")
110+
engine_fields_schema = src_type["row"]["fields"]
111+
112+
if src_type_kind == "LTable":
113+
if isinstance(dst_type_variant, AnalyzedAnyType):
96114
return _make_engine_ltable_to_list_dict_decoder(
97-
field_path, src_type["row"]["fields"]
115+
field_path, engine_fields_schema
116+
)
117+
if not isinstance(dst_type_variant, AnalyzedListType):
118+
raise ValueError(
119+
f"Type mismatch for `{''.join(field_path)}`: "
120+
f"declared `{dst_type_info.core_type}`, a list type expected"
98121
)
99-
elif src_type_kind == "KTable":
122+
row_decoder = _make_engine_struct_value_decoder(
123+
field_path,
124+
engine_fields_schema,
125+
analyze_type_info(dst_type_variant.elem_type),
126+
)
127+
128+
def decode(value: Any) -> Any | None:
129+
if value is None:
130+
return None
131+
return [row_decoder(v) for v in value]
132+
133+
elif src_type_kind == "KTable":
134+
if isinstance(dst_type_variant, AnalyzedAnyType):
100135
return _make_engine_ktable_to_dict_dict_decoder(
101-
field_path, src_type["row"]["fields"]
136+
field_path, engine_fields_schema
137+
)
138+
if not isinstance(dst_type_variant, AnalyzedDictType):
139+
raise ValueError(
140+
f"Type mismatch for `{''.join(field_path)}`: "
141+
f"declared `{dst_type_info.core_type}`, a dict type expected"
102142
)
103-
return lambda value: value
104143

105-
# Handle struct -> dict binding for explicit dict annotations
106-
is_dict_annotation = False
107-
if dst_annotation is dict:
108-
is_dict_annotation = True
109-
elif getattr(dst_annotation, "__origin__", None) is dict:
110-
args = getattr(dst_annotation, "__args__", ())
111-
if args == (str, Any):
112-
is_dict_annotation = True
113-
if is_dict_annotation and src_type_kind == "Struct":
114-
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
144+
key_field_schema = engine_fields_schema[0]
145+
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
146+
key_decoder = make_engine_value_decoder(
147+
field_path, key_field_schema["type"], dst_type_variant.key_type
148+
)
149+
field_path.pop()
150+
value_decoder = _make_engine_struct_value_decoder(
151+
field_path,
152+
engine_fields_schema[1:],
153+
analyze_type_info(dst_type_variant.value_type),
154+
)
115155

116-
dst_type_info = analyze_type_info(dst_annotation)
156+
def decode(value: Any) -> Any | None:
157+
if value is None:
158+
return None
159+
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
160+
161+
field_path.pop()
162+
return decode
117163

118164
if src_type_kind == "Union":
165+
if isinstance(dst_type_variant, AnalyzedAnyType):
166+
return lambda value: value[1]
167+
119168
dst_type_variants = (
120-
dst_type_info.union_variant_types
121-
if dst_type_info.union_variant_types is not None
169+
dst_type_variant.variant_types
170+
if isinstance(dst_type_variant, AnalyzedUnionType)
122171
else [dst_annotation]
123172
)
124173
src_type_variants = src_type["types"]
@@ -142,43 +191,36 @@ def make_engine_value_decoder(
142191
decoders.append(decoder)
143192
return lambda value: decoders[value[0]](value[1])
144193

145-
if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
146-
raise ValueError(
147-
f"Type mismatch for `{''.join(field_path)}`: "
148-
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
149-
)
150-
151-
if dst_type_info.kind in ("Float32", "Float64", "Int64"):
152-
dst_core_type = dst_type_info.core_type
153-
154-
def decode_scalar(value: Any) -> Any | None:
155-
if value is None:
156-
if dst_type_info.nullable:
157-
return None
158-
raise ValueError(
159-
f"Received null for non-nullable scalar `{''.join(field_path)}`"
160-
)
161-
return dst_core_type(value)
162-
163-
return decode_scalar
194+
if isinstance(dst_type_variant, AnalyzedAnyType):
195+
return lambda value: value
164196

165197
if src_type_kind == "Vector":
166198
field_path_str = "".join(field_path)
199+
if not isinstance(dst_type_variant, AnalyzedListType):
200+
raise ValueError(
201+
f"Type mismatch for `{''.join(field_path)}`: "
202+
f"declared `{dst_type_info.core_type}`, a list type expected"
203+
)
167204
expected_dim = (
168-
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
205+
dst_type_variant.vector_info.dim
206+
if dst_type_variant and dst_type_variant.vector_info
207+
else None
169208
)
170209

171-
elem_decoder = None
210+
vec_elem_decoder = None
172211
scalar_dtype = None
173-
if dst_type_info.np_number_type is None: # for Non-NDArray vector
174-
elem_decoder = make_engine_value_decoder(
212+
if (
213+
dst_type_variant
214+
and is_numpy_number_type(dst_type_variant.elem_type)
215+
and dst_type_info.base_type is np.ndarray
216+
):
217+
scalar_dtype = dst_type_variant.elem_type
218+
else:
219+
vec_elem_decoder = make_engine_value_decoder(
175220
field_path + ["[*]"],
176221
src_type["element_type"],
177-
dst_type_info.elem_type,
222+
dst_type_variant and dst_type_variant.elem_type,
178223
)
179-
else: # for NDArray vector
180-
scalar_dtype = extract_ndarray_scalar_dtype(dst_type_info.np_number_type)
181-
_ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
182224

183225
def decode_vector(value: Any) -> Any | None:
184226
if value is None:
@@ -197,66 +239,70 @@ def decode_vector(value: Any) -> Any | None:
197239
f"expected {expected_dim}, got {len(value)}"
198240
)
199241

200-
if elem_decoder is not None: # for Non-NDArray vector
201-
return [elem_decoder(v) for v in value]
242+
if vec_elem_decoder is not None: # for Non-NDArray vector
243+
return [vec_elem_decoder(v) for v in value]
202244
else: # for NDArray vector
203245
return np.array(value, dtype=scalar_dtype)
204246

205247
return decode_vector
206248

207-
if dst_type_info.struct_type is not None:
208-
return _make_engine_struct_value_decoder(
209-
field_path, src_type["fields"], dst_type_info.struct_type
210-
)
211-
212-
if src_type_kind in TABLE_TYPES:
213-
field_path.append("[*]")
214-
elem_type_info = analyze_type_info(dst_type_info.elem_type)
215-
if elem_type_info.struct_type is None:
249+
if isinstance(dst_type_variant, AnalyzedBasicType):
250+
if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
216251
raise ValueError(
217252
f"Type mismatch for `{''.join(field_path)}`: "
218-
f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected"
219-
)
220-
engine_fields_schema = src_type["row"]["fields"]
221-
if elem_type_info.key_type is not None:
222-
key_field_schema = engine_fields_schema[0]
223-
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
224-
key_decoder = make_engine_value_decoder(
225-
field_path, key_field_schema["type"], elem_type_info.key_type
226-
)
227-
field_path.pop()
228-
value_decoder = _make_engine_struct_value_decoder(
229-
field_path, engine_fields_schema[1:], elem_type_info.struct_type
253+
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_variant.kind})"
230254
)
231255

232-
def decode(value: Any) -> Any | None:
233-
if value is None:
234-
return None
235-
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
236-
else:
237-
elem_decoder = _make_engine_struct_value_decoder(
238-
field_path, engine_fields_schema, elem_type_info.struct_type
239-
)
256+
if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
257+
dst_core_type = dst_type_info.core_type
240258

241-
def decode(value: Any) -> Any | None:
259+
def decode_scalar(value: Any) -> Any | None:
242260
if value is None:
243-
return None
244-
return [elem_decoder(v) for v in value]
261+
if dst_type_info.nullable:
262+
return None
263+
raise ValueError(
264+
f"Received null for non-nullable scalar `{''.join(field_path)}`"
265+
)
266+
return dst_core_type(value)
245267

246-
field_path.pop()
247-
return decode
268+
return decode_scalar
248269

249270
return lambda value: value
250271

251272

252273
def _make_engine_struct_value_decoder(
253274
field_path: list[str],
254275
src_fields: list[dict[str, Any]],
255-
dst_struct_type: type,
276+
dst_type_info: AnalyzedTypeInfo,
256277
) -> Callable[[list[Any]], Any]:
257278
"""Make a decoder from an engine field values to a Python value."""
258279

280+
dst_type_variant = dst_type_info.variant
281+
282+
use_dict = False
283+
if isinstance(dst_type_variant, AnalyzedAnyType):
284+
use_dict = True
285+
elif isinstance(dst_type_variant, AnalyzedDictType):
286+
analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
287+
analyzed_value_type = analyze_type_info(dst_type_variant.value_type)
288+
use_dict = (
289+
isinstance(analyzed_key_type.variant, AnalyzedAnyType)
290+
or (
291+
isinstance(analyzed_key_type.variant, AnalyzedBasicType)
292+
and analyzed_key_type.variant.kind == "Str"
293+
)
294+
) and isinstance(analyzed_value_type.variant, AnalyzedAnyType)
295+
if use_dict:
296+
return _make_engine_struct_to_dict_decoder(field_path, src_fields)
297+
298+
if not isinstance(dst_type_variant, AnalyzedStructType):
299+
raise ValueError(
300+
f"Type mismatch for `{''.join(field_path)}`: "
301+
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
302+
)
303+
259304
src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
305+
dst_struct_type = dst_type_variant.struct_type
260306

261307
parameters: Mapping[str, inspect.Parameter]
262308
if dataclasses.is_dataclass(dst_struct_type):

python/cocoindex/tests/test_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ class MixedStruct:
12161216
numpy_float: np.float64
12171217
python_float: float
12181218
string: str
1219-
annotated_int: Annotated[np.int64, TypeKind("int")]
1219+
annotated_int: Annotated[np.int64, TypeKind("Int64")]
12201220
annotated_float: Float32
12211221

12221222
instance = MixedStruct(

0 commit comments

Comments
 (0)