Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
import datetime
import inspect
import uuid
from enum import Enum
from typing import Any, Callable, Mapping, get_origin

import numpy as np

from enum import Enum
from typing import Any, Callable, get_origin, Mapping
from .typing import (
KEY_FIELD_NAME,
TABLE_TYPES,
DtypeRegistry,
analyze_type_info,
encode_enriched_type,
extract_ndarray_scalar_dtype,
is_namedtuple_type,
TABLE_TYPES,
KEY_FIELD_NAME,
DtypeRegistry,
)


Expand All @@ -29,6 +31,8 @@ def encode_engine_value(value: Any) -> Any:
]
if is_namedtuple_type(type(value)):
return [encode_engine_value(getattr(value, name)) for name in value._fields]
if isinstance(value, np.number):
return value.item()
if isinstance(value, np.ndarray):
return value
if isinstance(value, (list, tuple)):
Expand Down Expand Up @@ -86,6 +90,20 @@ def make_engine_value_decoder(
field_path, src_type["fields"], dst_type_info.struct_type
)

if dst_type_info.np_number_type is not None and src_type_kind != "Vector":
numpy_type = dst_type_info.np_number_type

def decode_numpy_scalar(value: Any) -> Any | None:
if value is None:
if dst_type_info.nullable:
return None
raise ValueError(
f"Received null for non-nullable scalar `{''.join(field_path)}`"
)
return numpy_type(value)

return decode_numpy_scalar

if src_type_kind in TABLE_TYPES:
field_path.append("[*]")
elem_type_info = analyze_type_info(dst_type_info.elem_type)
Expand Down Expand Up @@ -127,33 +145,42 @@ def decode(value: Any) -> Any | None:
return lambda value: uuid.UUID(bytes=value)

if src_type_kind == "Vector":
dtype_info = DtypeRegistry.get_by_dtype(dst_type_info.np_number_type)

def decode_vector(value: Any) -> Any | None:
field_path_str = "".join(field_path)
expected_dim = (
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
)

if value is None:
if dst_type_info.nullable:
return None
raise ValueError(
f"Received null for non-nullable vector `{''.join(field_path)}`"
f"Received null for non-nullable vector `{field_path_str}`"
)

if not isinstance(value, (np.ndarray, list)):
raise TypeError(
f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}"
f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
)
expected_dim = (
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
)
if expected_dim is not None and len(value) != expected_dim:
raise ValueError(
f"Vector dimension mismatch for `{''.join(field_path)}`: "
f"Vector dimension mismatch for `{field_path_str}`: "
f"expected {expected_dim}, got {len(value)}"
)

# Use NDArray for supported numeric dtypes, else return list
if dtype_info is not None:
return np.array(value, dtype=dtype_info.numpy_dtype)
return value
if dst_type_info.np_number_type is None: # for Non-NDArray vector
elem_decoder = make_engine_value_decoder(
field_path + ["[*]"],
src_type["element_type"],
dst_type_info.elem_type,
)
return [elem_decoder(v) for v in value]
else: # for NDArray vector
scalar_dtype = extract_ndarray_scalar_dtype(
dst_type_info.np_number_type
)
_ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
return np.array(value, dtype=scalar_dtype)

return decode_vector

Expand Down
184 changes: 158 additions & 26 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import uuid
import datetime
import uuid
from dataclasses import dataclass, make_dataclass
from typing import NamedTuple, Literal, Any, Callable, Union
from typing import Annotated, Any, Callable, Literal, NamedTuple

import numpy as np
import pytest
from numpy.typing import NDArray

import cocoindex
from cocoindex.typing import (
encode_enriched_type,
Vector,
Float32,
Float64,
)
from cocoindex.convert import (
dump_engine_object,
encode_engine_value,
make_engine_value_decoder,
dump_engine_object,
)
import numpy as np
from numpy.typing import NDArray
from cocoindex.typing import (
Float32,
Float64,
TypeKind,
Vector,
encode_enriched_type,
)


@dataclass
Expand Down Expand Up @@ -128,6 +131,19 @@ def test_encode_engine_value_date_time_types() -> None:
assert encode_engine_value(dt) == dt


def test_encode_scalar_numpy_values() -> None:
"""Test encoding scalar NumPy values to engine-compatible values."""
test_cases = [
(np.int64(42), 42),
(np.float32(3.14), pytest.approx(3.14)),
(np.float64(2.718), pytest.approx(2.718)),
]
for np_value, expected in test_cases:
encoded = encode_engine_value(np_value)
assert encoded == expected
assert isinstance(encoded, (int, float))


def test_encode_engine_value_struct() -> None:
order = Order(order_id="O123", name="mixed nuts", price=25.0)
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
Expand Down Expand Up @@ -213,6 +229,47 @@ def test_roundtrip_basic_types() -> None:
)


def test_decode_scalar_numpy_values() -> None:
test_cases = [
({"kind": "Int64"}, np.int64, 42, np.int64(42)),
({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)),
({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
]
for src_type, dst_type, input_value, expected in test_cases:
decoder = make_engine_value_decoder(["field"], src_type, dst_type)
result = decoder(input_value)
assert isinstance(result, dst_type)
assert result == expected


def test_non_ndarray_vector_decoding() -> None:
# Test list[np.float64]
src_type = {
"kind": "Vector",
"element_type": {"kind": "Float64"},
"dimension": None,
}
dst_type_float = list[np.float64]
decoder = make_engine_value_decoder(["field"], src_type, dst_type_float)
input_numbers = [1.0, 2.0, 3.0]
result = decoder(input_numbers)
assert isinstance(result, list)
assert all(isinstance(x, np.float64) for x in result)
assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]

# Test list[Uuid]
src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
dst_type_uuid = list[uuid.UUID]
decoder = make_engine_value_decoder(["field"], src_type, dst_type_uuid)
uuid1 = uuid.uuid4()
uuid2 = uuid.uuid4()
input_bytes = [uuid1.bytes, uuid2.bytes]
result = decoder(input_bytes)
assert isinstance(result, list)
assert all(isinstance(x, uuid.UUID) for x in result)
assert result == [uuid1, uuid2]


@pytest.mark.parametrize(
"data_type, engine_val, expected",
[
Expand Down Expand Up @@ -565,12 +622,6 @@ def test_vector_as_list() -> None:
Float32VectorType = Vector[np.float32, Literal[3]]
Float64VectorType = Vector[np.float64, Literal[3]]
Int64VectorType = Vector[np.int64, Literal[3]]
Int32VectorType = Vector[np.int32, Literal[3]]
UInt8VectorType = Vector[np.uint8, Literal[3]]
UInt16VectorType = Vector[np.uint16, Literal[3]]
UInt32VectorType = Vector[np.uint32, Literal[3]]
UInt64VectorType = Vector[np.uint64, Literal[3]]
StrVectorType = Vector[str]
NDArrayFloat32Type = NDArray[np.float32]
NDArrayFloat64Type = NDArray[np.float64]
NDArrayInt64Type = NDArray[np.int64]
Expand Down Expand Up @@ -767,19 +818,19 @@ def test_full_roundtrip_vector_numeric_types() -> None:
value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
with pytest.raises(ValueError, match="type unsupported yet"):
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
with pytest.raises(ValueError, match="type unsupported yet"):
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
with pytest.raises(ValueError, match="type unsupported yet"):
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
with pytest.raises(ValueError, match="type unsupported yet"):
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
with pytest.raises(ValueError, match="type unsupported yet"):
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])


Expand Down Expand Up @@ -808,7 +859,88 @@ def test_roundtrip_dimension_mismatch() -> None:
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])


def test_roundtrip_list_backward_compatibility() -> None:
"""Test full roundtrip for list-based vectors for backward compatibility."""
value_list: list[int] = [1, 2, 3]
validate_full_roundtrip(value_list, list[int])
def test_full_roundtrip_scalar_numeric_types() -> None:
"""Test full roundtrip for scalar NumPy numeric types."""
# Test supported scalar types
validate_full_roundtrip(np.int64(42), np.int64)
validate_full_roundtrip(np.float32(3.14), np.float32)
validate_full_roundtrip(np.float64(2.718), np.float64)

# Test unsupported scalar types
for unsupported_type in [np.int32, np.uint8, np.uint16, np.uint32, np.uint64]:
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
validate_full_roundtrip(unsupported_type(1), unsupported_type)


def test_full_roundtrip_nullable_scalar() -> None:
"""Test full roundtrip for nullable scalar NumPy types."""
# Test with non-null values
validate_full_roundtrip(np.int64(42), np.int64 | None)
validate_full_roundtrip(np.float32(3.14), np.float32 | None)
validate_full_roundtrip(np.float64(2.718), np.float64 | None)

# Test with None
validate_full_roundtrip(None, np.int64 | None)
validate_full_roundtrip(None, np.float32 | None)
validate_full_roundtrip(None, np.float64 | None)


def test_full_roundtrip_scalar_in_struct() -> None:
"""Test full roundtrip for scalar NumPy types in a dataclass."""

@dataclass
class NumericStruct:
int_field: np.int64
float32_field: np.float32
float64_field: np.float64

instance = NumericStruct(
int_field=np.int64(42),
float32_field=np.float32(3.14),
float64_field=np.float64(2.718),
)
validate_full_roundtrip(instance, NumericStruct)


def test_full_roundtrip_scalar_in_nested_struct() -> None:
"""Test full roundtrip for scalar NumPy types in a nested struct."""

@dataclass
class InnerStruct:
value: np.float64

@dataclass
class OuterStruct:
inner: InnerStruct
count: np.int64

instance = OuterStruct(
inner=InnerStruct(value=np.float64(2.718)),
count=np.int64(1),
)
validate_full_roundtrip(instance, OuterStruct)


def test_full_roundtrip_scalar_with_python_types() -> None:
"""Test full roundtrip for structs mixing NumPy and Python scalar types."""

@dataclass
class MixedStruct:
numpy_int: np.int64
python_int: int
numpy_float: np.float64
python_float: float
string: str
annotated_int: Annotated[np.int64, TypeKind("int")]
annotated_float: Float32

instance = MixedStruct(
numpy_int=np.int64(42),
python_int=43,
numpy_float=np.float64(2.718),
python_float=3.14,
string="hello, world",
annotated_int=np.int64(42),
annotated_float=2.0,
)
validate_full_roundtrip(instance, MixedStruct)
Loading
Loading