From beedf1620c375853c7c160c4cd76e0ef53bb5ea5 Mon Sep 17 00:00:00 2001 From: lemorage Date: Sun, 6 Jul 2025 10:16:39 +0200 Subject: [PATCH 1/3] test(convert): cover roundtrip validation for all data types --- python/cocoindex/tests/test_convert.py | 146 +++++++++++++++++++++++-- python/cocoindex/typing.py | 2 +- 2 files changed, 137 insertions(+), 11 deletions(-) diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 008504169..26b88db16 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -94,7 +94,7 @@ def validate_full_roundtrip( def eq(a: Any, b: Any) -> bool: if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): return np.array_equal(a, b) - return type(a) == type(b) and not not (a == b) + return type(a) is type(b) and not not (a == b) encoded_value = encode_engine_value(value) value_type = value_type or type(value) @@ -218,6 +218,11 @@ def test_encode_engine_value_none() -> None: def test_roundtrip_basic_types() -> None: + validate_full_roundtrip(b"hello world", bytes, (b"hello world", None)) + validate_full_roundtrip(b"\x00\x01\x02\xff\xfe", bytes) + validate_full_roundtrip("hello", str, ("hello", None)) + validate_full_roundtrip(True, bool, (True, None)) + validate_full_roundtrip(False, bool, (False, None)) validate_full_roundtrip(42, int, (42, None)) validate_full_roundtrip(3.25, float, (3.25, Float64)) validate_full_roundtrip( @@ -226,9 +231,30 @@ def test_roundtrip_basic_types() -> None: validate_full_roundtrip( 3.25, Float32, (3.25, float), (np.float32(3.25), np.float32) ) - validate_full_roundtrip("hello", str, ("hello", None)) - validate_full_roundtrip(True, bool, (True, None)) - validate_full_roundtrip(False, bool, (False, None)) + + +def test_roundtrip_uuid() -> None: + uuid_value = uuid.uuid4() + validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None)) + + +def test_roundtrip_range() -> None: + r1 = (0, 100) + validate_full_roundtrip(r1, cocoindex.Range) + r2 = (50, 50) + validate_full_roundtrip(r2, cocoindex.Range) + r3 = (0, 1_000_000_000) + validate_full_roundtrip(r3, cocoindex.Range) + + +def test_roundtrip_time() -> None: + t1 = datetime.time(10, 30, 50, 123456) + validate_full_roundtrip(t1, datetime.time, (t1, None)) + t2 = datetime.time(23, 59, 59) + validate_full_roundtrip(t2, datetime.time, (t2, None)) + t3 = datetime.time(0, 0, 0) + validate_full_roundtrip(t3, datetime.time, (t3, None)) + validate_full_roundtrip( datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), None) ) @@ -247,8 +273,38 @@ def test_roundtrip_basic_types() -> None: ), ) - uuid_value = uuid.uuid4() - validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None)) + +def test_roundtrip_timedelta() -> None: + td1 = datetime.timedelta( + days=5, seconds=10, microseconds=123, milliseconds=456, minutes=30, hours=2 + ) + validate_full_roundtrip(td1, datetime.timedelta, (td1, None)) + td2 = datetime.timedelta(days=-5, hours=-2) + validate_full_roundtrip(td2, datetime.timedelta, (td2, None)) + td3 = datetime.timedelta(0) + validate_full_roundtrip(td3, datetime.timedelta, (td3, None)) + + +def test_roundtrip_json() -> None: + simple_dict = {"key": "value", "number": 123, "bool": True, "float": 1.23} + validate_full_roundtrip(simple_dict, cocoindex.Json, (simple_dict, dict)) + + simple_list = [1, "string", False, None, 4.56] + validate_full_roundtrip(simple_list, cocoindex.Json, (simple_list, list)) + + nested_structure = { + "name": "Test Json", + "version": 1.0, + "items": [ + {"id": 1, "value": "item1"}, + {"id": 2, "value": None, "props": {"active": True}}, + ], + "metadata": None, + } + validate_full_roundtrip(nested_structure, cocoindex.Json, (nested_structure, dict)) + + validate_full_roundtrip({}, cocoindex.Json, ({}, dict)) + validate_full_roundtrip([], cocoindex.Json, ([], list)) def test_decode_scalar_numpy_values() -> None: @@ -625,6 +681,21 @@ def test_roundtrip_union_with_vector() -> None: validate_full_roundtrip(value, t, ([1.0, 2.0, 3.0], list[float] | str)) +def test_roundtrip_union_with_misc_types() -> None: + t_bytes_union = int | bytes | str + validate_full_roundtrip(b"test_bytes", t_bytes_union) + validate_full_roundtrip(123, t_bytes_union) + + t_range_union = cocoindex.Range | str | bool + validate_full_roundtrip((100, 200), t_range_union) + validate_full_roundtrip("test_string", t_range_union) + + t_json_union = cocoindex.Json | int | bytes + json_dict = {"a": 1, "b": [2, 3]} + validate_full_roundtrip(json_dict, t_json_union) + validate_full_roundtrip(b"another_byte_string", t_json_union) + + def test_roundtrip_ltable() -> None: t = list[Order] value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)] @@ -638,10 +709,26 @@ def test_roundtrip_ltable() -> None: validate_full_roundtrip(value_nt, t_nt) -def test_roundtrip_ktable_str_key() -> None: - t = dict[str, Order] - value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)} - validate_full_roundtrip(value, t) +def test_roundtrip_ktable_various_key_types() -> None: + @dataclass + class SimpleValue: + data: str + + t_bytes_key = dict[bytes, SimpleValue] + value_bytes_key = {b"key1": SimpleValue("val1"), b"key2": SimpleValue("val2")} + validate_full_roundtrip(value_bytes_key, t_bytes_key) + + t_int_key = dict[int, SimpleValue] + value_int_key = {1: SimpleValue("val1"), 2: SimpleValue("val2")} + validate_full_roundtrip(value_int_key, t_int_key) + + t_bool_key = dict[bool, SimpleValue] + value_bool_key = {True: SimpleValue("val_true"), False: SimpleValue("val_false")} + validate_full_roundtrip(value_bool_key, t_bool_key) + + t_str_key = dict[str, Order] + value_str_key = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)} + validate_full_roundtrip(value_str_key, t_str_key) t_nt = dict[str, OrderNamedTuple] value_nt = { @@ -650,6 +737,27 @@ def test_roundtrip_ktable_str_key() -> None: } validate_full_roundtrip(value_nt, t_nt) + t_range_key = dict[cocoindex.Range, SimpleValue] + value_range_key = { + (1, 10): SimpleValue("val_range1"), + (20, 30): SimpleValue("val_range2"), + } + validate_full_roundtrip(value_range_key, t_range_key) + + t_date_key = dict[datetime.date, SimpleValue] + value_date_key = { + datetime.date(2023, 1, 1): SimpleValue("val_date1"), + datetime.date(2024, 2, 2): SimpleValue("val_date2"), + } + validate_full_roundtrip(value_date_key, t_date_key) + + t_uuid_key = dict[uuid.UUID, SimpleValue] + value_uuid_key = { + uuid.uuid4(): SimpleValue("val_uuid1"), + uuid.uuid4(): SimpleValue("val_uuid2"), + } + validate_full_roundtrip(value_uuid_key, t_uuid_key) + def test_roundtrip_ktable_struct_key() -> None: @dataclass(frozen=True) @@ -940,6 +1048,24 @@ def test_full_roundtrip_vector_numeric_types() -> None: validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]]) +def test_full_roundtrip_vector_other_types() -> None: + """Test full roundtrip for Vector with non-numeric basic types.""" + uuid_list = [uuid.uuid4(), uuid.uuid4()] + validate_full_roundtrip(uuid_list, Vector[uuid.UUID], (uuid_list, list[uuid.UUID])) + + date_list = [datetime.date(2023, 1, 1), datetime.date(2024, 10, 5)] + validate_full_roundtrip( + date_list, Vector[datetime.date], (date_list, list[datetime.date]) + ) + + bool_list = [True, False, True, False] + validate_full_roundtrip(bool_list, Vector[bool], (bool_list, list[bool])) + + validate_full_roundtrip([], Vector[uuid.UUID], ([], list[uuid.UUID])) + validate_full_roundtrip([], Vector[datetime.date], ([], list[datetime.date])) + validate_full_roundtrip([], Vector[bool], ([], list[bool])) + + def test_roundtrip_vector_no_dimension() -> None: """Test full roundtrip for vector types without dimension annotation.""" value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 861255947..de76fb513 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -242,7 +242,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: args = typing.get_args(t) elem_type = (args[0], args[1]) kind = "KTable" - elif base_type is types.UnionType: + elif base_type in (types.UnionType, typing.Union): possible_types = typing.get_args(t) non_none_types = [ arg for arg in possible_types if arg not in (None, types.NoneType) From 539fb212edf3e8bc59fee24bd4fdd425df3faa11 Mon Sep 17 00:00:00 2001 From: lemorage Date: Sun, 6 Jul 2025 14:03:35 +0200 Subject: [PATCH 2/3] fix(convert): update dictionary encoding to distinguish KTable and JSON types --- python/cocoindex/convert.py | 17 ++++++++++++++--- python/cocoindex/tests/test_convert.py | 10 +++++----- python/cocoindex/typing.py | 6 +++--- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 3d3dca8c1..8e49dbcfa 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -18,6 +18,7 @@ encode_enriched_type, extract_ndarray_scalar_dtype, is_namedtuple_type, + is_struct_type, ) @@ -37,9 +38,19 @@ def encode_engine_value(value: Any) -> Any: if isinstance(value, (list, tuple)): return [encode_engine_value(v) for v in value] if isinstance(value, dict): - return [ - [encode_engine_value(k)] + encode_engine_value(v) for k, v in value.items() - ] + if not value: + return {} + + first_val = next(iter(value.values())) + if is_struct_type(type(first_val)): # KTable + return [ + [encode_engine_value(k)] + encode_engine_value(v) + for k, v in value.items() + ] + else: # JSON + return { + encode_engine_value(k): encode_engine_value(v) for k, v in value.items() + } return value diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 26b88db16..507901e95 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -287,10 +287,10 @@ def test_roundtrip_timedelta() -> None: def test_roundtrip_json() -> None: simple_dict = {"key": "value", "number": 123, "bool": True, "float": 1.23} - validate_full_roundtrip(simple_dict, cocoindex.Json, (simple_dict, dict)) + validate_full_roundtrip(simple_dict, cocoindex.Json) simple_list = [1, "string", False, None, 4.56] - validate_full_roundtrip(simple_list, cocoindex.Json, (simple_list, list)) + validate_full_roundtrip(simple_list, cocoindex.Json) nested_structure = { "name": "Test Json", @@ -301,10 +301,10 @@ def test_roundtrip_json() -> None: ], "metadata": None, } - validate_full_roundtrip(nested_structure, cocoindex.Json, (nested_structure, dict)) + validate_full_roundtrip(nested_structure, cocoindex.Json) - validate_full_roundtrip({}, cocoindex.Json, ({}, dict)) - validate_full_roundtrip([], cocoindex.Json, ([], list)) + validate_full_roundtrip({}, cocoindex.Json) + validate_full_roundtrip([], cocoindex.Json) def test_decode_scalar_numpy_values() -> None: diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index de76fb513..8758ff48f 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -107,7 +107,7 @@ def is_namedtuple_type(t: type) -> bool: return isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields") -def _is_struct_type(t: ElementType | None) -> bool: +def is_struct_type(t: ElementType | None) -> bool: return isinstance(t, type) and ( dataclasses.is_dataclass(t) or is_namedtuple_type(t) ) @@ -204,7 +204,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: union_variant_types: typing.List[ElementType] | None = None key_type: type | None = None np_number_type: type | None = None - if _is_struct_type(t): + if is_struct_type(t): struct_type = t if kind is None: @@ -219,7 +219,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: elem_type = args[0] if kind is None: - if _is_struct_type(elem_type): + if is_struct_type(elem_type): kind = "LTable" if vector_info is not None: raise ValueError( From 3bde4aa1a153ac26308dfa1c01186675a69d15a3 Mon Sep 17 00:00:00 2001 From: lemorage Date: Mon, 7 Jul 2025 13:14:53 +0200 Subject: [PATCH 3/3] fix(convert): handle raw JSON dicts as leaf values --- python/cocoindex/convert.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 8e49dbcfa..52422d295 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -47,10 +47,6 @@ def encode_engine_value(value: Any) -> Any: [encode_engine_value(k)] + encode_engine_value(v) for k, v in value.items() ] - else: # JSON - return { - encode_engine_value(k): encode_engine_value(v) for k, v in value.items() - } return value