Skip to content

Commit fb39083

Browse files
authored
test(convert): cover roundtrip validation for all data types (#703)
* test(convert): cover roundtrip validation for all data types * fix(convert): update dictionary encoding to distinguish KTable and JSON types * fix(convert): handle raw JSON dicts as leaf values
1 parent 591db49 commit fb39083

File tree

3 files changed

+149
-18
lines changed

3 files changed

+149
-18
lines changed

python/cocoindex/convert.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
encode_enriched_type,
1919
extract_ndarray_scalar_dtype,
2020
is_namedtuple_type,
21+
is_struct_type,
2122
)
2223

2324

@@ -37,9 +38,15 @@ def encode_engine_value(value: Any) -> Any:
3738
if isinstance(value, (list, tuple)):
3839
return [encode_engine_value(v) for v in value]
3940
if isinstance(value, dict):
40-
return [
41-
[encode_engine_value(k)] + encode_engine_value(v) for k, v in value.items()
42-
]
41+
if not value:
42+
return {}
43+
44+
first_val = next(iter(value.values()))
45+
if is_struct_type(type(first_val)): # KTable
46+
return [
47+
[encode_engine_value(k)] + encode_engine_value(v)
48+
for k, v in value.items()
49+
]
4350
return value
4451

4552

python/cocoindex/tests/test_convert.py

Lines changed: 135 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def validate_full_roundtrip_to(
9393
def eq(a: Any, b: Any) -> bool:
9494
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
9595
return np.array_equal(a, b)
96-
return type(a) == type(b) and not not (a == b)
96+
return type(a) is type(b) and not not (a == b)
9797

9898
encoded_value = encode_engine_value(value)
9999
value_type = value_type or type(value)
@@ -229,6 +229,11 @@ def test_encode_engine_value_none() -> None:
229229

230230

231231
def test_roundtrip_basic_types() -> None:
232+
validate_full_roundtrip(b"hello world", bytes, (b"hello world", None))
233+
validate_full_roundtrip(b"\x00\x01\x02\xff\xfe", bytes)
234+
validate_full_roundtrip("hello", str, ("hello", None))
235+
validate_full_roundtrip(True, bool, (True, None))
236+
validate_full_roundtrip(False, bool, (False, None))
232237
validate_full_roundtrip(
233238
42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, None)
234239
)
@@ -252,10 +257,29 @@ def test_roundtrip_basic_types() -> None:
252257
)
253258
validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))
254259

255-
validate_full_roundtrip("hello", str, ("hello", None))
256-
validate_full_roundtrip(True, bool, (True, None))
257-
validate_full_roundtrip(False, bool, (False, None))
258-
validate_full_roundtrip((1, 2), cocoindex.Range, ((1, 2), None))
260+
261+
def test_roundtrip_uuid() -> None:
262+
uuid_value = uuid.uuid4()
263+
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None))
264+
265+
266+
def test_roundtrip_range() -> None:
267+
r1 = (0, 100)
268+
validate_full_roundtrip(r1, cocoindex.Range, (r1, None))
269+
r2 = (50, 50)
270+
validate_full_roundtrip(r2, cocoindex.Range, (r2, None))
271+
r3 = (0, 1_000_000_000)
272+
validate_full_roundtrip(r3, cocoindex.Range, (r3, None))
273+
274+
275+
def test_roundtrip_time() -> None:
276+
t1 = datetime.time(10, 30, 50, 123456)
277+
validate_full_roundtrip(t1, datetime.time, (t1, None))
278+
t2 = datetime.time(23, 59, 59)
279+
validate_full_roundtrip(t2, datetime.time, (t2, None))
280+
t3 = datetime.time(0, 0, 0)
281+
validate_full_roundtrip(t3, datetime.time, (t3, None))
282+
259283
validate_full_roundtrip(
260284
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), None)
261285
)
@@ -297,8 +321,38 @@ def test_roundtrip_basic_types() -> None:
297321
),
298322
)
299323

300-
uuid_value = uuid.uuid4()
301-
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None))
324+
325+
def test_roundtrip_timedelta() -> None:
326+
td1 = datetime.timedelta(
327+
days=5, seconds=10, microseconds=123, milliseconds=456, minutes=30, hours=2
328+
)
329+
validate_full_roundtrip(td1, datetime.timedelta, (td1, None))
330+
td2 = datetime.timedelta(days=-5, hours=-2)
331+
validate_full_roundtrip(td2, datetime.timedelta, (td2, None))
332+
td3 = datetime.timedelta(0)
333+
validate_full_roundtrip(td3, datetime.timedelta, (td3, None))
334+
335+
336+
def test_roundtrip_json() -> None:
337+
simple_dict = {"key": "value", "number": 123, "bool": True, "float": 1.23}
338+
validate_full_roundtrip(simple_dict, cocoindex.Json)
339+
340+
simple_list = [1, "string", False, None, 4.56]
341+
validate_full_roundtrip(simple_list, cocoindex.Json)
342+
343+
nested_structure = {
344+
"name": "Test Json",
345+
"version": 1.0,
346+
"items": [
347+
{"id": 1, "value": "item1"},
348+
{"id": 2, "value": None, "props": {"active": True}},
349+
],
350+
"metadata": None,
351+
}
352+
validate_full_roundtrip(nested_structure, cocoindex.Json)
353+
354+
validate_full_roundtrip({}, cocoindex.Json)
355+
validate_full_roundtrip([], cocoindex.Json)
302356

303357

304358
def test_decode_scalar_numpy_values() -> None:
@@ -675,6 +729,21 @@ def test_roundtrip_union_with_vector() -> None:
675729
validate_full_roundtrip(value, t, ([1.0, 2.0, 3.0], list[float] | str))
676730

677731

732+
def test_roundtrip_union_with_misc_types() -> None:
733+
t_bytes_union = int | bytes | str
734+
validate_full_roundtrip(b"test_bytes", t_bytes_union)
735+
validate_full_roundtrip(123, t_bytes_union)
736+
737+
t_range_union = cocoindex.Range | str | bool
738+
validate_full_roundtrip((100, 200), t_range_union)
739+
validate_full_roundtrip("test_string", t_range_union)
740+
741+
t_json_union = cocoindex.Json | int | bytes
742+
json_dict = {"a": 1, "b": [2, 3]}
743+
validate_full_roundtrip(json_dict, t_json_union)
744+
validate_full_roundtrip(b"another_byte_string", t_json_union)
745+
746+
678747
def test_roundtrip_ltable() -> None:
679748
t = list[Order]
680749
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
@@ -688,10 +757,26 @@ def test_roundtrip_ltable() -> None:
688757
validate_full_roundtrip(value_nt, t_nt)
689758

690759

691-
def test_roundtrip_ktable_str_key() -> None:
692-
t = dict[str, Order]
693-
value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
694-
validate_full_roundtrip(value, t)
760+
def test_roundtrip_ktable_various_key_types() -> None:
761+
@dataclass
762+
class SimpleValue:
763+
data: str
764+
765+
t_bytes_key = dict[bytes, SimpleValue]
766+
value_bytes_key = {b"key1": SimpleValue("val1"), b"key2": SimpleValue("val2")}
767+
validate_full_roundtrip(value_bytes_key, t_bytes_key)
768+
769+
t_int_key = dict[int, SimpleValue]
770+
value_int_key = {1: SimpleValue("val1"), 2: SimpleValue("val2")}
771+
validate_full_roundtrip(value_int_key, t_int_key)
772+
773+
t_bool_key = dict[bool, SimpleValue]
774+
value_bool_key = {True: SimpleValue("val_true"), False: SimpleValue("val_false")}
775+
validate_full_roundtrip(value_bool_key, t_bool_key)
776+
777+
t_str_key = dict[str, Order]
778+
value_str_key = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
779+
validate_full_roundtrip(value_str_key, t_str_key)
695780

696781
t_nt = dict[str, OrderNamedTuple]
697782
value_nt = {
@@ -700,6 +785,27 @@ def test_roundtrip_ktable_str_key() -> None:
700785
}
701786
validate_full_roundtrip(value_nt, t_nt)
702787

788+
t_range_key = dict[cocoindex.Range, SimpleValue]
789+
value_range_key = {
790+
(1, 10): SimpleValue("val_range1"),
791+
(20, 30): SimpleValue("val_range2"),
792+
}
793+
validate_full_roundtrip(value_range_key, t_range_key)
794+
795+
t_date_key = dict[datetime.date, SimpleValue]
796+
value_date_key = {
797+
datetime.date(2023, 1, 1): SimpleValue("val_date1"),
798+
datetime.date(2024, 2, 2): SimpleValue("val_date2"),
799+
}
800+
validate_full_roundtrip(value_date_key, t_date_key)
801+
802+
t_uuid_key = dict[uuid.UUID, SimpleValue]
803+
value_uuid_key = {
804+
uuid.uuid4(): SimpleValue("val_uuid1"),
805+
uuid.uuid4(): SimpleValue("val_uuid2"),
806+
}
807+
validate_full_roundtrip(value_uuid_key, t_uuid_key)
808+
703809

704810
def test_roundtrip_ktable_struct_key() -> None:
705811
@dataclass(frozen=True)
@@ -990,6 +1096,24 @@ def test_full_roundtrip_vector_numeric_types() -> None:
9901096
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
9911097

9921098

1099+
def test_full_roundtrip_vector_other_types() -> None:
1100+
"""Test full roundtrip for Vector with non-numeric basic types."""
1101+
uuid_list = [uuid.uuid4(), uuid.uuid4()]
1102+
validate_full_roundtrip(uuid_list, Vector[uuid.UUID], (uuid_list, list[uuid.UUID]))
1103+
1104+
date_list = [datetime.date(2023, 1, 1), datetime.date(2024, 10, 5)]
1105+
validate_full_roundtrip(
1106+
date_list, Vector[datetime.date], (date_list, list[datetime.date])
1107+
)
1108+
1109+
bool_list = [True, False, True, False]
1110+
validate_full_roundtrip(bool_list, Vector[bool], (bool_list, list[bool]))
1111+
1112+
validate_full_roundtrip([], Vector[uuid.UUID], ([], list[uuid.UUID]))
1113+
validate_full_roundtrip([], Vector[datetime.date], ([], list[datetime.date]))
1114+
validate_full_roundtrip([], Vector[bool], ([], list[bool]))
1115+
1116+
9931117
def test_roundtrip_vector_no_dimension() -> None:
9941118
"""Test full roundtrip for vector types without dimension annotation."""
9951119
value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)

python/cocoindex/typing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def is_namedtuple_type(t: type) -> bool:
108108
return isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields")
109109

110110

111-
def _is_struct_type(t: ElementType | None) -> bool:
111+
def is_struct_type(t: ElementType | None) -> bool:
112112
return isinstance(t, type) and (
113113
dataclasses.is_dataclass(t) or is_namedtuple_type(t)
114114
)
@@ -205,7 +205,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
205205
union_variant_types: typing.List[ElementType] | None = None
206206
key_type: type | None = None
207207
np_number_type: type | None = None
208-
if _is_struct_type(t):
208+
if is_struct_type(t):
209209
struct_type = t
210210

211211
if kind is None:
@@ -220,7 +220,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
220220
elem_type = args[0]
221221

222222
if kind is None:
223-
if _is_struct_type(elem_type):
223+
if is_struct_type(elem_type):
224224
kind = "LTable"
225225
if vector_info is not None:
226226
raise ValueError(
@@ -243,7 +243,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
243243
args = typing.get_args(t)
244244
elem_type = (args[0], args[1])
245245
kind = "KTable"
246-
elif base_type is types.UnionType:
246+
elif base_type in (types.UnionType, typing.Union):
247247
possible_types = typing.get_args(t)
248248
non_none_types = [
249249
arg for arg in possible_types if arg not in (None, types.NoneType)

0 commit comments

Comments
 (0)