Skip to content

Commit afe61af

Browse files
authored
feat(encode-types): encode coco type's python classes and test roundtrip (#1030)
1 parent 183b41e commit afe61af

File tree

3 files changed

+188
-2
lines changed

3 files changed

+188
-2
lines changed

python/cocoindex/convert.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AnalyzedTypeInfo,
2323
AnalyzedUnionType,
2424
AnalyzedUnknownType,
25+
EnrichedValueType,
2526
analyze_type_info,
2627
encode_enriched_type,
2728
is_namedtuple_type,
@@ -611,6 +612,10 @@ def dump_engine_object(v: Any) -> Any:
611612
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
612613
if v is None:
613614
return None
615+
elif isinstance(v, EnrichedValueType):
616+
return v.encode()
617+
elif isinstance(v, FieldSchema):
618+
return v.encode()
614619
elif isinstance(v, type) or get_origin(v) is not None:
615620
return encode_enriched_type(v)
616621
elif isinstance(v, Enum):
@@ -660,6 +665,11 @@ def load_engine_object(expected_type: Any, v: Any) -> Any:
660665
type_info = analyze_type_info(expected_type)
661666
variant = type_info.variant
662667

668+
if type_info.core_type is EnrichedValueType:
669+
return EnrichedValueType.decode(v)
670+
if type_info.core_type is FieldSchema:
671+
return FieldSchema.decode(v)
672+
663673
# Any or unknown → return as-is
664674
if isinstance(variant, AnalyzedAnyType) or type_info.base_type is Any:
665675
return v

python/cocoindex/tests/test_typing.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Annotated, Any, Literal, NamedTuple, get_args, get_origin
66

77
import numpy as np
8-
import pytest
98
from numpy.typing import NDArray
109

1110
from cocoindex.typing import (
@@ -20,7 +19,10 @@
2019
Vector,
2120
VectorInfo,
2221
analyze_type_info,
22+
decode_engine_value_type,
2323
encode_enriched_type,
24+
encode_enriched_type_info,
25+
encode_engine_value_type,
2426
)
2527

2628

@@ -32,7 +34,7 @@ class SimpleDataclass:
3234

3335
class SimpleNamedTuple(NamedTuple):
3436
name: str
35-
value: Any
37+
value: int
3638

3739

3840
def test_ndarray_float32_no_dim() -> None:
@@ -427,3 +429,125 @@ def test_unknown_type() -> None:
427429
typ = set
428430
result = analyze_type_info(typ)
429431
assert isinstance(result.variant, AnalyzedUnknownType)
432+
433+
434+
# ========================= Encode/Decode Tests =========================
435+
436+
437+
def encode_type_from_annotation(t: Any) -> dict[str, Any]:
438+
"""Helper function to encode a Python type annotation to its dictionary representation."""
439+
return encode_enriched_type_info(analyze_type_info(t))
440+
441+
442+
def test_basic_types_encode_decode() -> None:
443+
"""Test encode/decode roundtrip for basic Python types."""
444+
test_cases = [
445+
str,
446+
int,
447+
float,
448+
bool,
449+
bytes,
450+
uuid.UUID,
451+
datetime.date,
452+
datetime.time,
453+
datetime.datetime,
454+
datetime.timedelta,
455+
]
456+
457+
for typ in test_cases:
458+
encoded = encode_type_from_annotation(typ)
459+
decoded = decode_engine_value_type(encoded["type"])
460+
reencoded = encode_engine_value_type(decoded)
461+
assert reencoded == encoded["type"]
462+
463+
464+
def test_vector_types_encode_decode() -> None:
465+
"""Test encode/decode roundtrip for vector types."""
466+
test_cases = [
467+
NDArray[np.float32],
468+
NDArray[np.float64],
469+
NDArray[np.int64],
470+
Vector[np.float32],
471+
Vector[np.float32, Literal[128]],
472+
Vector[str],
473+
]
474+
475+
for typ in test_cases:
476+
encoded = encode_type_from_annotation(typ)
477+
decoded = decode_engine_value_type(encoded["type"])
478+
reencoded = encode_engine_value_type(decoded)
479+
assert reencoded == encoded["type"]
480+
481+
482+
def test_struct_types_encode_decode() -> None:
483+
"""Test encode/decode roundtrip for struct types."""
484+
test_cases = [
485+
SimpleDataclass,
486+
SimpleNamedTuple,
487+
]
488+
489+
for typ in test_cases:
490+
encoded = encode_type_from_annotation(typ)
491+
decoded = decode_engine_value_type(encoded["type"])
492+
reencoded = encode_engine_value_type(decoded)
493+
assert reencoded == encoded["type"]
494+
495+
496+
def test_table_types_encode_decode() -> None:
497+
"""Test encode/decode roundtrip for table types."""
498+
test_cases = [
499+
list[SimpleDataclass], # LTable
500+
dict[str, SimpleDataclass], # KTable
501+
]
502+
503+
for typ in test_cases:
504+
encoded = encode_type_from_annotation(typ)
505+
decoded = decode_engine_value_type(encoded["type"])
506+
reencoded = encode_engine_value_type(decoded)
507+
assert reencoded == encoded["type"]
508+
509+
510+
def test_nullable_types_encode_decode() -> None:
511+
"""Test encode/decode roundtrip for nullable types."""
512+
test_cases = [
513+
str | None,
514+
int | None,
515+
NDArray[np.float32] | None,
516+
]
517+
518+
for typ in test_cases:
519+
encoded = encode_type_from_annotation(typ)
520+
decoded = decode_engine_value_type(encoded["type"])
521+
reencoded = encode_engine_value_type(decoded)
522+
assert reencoded == encoded["type"]
523+
524+
525+
def test_annotated_types_encode_decode() -> None:
526+
"""Test encode/decode roundtrip for annotated types."""
527+
test_cases = [
528+
Annotated[str, TypeAttr("key", "value")],
529+
Annotated[NDArray[np.float32], VectorInfo(dim=256)],
530+
Annotated[list[int], VectorInfo(dim=10)],
531+
]
532+
533+
for typ in test_cases:
534+
encoded = encode_type_from_annotation(typ)
535+
decoded = decode_engine_value_type(encoded["type"])
536+
reencoded = encode_engine_value_type(decoded)
537+
assert reencoded == encoded["type"]
538+
539+
540+
def test_complex_nested_encode_decode() -> None:
541+
"""Test complex nested structure encode/decode roundtrip."""
542+
543+
# Create a complex nested structure using Python type annotations
544+
@dataclasses.dataclass
545+
class ComplexStruct:
546+
embedding: NDArray[np.float32]
547+
metadata: str | None
548+
score: Annotated[float, TypeAttr("indexed", True)]
549+
550+
encoded = encode_type_from_annotation(ComplexStruct)
551+
decoded = decode_engine_value_type(encoded["type"])
552+
reencoded = encode_engine_value_type(decoded)
553+
assert reencoded == encoded["type"]

python/cocoindex/typing.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,12 @@ def decode(obj: dict[str, Any]) -> "VectorTypeSchema":
489489
dimension=obj.get("dimension"),
490490
)
491491

492+
def encode(self) -> dict[str, Any]:
493+
return {
494+
"element_type": self.element_type.encode(),
495+
"dimension": self.dimension,
496+
}
497+
492498

493499
@dataclasses.dataclass
494500
class UnionTypeSchema:
@@ -500,6 +506,9 @@ def decode(obj: dict[str, Any]) -> "UnionTypeSchema":
500506
variants=[BasicValueType.decode(t) for t in obj["types"]]
501507
)
502508

509+
def encode(self) -> dict[str, Any]:
510+
return {"types": [variant.encode() for variant in self.variants]}
511+
503512

504513
@dataclasses.dataclass
505514
class BasicValueType:
@@ -545,6 +554,14 @@ def decode(obj: dict[str, Any]) -> "BasicValueType":
545554
)
546555
return BasicValueType(kind=kind) # type: ignore[arg-type]
547556

557+
def encode(self) -> dict[str, Any]:
558+
result = {"kind": self.kind}
559+
if self.kind == "Vector" and self.vector is not None:
560+
result.update(self.vector.encode())
561+
elif self.kind == "Union" and self.union is not None:
562+
result.update(self.union.encode())
563+
return result
564+
548565

549566
@dataclasses.dataclass
550567
class EnrichedValueType:
@@ -560,6 +577,14 @@ def decode(obj: dict[str, Any]) -> "EnrichedValueType":
560577
attrs=obj.get("attrs"),
561578
)
562579

580+
def encode(self) -> dict[str, Any]:
581+
result: dict[str, Any] = {"type": self.type.encode()}
582+
if self.nullable:
583+
result["nullable"] = True
584+
if self.attrs is not None:
585+
result["attrs"] = self.attrs
586+
return result
587+
563588

564589
@dataclasses.dataclass
565590
class FieldSchema:
@@ -570,6 +595,11 @@ class FieldSchema:
570595
def decode(obj: dict[str, Any]) -> "FieldSchema":
571596
return FieldSchema(name=obj["name"], value_type=EnrichedValueType.decode(obj))
572597

598+
def encode(self) -> dict[str, Any]:
599+
result = self.value_type.encode()
600+
result["name"] = self.name
601+
return result
602+
573603

574604
@dataclasses.dataclass
575605
class StructSchema:
@@ -583,11 +613,22 @@ def decode(cls, obj: dict[str, Any]) -> Self:
583613
description=obj.get("description"),
584614
)
585615

616+
def encode(self) -> dict[str, Any]:
617+
result: dict[str, Any] = {"fields": [field.encode() for field in self.fields]}
618+
if self.description is not None:
619+
result["description"] = self.description
620+
return result
621+
586622

587623
@dataclasses.dataclass
588624
class StructType(StructSchema):
589625
kind: Literal["Struct"] = "Struct"
590626

627+
def encode(self) -> dict[str, Any]:
628+
result = super().encode()
629+
result["kind"] = self.kind
630+
return result
631+
591632

592633
@dataclasses.dataclass
593634
class TableType:
@@ -608,6 +649,12 @@ def decode(obj: dict[str, Any]) -> "TableType":
608649
num_key_parts=obj.get("num_key_parts"),
609650
)
610651

652+
def encode(self) -> dict[str, Any]:
653+
result: dict[str, Any] = {"kind": self.kind, "row": self.row.encode()}
654+
if self.num_key_parts is not None:
655+
result["num_key_parts"] = self.num_key_parts
656+
return result
657+
611658

612659
ValueType = BasicValueType | StructType | TableType
613660

@@ -626,3 +673,8 @@ def decode_engine_value_type(obj: dict[str, Any]) -> ValueType:
626673

627674
# Otherwise it's a basic value
628675
return BasicValueType.decode(obj)
676+
677+
678+
def encode_engine_value_type(value_type: ValueType) -> dict[str, Any]:
679+
"""Encode a ValueType to its dictionary representation."""
680+
return value_type.encode()

0 commit comments

Comments
 (0)