Skip to content

Commit c260cb0

Browse files
authored
refactor: create context manager for field paths for decoders (#811)
1 parent 50025f8 commit c260cb0

File tree

1 file changed

+105
-87
lines changed

1 file changed

+105
-87
lines changed

python/cocoindex/convert.py

Lines changed: 105 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Utilities to convert between Python and engine values.
33
"""
44

5+
from __future__ import annotations
6+
57
import dataclasses
68
import datetime
79
import inspect
@@ -29,6 +31,24 @@
2931
)
3032

3133

34+
class ChildFieldPath:
35+
"""Context manager to append a field to field_path on enter and pop it on exit."""
36+
37+
_field_path: list[str]
38+
_field_name: str
39+
40+
def __init__(self, field_path: list[str], field_name: str):
41+
self._field_path: list[str] = field_path
42+
self._field_name = field_name
43+
44+
def __enter__(self) -> ChildFieldPath:
45+
self._field_path.append(self._field_name)
46+
return self
47+
48+
def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
49+
self._field_path.pop()
50+
51+
3252
def encode_engine_value(value: Any) -> Any:
3353
"""Encode a Python value to an engine value."""
3454
if dataclasses.is_dataclass(value):
@@ -106,59 +126,58 @@ def make_engine_value_decoder(
106126
)
107127

108128
if src_type_kind in TABLE_TYPES:
109-
field_path.append("[*]")
110-
engine_fields_schema = src_type["row"]["fields"]
129+
with ChildFieldPath(field_path, "[*]"):
130+
engine_fields_schema = src_type["row"]["fields"]
111131

112-
if src_type_kind == "LTable":
113-
if isinstance(dst_type_variant, AnalyzedAnyType):
114-
return _make_engine_ltable_to_list_dict_decoder(
115-
field_path, engine_fields_schema
116-
)
117-
if not isinstance(dst_type_variant, AnalyzedListType):
118-
raise ValueError(
119-
f"Type mismatch for `{''.join(field_path)}`: "
120-
f"declared `{dst_type_info.core_type}`, a list type expected"
132+
if src_type_kind == "LTable":
133+
if isinstance(dst_type_variant, AnalyzedAnyType):
134+
return _make_engine_ltable_to_list_dict_decoder(
135+
field_path, engine_fields_schema
136+
)
137+
if not isinstance(dst_type_variant, AnalyzedListType):
138+
raise ValueError(
139+
f"Type mismatch for `{''.join(field_path)}`: "
140+
f"declared `{dst_type_info.core_type}`, a list type expected"
141+
)
142+
row_decoder = make_engine_struct_decoder(
143+
field_path,
144+
engine_fields_schema,
145+
analyze_type_info(dst_type_variant.elem_type),
121146
)
122-
row_decoder = make_engine_struct_decoder(
123-
field_path,
124-
engine_fields_schema,
125-
analyze_type_info(dst_type_variant.elem_type),
126-
)
127147

128-
def decode(value: Any) -> Any | None:
129-
if value is None:
130-
return None
131-
return [row_decoder(v) for v in value]
148+
def decode(value: Any) -> Any | None:
149+
if value is None:
150+
return None
151+
return [row_decoder(v) for v in value]
152+
153+
elif src_type_kind == "KTable":
154+
if isinstance(dst_type_variant, AnalyzedAnyType):
155+
return _make_engine_ktable_to_dict_dict_decoder(
156+
field_path, engine_fields_schema
157+
)
158+
if not isinstance(dst_type_variant, AnalyzedDictType):
159+
raise ValueError(
160+
f"Type mismatch for `{''.join(field_path)}`: "
161+
f"declared `{dst_type_info.core_type}`, a dict type expected"
162+
)
132163

133-
elif src_type_kind == "KTable":
134-
if isinstance(dst_type_variant, AnalyzedAnyType):
135-
return _make_engine_ktable_to_dict_dict_decoder(
136-
field_path, engine_fields_schema
164+
key_field_schema = engine_fields_schema[0]
165+
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
166+
key_decoder = make_engine_value_decoder(
167+
field_path, key_field_schema["type"], dst_type_variant.key_type
137168
)
138-
if not isinstance(dst_type_variant, AnalyzedDictType):
139-
raise ValueError(
140-
f"Type mismatch for `{''.join(field_path)}`: "
141-
f"declared `{dst_type_info.core_type}`, a dict type expected"
169+
field_path.pop()
170+
value_decoder = make_engine_struct_decoder(
171+
field_path,
172+
engine_fields_schema[1:],
173+
analyze_type_info(dst_type_variant.value_type),
142174
)
143175

144-
key_field_schema = engine_fields_schema[0]
145-
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
146-
key_decoder = make_engine_value_decoder(
147-
field_path, key_field_schema["type"], dst_type_variant.key_type
148-
)
149-
field_path.pop()
150-
value_decoder = make_engine_struct_decoder(
151-
field_path,
152-
engine_fields_schema[1:],
153-
analyze_type_info(dst_type_variant.value_type),
154-
)
155-
156-
def decode(value: Any) -> Any | None:
157-
if value is None:
158-
return None
159-
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
176+
def decode(value: Any) -> Any | None:
177+
if value is None:
178+
return None
179+
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
160180

161-
field_path.pop()
162181
return decode
163182

164183
if src_type_kind == "Union":
@@ -173,22 +192,22 @@ def decode(value: Any) -> Any | None:
173192
src_type_variants = src_type["types"]
174193
decoders = []
175194
for i, src_type_variant in enumerate(src_type_variants):
176-
src_field_path = field_path + [f"[{i}]"]
177-
decoder = None
178-
for dst_type_variant in dst_type_variants:
179-
try:
180-
decoder = make_engine_value_decoder(
181-
src_field_path, src_type_variant, dst_type_variant
195+
with ChildFieldPath(field_path, f"[{i}]"):
196+
decoder = None
197+
for dst_type_variant in dst_type_variants:
198+
try:
199+
decoder = make_engine_value_decoder(
200+
field_path, src_type_variant, dst_type_variant
201+
)
202+
break
203+
except ValueError:
204+
pass
205+
if decoder is None:
206+
raise ValueError(
207+
f"Type mismatch for `{''.join(field_path)}`: "
208+
f"cannot find matched target type for source type variant {src_type_variant}"
182209
)
183-
break
184-
except ValueError:
185-
pass
186-
if decoder is None:
187-
raise ValueError(
188-
f"Type mismatch for `{''.join(field_path)}`: "
189-
f"cannot find matched target type for source type variant {src_type_variant}"
190-
)
191-
decoders.append(decoder)
210+
decoders.append(decoder)
192211
return lambda value: decoders[value[0]](value[1])
193212

194213
if isinstance(dst_type_variant, AnalyzedAnyType):
@@ -325,25 +344,24 @@ def make_closure_for_value(
325344
name: str, param: inspect.Parameter
326345
) -> Callable[[list[Any]], Any]:
327346
src_idx = src_name_to_idx.get(name)
328-
if src_idx is not None:
329-
field_path.append(f".{name}")
330-
field_decoder = make_engine_value_decoder(
331-
field_path, src_fields[src_idx]["type"], param.annotation
332-
)
333-
field_path.pop()
334-
return (
335-
lambda values: field_decoder(values[src_idx])
336-
if len(values) > src_idx
337-
else param.default
338-
)
347+
with ChildFieldPath(field_path, f".{name}"):
348+
if src_idx is not None:
349+
field_decoder = make_engine_value_decoder(
350+
field_path, src_fields[src_idx]["type"], param.annotation
351+
)
352+
return (
353+
lambda values: field_decoder(values[src_idx])
354+
if len(values) > src_idx
355+
else param.default
356+
)
339357

340-
default_value = param.default
341-
if default_value is inspect.Parameter.empty:
342-
raise ValueError(
343-
f"Field without default value is missing in input: {''.join(field_path)}"
344-
)
358+
default_value = param.default
359+
if default_value is inspect.Parameter.empty:
360+
raise ValueError(
361+
f"Field without default value is missing in input: {''.join(field_path)}"
362+
)
345363

346-
return lambda _: default_value
364+
return lambda _: default_value
347365

348366
field_value_decoder = [
349367
make_closure_for_value(name, param) for (name, param) in parameters.items()
@@ -363,13 +381,12 @@ def _make_engine_struct_to_dict_decoder(
363381
field_decoders = []
364382
for i, field_schema in enumerate(src_fields):
365383
field_name = field_schema["name"]
366-
field_path.append(f".{field_name}")
367-
field_decoder = make_engine_value_decoder(
368-
field_path,
369-
field_schema["type"],
370-
Any, # Use Any for recursive decoding
371-
)
372-
field_path.pop()
384+
with ChildFieldPath(field_path, f".{field_name}"):
385+
field_decoder = make_engine_value_decoder(
386+
field_path,
387+
field_schema["type"],
388+
Any, # Use Any for recursive decoding
389+
)
373390
field_decoders.append((field_name, field_decoder))
374391

375392
def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
@@ -426,9 +443,10 @@ def _make_engine_ktable_to_dict_dict_decoder(
426443
value_fields_schema = src_fields[1:]
427444

428445
# Create decoders
429-
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
430-
key_decoder = make_engine_value_decoder(field_path, key_field_schema["type"], Any)
431-
field_path.pop()
446+
with ChildFieldPath(field_path, f".{key_field_schema.get('name', KEY_FIELD_NAME)}"):
447+
key_decoder = make_engine_value_decoder(
448+
field_path, key_field_schema["type"], Any
449+
)
432450

433451
value_decoder = _make_engine_struct_to_dict_decoder(field_path, value_fields_schema)
434452

0 commit comments

Comments
 (0)