Skip to content

Commit ba758a1

Browse files
authored
feat: python type binding convertibility for basic types (#649)
1 parent 65cbf27 commit ba758a1

File tree

4 files changed

+198
-103
lines changed

4 files changed

+198
-103
lines changed

python/cocoindex/convert.py

Lines changed: 78 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .typing import (
1515
KEY_FIELD_NAME,
1616
TABLE_TYPES,
17+
AnalyzedTypeInfo,
1718
DtypeRegistry,
1819
analyze_type_info,
1920
encode_enriched_type,
@@ -46,6 +47,19 @@ def encode_engine_value(value: Any) -> Any:
4647
return value
4748

4849

50+
_CONVERTIBLE_KINDS = {
51+
("Float32", "Float64"),
52+
("LocalDateTime", "OffsetDateTime"),
53+
}
54+
55+
56+
def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool:
57+
return (
58+
src_type_kind == dst_type_kind
59+
or (src_type_kind, dst_type_kind) in _CONVERTIBLE_KINDS
60+
)
61+
62+
4963
def make_engine_value_decoder(
5064
field_path: list[str],
5165
src_type: dict[str, Any],
@@ -65,44 +79,90 @@ def make_engine_value_decoder(
6579

6680
src_type_kind = src_type["kind"]
6781

82+
dst_type_info: AnalyzedTypeInfo | None = None
6883
if (
69-
dst_annotation is None
70-
or dst_annotation is inspect.Parameter.empty
71-
or dst_annotation is Any
84+
dst_annotation is not None
85+
and dst_annotation is not inspect.Parameter.empty
86+
and dst_annotation is not Any
7287
):
88+
dst_type_info = analyze_type_info(dst_annotation)
89+
if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
90+
raise ValueError(
91+
f"Type mismatch for `{''.join(field_path)}`: "
92+
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
93+
)
94+
95+
if src_type_kind == "Uuid":
96+
return lambda value: uuid.UUID(bytes=value)
97+
98+
if dst_type_info is None:
7399
if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES:
74100
raise ValueError(
75101
f"Missing type annotation for `{''.join(field_path)}`."
76102
f"It's required for {src_type_kind} type."
77103
)
78104
return lambda value: value
79105

80-
dst_type_info = analyze_type_info(dst_annotation)
106+
if dst_type_info.kind in ("Float32", "Float64", "Int64"):
107+
dst_core_type = dst_type_info.core_type
81108

82-
if src_type_kind != dst_type_info.kind:
83-
raise ValueError(
84-
f"Type mismatch for `{''.join(field_path)}`: "
85-
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
86-
)
109+
def decode_scalar(value: Any) -> Any | None:
110+
if value is None:
111+
if dst_type_info.nullable:
112+
return None
113+
raise ValueError(
114+
f"Received null for non-nullable scalar `{''.join(field_path)}`"
115+
)
116+
return dst_core_type(value)
87117

88-
if dst_type_info.struct_type is not None:
89-
return _make_engine_struct_value_decoder(
90-
field_path, src_type["fields"], dst_type_info.struct_type
118+
return decode_scalar
119+
120+
if src_type_kind == "Vector":
121+
field_path_str = "".join(field_path)
122+
expected_dim = (
123+
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
91124
)
92125

93-
if dst_type_info.np_number_type is not None and src_type_kind != "Vector":
94-
numpy_type = dst_type_info.np_number_type
126+
elem_decoder = None
127+
scalar_dtype = None
128+
if dst_type_info.np_number_type is None: # for Non-NDArray vector
129+
elem_decoder = make_engine_value_decoder(
130+
field_path + ["[*]"],
131+
src_type["element_type"],
132+
dst_type_info.elem_type,
133+
)
134+
else: # for NDArray vector
135+
scalar_dtype = extract_ndarray_scalar_dtype(dst_type_info.np_number_type)
136+
_ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
95137

96-
def decode_numpy_scalar(value: Any) -> Any | None:
138+
def decode_vector(value: Any) -> Any | None:
97139
if value is None:
98140
if dst_type_info.nullable:
99141
return None
100142
raise ValueError(
101-
f"Received null for non-nullable scalar `{''.join(field_path)}`"
143+
f"Received null for non-nullable vector `{field_path_str}`"
144+
)
145+
if not isinstance(value, (np.ndarray, list)):
146+
raise TypeError(
147+
f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
148+
)
149+
if expected_dim is not None and len(value) != expected_dim:
150+
raise ValueError(
151+
f"Vector dimension mismatch for `{field_path_str}`: "
152+
f"expected {expected_dim}, got {len(value)}"
102153
)
103-
return numpy_type(value)
104154

105-
return decode_numpy_scalar
155+
if elem_decoder is not None: # for Non-NDArray vector
156+
return [elem_decoder(v) for v in value]
157+
else: # for NDArray vector
158+
return np.array(value, dtype=scalar_dtype)
159+
160+
return decode_vector
161+
162+
if dst_type_info.struct_type is not None:
163+
return _make_engine_struct_value_decoder(
164+
field_path, src_type["fields"], dst_type_info.struct_type
165+
)
106166

107167
if src_type_kind in TABLE_TYPES:
108168
field_path.append("[*]")
@@ -141,49 +201,6 @@ def decode(value: Any) -> Any | None:
141201
field_path.pop()
142202
return decode
143203

144-
if src_type_kind == "Uuid":
145-
return lambda value: uuid.UUID(bytes=value)
146-
147-
if src_type_kind == "Vector":
148-
149-
def decode_vector(value: Any) -> Any | None:
150-
field_path_str = "".join(field_path)
151-
expected_dim = (
152-
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
153-
)
154-
155-
if value is None:
156-
if dst_type_info.nullable:
157-
return None
158-
raise ValueError(
159-
f"Received null for non-nullable vector `{field_path_str}`"
160-
)
161-
if not isinstance(value, (np.ndarray, list)):
162-
raise TypeError(
163-
f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
164-
)
165-
if expected_dim is not None and len(value) != expected_dim:
166-
raise ValueError(
167-
f"Vector dimension mismatch for `{field_path_str}`: "
168-
f"expected {expected_dim}, got {len(value)}"
169-
)
170-
171-
if dst_type_info.np_number_type is None: # for Non-NDArray vector
172-
elem_decoder = make_engine_value_decoder(
173-
field_path + ["[*]"],
174-
src_type["element_type"],
175-
dst_type_info.elem_type,
176-
)
177-
return [elem_decoder(v) for v in value]
178-
else: # for NDArray vector
179-
scalar_dtype = extract_ndarray_scalar_dtype(
180-
dst_type_info.np_number_type
181-
)
182-
_ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
183-
return np.array(value, dtype=scalar_dtype)
184-
185-
return decode_vector
186-
187204
if src_type_kind == "Union":
188205
return lambda value: value[1]
189206

python/cocoindex/tests/test_convert.py

Lines changed: 92 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,26 @@ def validate_full_roundtrip(
9191
"""
9292
from cocoindex import _engine # type: ignore
9393

94+
def eq(a: Any, b: Any) -> bool:
95+
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
96+
return np.array_equal(a, b)
97+
return type(a) == type(b) and not not (a == b)
98+
9499
encoded_value = encode_engine_value(value)
95100
value_type = value_type or type(value)
96101
encoded_output_type = encode_enriched_type(value_type)["type"]
97102
value_from_engine = _engine.testutil.seder_roundtrip(
98103
encoded_value, encoded_output_type
99104
)
100-
decoded_value = build_engine_value_decoder(value_type, value_type)(
101-
value_from_engine
102-
)
103-
np.testing.assert_array_equal(decoded_value, value)
105+
decoder = make_engine_value_decoder([], encoded_output_type, value_type)
106+
decoded_value = decoder(value_from_engine)
107+
assert eq(decoded_value, value)
104108

105109
if other_decoded_values is not None:
106110
for other_value, other_type in other_decoded_values:
107-
other_decoded_value = build_engine_value_decoder(other_type, other_type)(
108-
value_from_engine
109-
)
110-
np.testing.assert_array_equal(other_decoded_value, other_value)
111+
decoder = make_engine_value_decoder([], encoded_output_type, other_type)
112+
other_decoded_value = decoder(value_from_engine)
113+
assert eq(other_decoded_value, other_value)
111114

112115

113116
def test_encode_engine_value_basic_types() -> None:
@@ -215,19 +218,38 @@ def test_encode_engine_value_none() -> None:
215218

216219

217220
def test_roundtrip_basic_types() -> None:
218-
validate_full_roundtrip(42, int)
221+
validate_full_roundtrip(42, int, (42, None))
219222
validate_full_roundtrip(3.25, float, (3.25, Float64))
220-
validate_full_roundtrip(3.25, Float64, (3.25, float))
221-
validate_full_roundtrip(3.25, Float32)
222-
validate_full_roundtrip("hello", str)
223-
validate_full_roundtrip(True, bool)
224-
validate_full_roundtrip(False, bool)
225-
validate_full_roundtrip(datetime.date(2025, 1, 1), datetime.date)
226-
validate_full_roundtrip(datetime.datetime.now(), cocoindex.LocalDateTime)
227223
validate_full_roundtrip(
228-
datetime.datetime.now(datetime.UTC), cocoindex.OffsetDateTime
224+
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64)
225+
)
226+
validate_full_roundtrip(
227+
3.25, Float32, (3.25, float), (np.float32(3.25), np.float32)
228+
)
229+
validate_full_roundtrip("hello", str, ("hello", None))
230+
validate_full_roundtrip(True, bool, (True, None))
231+
validate_full_roundtrip(False, bool, (False, None))
232+
validate_full_roundtrip(
233+
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), None)
229234
)
230235

236+
validate_full_roundtrip(
237+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
238+
cocoindex.LocalDateTime,
239+
(datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), datetime.datetime),
240+
)
241+
validate_full_roundtrip(
242+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
243+
cocoindex.OffsetDateTime,
244+
(
245+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
246+
datetime.datetime,
247+
),
248+
)
249+
250+
uuid_value = uuid.uuid4()
251+
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None))
252+
231253

232254
def test_decode_scalar_numpy_values() -> None:
233255
test_cases = [
@@ -849,37 +871,72 @@ def test_dump_vector_type_annotation_no_dim() -> None:
849871

850872
def test_full_roundtrip_vector_numeric_types() -> None:
851873
"""Test full roundtrip for numeric vector types using NDArray."""
852-
value_f32: Vector[np.float32, Literal[3]] = np.array(
853-
[1.0, 2.0, 3.0], dtype=np.float32
874+
value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
875+
validate_full_roundtrip(
876+
value_f32,
877+
Vector[np.float32, Literal[3]],
878+
([np.float32(1.0), np.float32(2.0), np.float32(3.0)], list[np.float32]),
879+
([1.0, 2.0, 3.0], list[cocoindex.Float32]),
880+
([1.0, 2.0, 3.0], list[float]),
881+
)
882+
validate_full_roundtrip(
883+
value_f32,
884+
np.typing.NDArray[np.float32],
885+
([np.float32(1.0), np.float32(2.0), np.float32(3.0)], list[np.float32]),
886+
([1.0, 2.0, 3.0], list[cocoindex.Float32]),
887+
([1.0, 2.0, 3.0], list[float]),
888+
)
889+
validate_full_roundtrip(
890+
value_f32.tolist(),
891+
list[np.float32],
892+
(value_f32, Vector[np.float32, Literal[3]]),
893+
([1.0, 2.0, 3.0], list[cocoindex.Float32]),
894+
([1.0, 2.0, 3.0], list[float]),
895+
)
896+
897+
value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
898+
validate_full_roundtrip(
899+
value_f64,
900+
Vector[np.float64, Literal[3]],
901+
([np.float64(1.0), np.float64(2.0), np.float64(3.0)], list[np.float64]),
902+
([1.0, 2.0, 3.0], list[cocoindex.Float64]),
903+
([1.0, 2.0, 3.0], list[float]),
854904
)
855-
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
856-
value_f64: Vector[np.float64, Literal[3]] = np.array(
857-
[1.0, 2.0, 3.0], dtype=np.float64
905+
906+
value_i64 = np.array([1, 2, 3], dtype=np.int64)
907+
validate_full_roundtrip(
908+
value_i64,
909+
Vector[np.int64, Literal[3]],
910+
([np.int64(1), np.int64(2), np.int64(3)], list[np.int64]),
911+
([1, 2, 3], list[int]),
858912
)
859-
validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]])
860-
value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
861-
validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
862-
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
913+
914+
value_i32 = np.array([1, 2, 3], dtype=np.int32)
863915
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
864916
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
865-
value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
917+
value_u8 = np.array([1, 2, 3], dtype=np.uint8)
866918
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
867919
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
868-
value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
920+
value_u16 = np.array([1, 2, 3], dtype=np.uint16)
869921
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
870922
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
871-
value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
923+
value_u32 = np.array([1, 2, 3], dtype=np.uint32)
872924
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
873925
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
874-
value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
926+
value_u64 = np.array([1, 2, 3], dtype=np.uint64)
875927
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
876928
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
877929

878930

879931
def test_roundtrip_vector_no_dimension() -> None:
880932
"""Test full roundtrip for vector types without dimension annotation."""
881-
value_f64: Vector[np.float64] = np.array([1.0, 2.0, 3.0], dtype=np.float64)
882-
validate_full_roundtrip(value_f64, Vector[np.float64])
933+
value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
934+
validate_full_roundtrip(
935+
value_f64,
936+
Vector[np.float64],
937+
([1.0, 2.0, 3.0], list[float]),
938+
(np.array([1.0, 2.0, 3.0], dtype=np.float64), np.typing.NDArray[np.float64]),
939+
)
883940

884941

885942
def test_roundtrip_string_vector() -> None:
@@ -904,9 +961,9 @@ def test_roundtrip_dimension_mismatch() -> None:
904961
def test_full_roundtrip_scalar_numeric_types() -> None:
905962
"""Test full roundtrip for scalar NumPy numeric types."""
906963
# Test supported scalar types
907-
validate_full_roundtrip(np.int64(42), np.int64)
908-
validate_full_roundtrip(np.float32(3.14), np.float32)
909-
validate_full_roundtrip(np.float64(2.718), np.float64)
964+
validate_full_roundtrip(np.int64(42), np.int64, (42, int))
965+
validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, cocoindex.Float32))
966+
validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, cocoindex.Float64))
910967

911968
# Test unsupported scalar types
912969
for unsupported_type in [np.int32, np.uint8, np.uint16, np.uint32, np.uint64]:

0 commit comments

Comments
 (0)