Skip to content

Commit 0783cdf

Browse files
authored
refactor: make_engine_value_decoder() takes AnalyzedTakesInfo (#819)
This is for better composability. We also cleaned up related logic and drop unnecessary support for cases that engine fields don't match engine type - which should not happen. This simplifies code and tests.
1 parent 6e6d28f commit 0783cdf

File tree

4 files changed

+87
-214
lines changed

4 files changed

+87
-214
lines changed

python/cocoindex/convert.py

Lines changed: 41 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
9494
def make_engine_value_decoder(
9595
field_path: list[str],
9696
src_type: dict[str, Any],
97-
dst_annotation: Any,
97+
dst_type_info: AnalyzedTypeInfo,
9898
) -> Callable[[Any], Any]:
9999
"""
100100
Make a decoder from an engine value to a Python value.
@@ -110,7 +110,6 @@ def make_engine_value_decoder(
110110

111111
src_type_kind = src_type["kind"]
112112

113-
dst_type_info = analyze_type_info(dst_annotation)
114113
dst_type_variant = dst_type_info.variant
115114

116115
if isinstance(dst_type_variant, AnalyzedUnknownType):
@@ -165,7 +164,9 @@ def decode(value: Any) -> Any | None:
165164
key_field_schema = engine_fields_schema[0]
166165
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
167166
key_decoder = make_engine_value_decoder(
168-
field_path, key_field_schema["type"], dst_type_variant.key_type
167+
field_path,
168+
key_field_schema["type"],
169+
analyze_type_info(dst_type_variant.key_type),
169170
)
170171
field_path.pop()
171172
value_decoder = make_engine_struct_decoder(
@@ -185,20 +186,20 @@ def decode(value: Any) -> Any | None:
185186
if isinstance(dst_type_variant, AnalyzedAnyType):
186187
return lambda value: value[1]
187188

188-
dst_type_variants = (
189-
dst_type_variant.variant_types
189+
dst_type_info_variants = (
190+
[analyze_type_info(t) for t in dst_type_variant.variant_types]
190191
if isinstance(dst_type_variant, AnalyzedUnionType)
191-
else [dst_annotation]
192+
else [dst_type_info]
192193
)
193194
src_type_variants = src_type["types"]
194195
decoders = []
195196
for i, src_type_variant in enumerate(src_type_variants):
196197
with ChildFieldPath(field_path, f"[{i}]"):
197198
decoder = None
198-
for dst_type_variant in dst_type_variants:
199+
for dst_type_info_variant in dst_type_info_variants:
199200
try:
200201
decoder = make_engine_value_decoder(
201-
field_path, src_type_variant, dst_type_variant
202+
field_path, src_type_variant, dst_type_info_variant
202203
)
203204
break
204205
except ValueError:
@@ -236,7 +237,7 @@ def decode(value: Any) -> Any | None:
236237
vec_elem_decoder = make_engine_value_decoder(
237238
field_path + ["[*]"],
238239
src_type["element_type"],
239-
dst_type_variant and dst_type_variant.elem_type,
240+
analyze_type_info(dst_type_variant and dst_type_variant.elem_type),
240241
)
241242

242243
def decode_vector(value: Any) -> Any | None:
@@ -267,7 +268,7 @@ def decode_vector(value: Any) -> Any | None:
267268
if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
268269
raise ValueError(
269270
f"Type mismatch for `{''.join(field_path)}`: "
270-
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_variant.kind})"
271+
f"passed in {src_type_kind}, declared {dst_type_info.core_type} ({dst_type_variant.kind})"
271272
)
272273

273274
if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
@@ -288,7 +289,7 @@ def decode_scalar(value: Any) -> Any | None:
288289

289290

290291
def _get_auto_default_for_type(
291-
annotation: Any, field_name: str, field_path: list[str]
292+
type_info: AnalyzedTypeInfo,
292293
) -> tuple[Any, bool]:
293294
"""
294295
Get an auto-default value for a type annotation if it's safe to do so.
@@ -298,52 +299,17 @@ def _get_auto_default_for_type(
298299
- default_value: The default value if auto-defaulting is supported
299300
- is_supported: True if auto-defaulting is supported for this type
300301
"""
301-
if annotation is None or annotation is inspect.Parameter.empty or annotation is Any:
302-
return None, False
303-
304-
try:
305-
type_info = analyze_type_info(annotation)
306-
307-
# Case 1: Nullable types (Optional[T] or T | None)
308-
if type_info.nullable:
309-
return None, True
310-
311-
# Case 2: Table types (KTable or LTable) - check if it's a list or dict type
312-
if isinstance(type_info.variant, AnalyzedListType):
313-
return [], True
314-
elif isinstance(type_info.variant, AnalyzedDictType):
315-
return {}, True
316-
317-
# For all other types, don't auto-default to avoid ambiguity
318-
return None, False
302+
# Case 1: Nullable types (Optional[T] or T | None)
303+
if type_info.nullable:
304+
return None, True
319305

320-
except (ValueError, TypeError):
321-
return None, False
322-
323-
324-
def _handle_missing_field_with_auto_default(
325-
param: inspect.Parameter, name: str, field_path: list[str]
326-
) -> Any:
327-
"""
328-
Handle missing field by trying auto-default or raising an error.
329-
330-
Returns the auto-default value if supported, otherwise raises ValueError.
331-
"""
332-
auto_default, is_supported = _get_auto_default_for_type(
333-
param.annotation, name, field_path
334-
)
335-
if is_supported:
336-
warnings.warn(
337-
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
338-
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
339-
UserWarning,
340-
stacklevel=4,
341-
)
342-
return auto_default
306+
# Case 2: Table types (KTable or LTable) - check if it's a list or dict type
307+
if isinstance(type_info.variant, AnalyzedListType):
308+
return [], True
309+
elif isinstance(type_info.variant, AnalyzedDictType):
310+
return {}, True
343311

344-
raise ValueError(
345-
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
346-
)
312+
return None, False
347313

348314

349315
def make_engine_struct_decoder(
@@ -400,40 +366,39 @@ def make_engine_struct_decoder(
400366
else:
401367
raise ValueError(f"Unsupported struct type: {dst_struct_type}")
402368

403-
def make_closure_for_value(
369+
def make_closure_for_field(
404370
name: str, param: inspect.Parameter
405371
) -> Callable[[list[Any]], Any]:
406372
src_idx = src_name_to_idx.get(name)
373+
type_info = analyze_type_info(param.annotation)
374+
407375
with ChildFieldPath(field_path, f".{name}"):
408376
if src_idx is not None:
409377
field_decoder = make_engine_value_decoder(
410-
field_path, src_fields[src_idx]["type"], param.annotation
378+
field_path, src_fields[src_idx]["type"], type_info
411379
)
412-
413-
def field_value_getter(values: list[Any]) -> Any:
414-
if src_idx is not None and len(values) > src_idx:
415-
return field_decoder(values[src_idx])
416-
default_value = param.default
417-
if default_value is not inspect.Parameter.empty:
418-
return default_value
419-
420-
return _handle_missing_field_with_auto_default(
421-
param, name, field_path
422-
)
423-
424-
return field_value_getter
380+
return lambda values: field_decoder(values[src_idx])
425381

426382
default_value = param.default
427383
if default_value is not inspect.Parameter.empty:
428384
return lambda _: default_value
429385

430-
auto_default = _handle_missing_field_with_auto_default(
431-
param, name, field_path
386+
auto_default, is_supported = _get_auto_default_for_type(type_info)
387+
if is_supported:
388+
warnings.warn(
389+
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
390+
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
391+
UserWarning,
392+
stacklevel=4,
393+
)
394+
return lambda _: auto_default
395+
396+
raise ValueError(
397+
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
432398
)
433-
return lambda _: auto_default
434399

435400
field_value_decoder = [
436-
make_closure_for_value(name, param) for (name, param) in parameters.items()
401+
make_closure_for_field(name, param) for (name, param) in parameters.items()
437402
]
438403

439404
return lambda values: dst_struct_type(
@@ -454,7 +419,7 @@ def _make_engine_struct_to_dict_decoder(
454419
field_decoder = make_engine_value_decoder(
455420
field_path,
456421
field_schema["type"],
457-
Any, # Use Any for recursive decoding
422+
analyze_type_info(Any), # Use Any for recursive decoding
458423
)
459424
field_decoders.append((field_name, field_decoder))
460425

@@ -514,7 +479,7 @@ def _make_engine_ktable_to_dict_dict_decoder(
514479
# Create decoders
515480
with ChildFieldPath(field_path, f".{key_field_schema.get('name', KEY_FIELD_NAME)}"):
516481
key_decoder = make_engine_value_decoder(
517-
field_path, key_field_schema["type"], Any
482+
field_path, key_field_schema["type"], analyze_type_info(Any)
518483
)
519484

520485
value_decoder = _make_engine_struct_to_dict_decoder(field_path, value_fields_schema)

python/cocoindex/flow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
validate_full_flow_name,
1717
validate_target_name,
1818
)
19+
from .typing import analyze_type_info
1920

2021
from dataclasses import dataclass
2122
from enum import Enum
@@ -1053,7 +1054,7 @@ async def _build_flow_info_async(self) -> TransformFlowInfo:
10531054
sig.return_annotation
10541055
)
10551056
result_decoder = make_engine_value_decoder(
1056-
[], engine_return_type["type"], python_return_type
1057+
[], engine_return_type["type"], analyze_type_info(python_return_type)
10571058
)
10581059

10591060
return TransformFlowInfo(engine_flow, result_decoder)

python/cocoindex/op.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
220220
)
221221
self._args_decoders.append(
222222
make_engine_value_decoder(
223-
[arg_name], arg.value_type["type"], arg_param.annotation
223+
[arg_name],
224+
arg.value_type["type"],
225+
analyze_type_info(arg_param.annotation),
224226
)
225227
)
226228
process_attribute(arg_name, arg)
@@ -252,7 +254,9 @@ def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
252254
)
253255
arg_param = expected_arg[1]
254256
self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
255-
[kwarg_name], kwarg.value_type["type"], arg_param.annotation
257+
[kwarg_name],
258+
kwarg.value_type["type"],
259+
analyze_type_info(arg_param.annotation),
256260
)
257261
process_attribute(kwarg_name, kwarg)
258262

@@ -505,7 +509,9 @@ def create_export_context(
505509

506510
if len(key_fields_schema) == 1:
507511
key_decoder = make_engine_value_decoder(
508-
["(key)"], key_fields_schema[0]["type"], key_annotation
512+
["(key)"],
513+
key_fields_schema[0]["type"],
514+
analyze_type_info(key_annotation),
509515
)
510516
else:
511517
key_decoder = make_engine_struct_decoder(

0 commit comments

Comments
 (0)