Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 67 additions & 7 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dataclasses
import datetime
import inspect
import warnings
from enum import Enum
from typing import Any, Callable, Mapping, get_origin

Expand Down Expand Up @@ -129,7 +130,9 @@ def make_engine_value_decoder(
for dst_type_variant in dst_type_variants:
try:
decoder = make_engine_value_decoder(
src_field_path, src_type_variant, dst_type_variant
src_field_path,
src_type_variant,
dst_type_variant,
)
break
except ValueError:
Expand Down Expand Up @@ -206,7 +209,9 @@ def decode_vector(value: Any) -> Any | None:

if dst_type_info.struct_type is not None:
return _make_engine_struct_value_decoder(
field_path, src_type["fields"], dst_type_info.struct_type
field_path,
src_type["fields"],
dst_type_info.struct_type,
)

if src_type_kind in TABLE_TYPES:
Expand All @@ -222,11 +227,15 @@ def decode_vector(value: Any) -> Any | None:
key_field_schema = engine_fields_schema[0]
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
key_decoder = make_engine_value_decoder(
field_path, key_field_schema["type"], elem_type_info.key_type
field_path,
key_field_schema["type"],
elem_type_info.key_type,
)
field_path.pop()
value_decoder = _make_engine_struct_value_decoder(
field_path, engine_fields_schema[1:], elem_type_info.struct_type
field_path,
engine_fields_schema[1:],
elem_type_info.struct_type,
)

def decode(value: Any) -> Any | None:
Expand All @@ -235,7 +244,9 @@ def decode(value: Any) -> Any | None:
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
else:
elem_decoder = _make_engine_struct_value_decoder(
field_path, engine_fields_schema, elem_type_info.struct_type
field_path,
engine_fields_schema,
elem_type_info.struct_type,
)

def decode(value: Any) -> Any | None:
Expand All @@ -249,6 +260,41 @@ def decode(value: Any) -> Any | None:
return lambda value: value


def _get_auto_default_for_type(
annotation: Any, field_name: str, field_path: list[str]
) -> tuple[Any, bool]:
"""
Get an auto-default value for a type annotation if it's safe to do so.

Returns:
A tuple of (default_value, is_supported) where:
- default_value: The default value if auto-defaulting is supported
- is_supported: True if auto-defaulting is supported for this type
"""
if annotation is None or annotation is inspect.Parameter.empty or annotation is Any:
return None, False

try:
type_info = analyze_type_info(annotation)

# Case 1: Nullable types (Optional[T] or T | None)
if type_info.nullable:
return None, True

# Case 2: Table types (KTable or LTable)
if type_info.kind in TABLE_TYPES:
if type_info.kind == "LTable":
return [], True
elif type_info.kind == "KTable":
return {}, True

# For all other types, don't auto-default to avoid ambiguity
return None, False

except (ValueError, TypeError):
return None, False


def _make_engine_struct_value_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
Expand Down Expand Up @@ -285,7 +331,9 @@ def make_closure_for_value(
if src_idx is not None:
field_path.append(f".{name}")
field_decoder = make_engine_value_decoder(
field_path, src_fields[src_idx]["type"], param.annotation
field_path,
src_fields[src_idx]["type"],
param.annotation,
)
field_path.pop()
return (
Expand All @@ -296,8 +344,20 @@ def make_closure_for_value(

default_value = param.default
if default_value is inspect.Parameter.empty:
auto_default, is_supported = _get_auto_default_for_type(
param.annotation, name, field_path
)
if is_supported:
warnings.warn(
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
UserWarning,
stacklevel=3,
)
return lambda _: auto_default

raise ValueError(
f"Field without default value is missing in input: {''.join(field_path)}"
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
)

return lambda _: default_value
Expand Down
38 changes: 37 additions & 1 deletion python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import uuid
from dataclasses import dataclass, make_dataclass
from dataclasses import dataclass, make_dataclass, field
from typing import Annotated, Any, Callable, Literal, NamedTuple

import numpy as np
Expand Down Expand Up @@ -1468,3 +1468,39 @@ class Team:

# Test Any annotation
validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))


def test_auto_default_supported_and_unsupported() -> None:
from dataclasses import dataclass, field

@dataclass
class Base:
a: int
b: int

@dataclass
class ExtraFieldSupported:
a: int
b: int
c: list[int] = field(default_factory=list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is using an explicit default value - which is not an auto default case we start to support by this PR. I think we want to test with the following cases: a field with int | None type, and a field with list[Base] type (as LTable), and a field with dict[str, Base] type (as KTable), and all shouldn't have explicit default value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dataclass
class Base:
    a: int

@dataclass
class NullableField:
    a: int
    b: int | None

@dataclass
class LTableField:
    a: int
    b: list[Base]

@dataclass
class KTableField:
    a: int
    b: dict[str, Base]

@dataclass
class UnsupportedField:
    a: int
    b: int

would these dataclasses be fine?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look good. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the KtableField I am getting an error

TypeError: 'dict' object cannot be converted to 'Sequence' when using validate_full_roundtrip(KTableField(1, {}), KTableField) Can you guys help me out with this. Attaching the dumps below.

=================================== FAILURES ===================================
____________ test_auto_default_for_supported_and_unsupported_types _____________

def test_auto_default_for_supported_and_unsupported_types() -> None:
    @dataclass
    class Base:
        a: int

    @dataclass
    class NullableField:
        a: int
        b: int | None

    @dataclass
    class LTableField:
        a: int
        b: list[Base]

    @dataclass
    class KTableField:
        a: int
        b: dict[str, Base]

    @dataclass
    class UnsupportedField:
        a: int
        b: int

    engine_val = [1]

    validate_full_roundtrip(NullableField(1, None), NullableField)

    validate_full_roundtrip(LTableField(1, []), LTableField)

    decoder = build_engine_value_decoder(KTableField)
    result = decoder(engine_val)
    assert result == KTableField(1, {})
  validate_full_roundtrip(KTableField(1, {}), KTableField)

python/cocoindex/tests/test_convert.py:1508:


python/cocoindex/tests/test_convert.py:122: in validate_full_roundtrip
validate_full_roundtrip_to(


value = test_auto_default_for_supported_and_unsupported_types..KTableField(a=1, b={})
value_type = <class 'cocoindex.tests.test_convert.test_auto_default_for_supported_and_unsupported_types..KTableField'>
decoded_values = ((test_auto_default_for_supported_and_unsupported_types..KTableField(a=1, b={}), <class 'cocoindex.tests.test_convert.test_auto_default_for_supported_and_unsupported_types..KTableField'>),)
_engine = <module 'cocoindex._engine' from '/home/kushal/Desktop/Open-Source/cocoindex/python/cocoindex/_engine.cpython-312-x86_64-linux-gnu.so'>
eq = <function validate_full_roundtrip_to..eq at 0x7d7a6d6d9940>
encoded_value = [1, {}]

def validate_full_roundtrip_to(
    value: Any,
    value_type: Any,
    *decoded_values: tuple[Any, Any],
) -> None:
    """
    Validate the given value becomes specific values after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).

    `decoded_values` is a tuple of (value, type) pairs.
    """
    from cocoindex import _engine  # type: ignore

    def eq(a: Any, b: Any) -> bool:
        if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
            return np.array_equal(a, b)
        return type(a) is type(b) and not not (a == b)

    encoded_value = encode_engine_value(value)
    value_type = value_type or type(value)
    encoded_output_type = encode_enriched_type(value_type)["type"]
  value_from_engine = _engine.testutil.seder_roundtrip(
        encoded_value, encoded_output_type
    )

E TypeError: 'dict' object cannot be converted to 'Sequence'

python/cocoindex/tests/test_convert.py:101: TypeError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. The test case is validate_full_roundtrip(KTableField(1, {}), KTableField). Here the empty dict is being passed. When encoding struct fields that contained empty dicts (KTables), the function was preserving them as {}, but the expected engine format for KTables was [].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you both! Since fix for the issue (PR #807) may still need more time to be resolved, we can disable this specific test case (e.g. comment out the specific line) and merge the PR first. After the fix is in, we can re-enable it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks guys, I've commented out the line for now and have added a test case just for the decoder.


@dataclass
class ExtraFieldUnsupported:
a: int
b: int
c: int

engine_val = [1, 2]

# Should succeed: c is a list (LTable), auto-defaults to []
validate_full_roundtrip(
Base(1, 2), Base, (ExtraFieldSupported(1, 2, []), ExtraFieldSupported)
)

# Should fail: c is a non-nullable int, no default, not supported
with pytest.raises(
ValueError,
match=r"Field 'c' \(type <class 'int'>\) without default value is missing in input: ",
):
decoder = build_engine_value_decoder(Base, ExtraFieldUnsupported)
decoder(engine_val)
Loading