diff --git a/python/cocoindex/__init__.py b/python/cocoindex/__init__.py index 567bde63..dd9ecc66 100644 --- a/python/cocoindex/__init__.py +++ b/python/cocoindex/__init__.py @@ -17,7 +17,16 @@ from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions from .setting import DatabaseConnectionSpec, Settings, ServerSettings from .setting import get_app_namespace -from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json +from .typing import ( + Int64, + Float32, + Float64, + LocalDateTime, + OffsetDateTime, + Range, + Vector, + Json, +) __all__ = [ # Submodules @@ -64,6 +73,7 @@ "ServerSettings", "get_app_namespace", # Typing + "Int64", "Float32", "Float64", "LocalDateTime", diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 00850416..b390b7bb 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -78,16 +78,15 @@ def build_engine_value_decoder( return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py) -def validate_full_roundtrip( +def validate_full_roundtrip_to( value: Any, - value_type: Any = None, - *other_decoded_values: tuple[Any, Any], + value_type: Any, + *decoded_values: tuple[Any, Any], ) -> None: """ - Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type). + Validate the given value becomes specific values after encoding, sending to engine (using output_type), receiving back and decoding (using input_type). - `other_decoded_values` is a tuple of (value, type) pairs. - If provided, also validate the value can be decoded to the other types. + `decoded_values` is a tuple of (value, type) pairs. """ from cocoindex import _engine # type: ignore @@ -102,15 +101,27 @@ def eq(a: Any, b: Any) -> bool: value_from_engine = _engine.testutil.seder_roundtrip( encoded_value, encoded_output_type ) - decoder = make_engine_value_decoder([], encoded_output_type, value_type) - decoded_value = decoder(value_from_engine) - assert eq(decoded_value, value), f"{decoded_value} != {value}" - if other_decoded_values is not None: - for other_value, other_type in other_decoded_values: - decoder = make_engine_value_decoder([], encoded_output_type, other_type) - other_decoded_value = decoder(value_from_engine) - assert eq(other_decoded_value, other_value) + for other_value, other_type in decoded_values: + decoder = make_engine_value_decoder([], encoded_output_type, other_type) + other_decoded_value = decoder(value_from_engine) + assert eq(other_decoded_value, other_value) + + +def validate_full_roundtrip( + value: Any, + value_type: Any, + *other_decoded_values: tuple[Any, Any], +) -> None: + """ + Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type). + + `other_decoded_values` is a tuple of (value, type) pairs. + If provided, also validate the value can be decoded to the other types. + """ + validate_full_roundtrip_to( + value, value_type, (value, value_type), *other_decoded_values + ) def test_encode_engine_value_basic_types() -> None: @@ -218,17 +229,33 @@ def test_encode_engine_value_none() -> None: def test_roundtrip_basic_types() -> None: - validate_full_roundtrip(42, int, (42, None)) - validate_full_roundtrip(3.25, float, (3.25, Float64)) validate_full_roundtrip( - 3.25, Float64, (3.25, float), (np.float64(3.25), np.float64) + 42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, None) ) + validate_full_roundtrip(42, int, (42, cocoindex.Int64)) + validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64)) + validate_full_roundtrip( - 3.25, Float32, (3.25, float), (np.float32(3.25), np.float32) + 3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, None) ) + validate_full_roundtrip(3.25, float, (3.25, Float64)) + validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64)) + + validate_full_roundtrip( + 3.25, + Float32, + (3.25, float), + (np.float32(3.25), np.float32), + (np.float64(3.25), np.float64), + (3.25, Float64), + (3.25, None), + ) + validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32)) + validate_full_roundtrip("hello", str, ("hello", None)) validate_full_roundtrip(True, bool, (True, None)) validate_full_roundtrip(False, bool, (False, None)) + validate_full_roundtrip((1, 2), cocoindex.Range, ((1, 2), None)) validate_full_roundtrip( datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), None) ) @@ -238,14 +265,37 @@ def test_roundtrip_basic_types() -> None: cocoindex.LocalDateTime, (datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), datetime.datetime), ) + + tz = datetime.timezone(datetime.timedelta(hours=5)) validate_full_roundtrip( - datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC), + datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), + cocoindex.OffsetDateTime, + ( + datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), + datetime.datetime, + ), + ) + validate_full_roundtrip( + datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), + datetime.datetime, + (datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), cocoindex.OffsetDateTime), + ) + validate_full_roundtrip_to( + datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), cocoindex.OffsetDateTime, ( datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC), datetime.datetime, ), ) + validate_full_roundtrip_to( + datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), + datetime.datetime, + ( + datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC), + cocoindex.OffsetDateTime, + ), + ) uuid_value = uuid.uuid4() validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None)) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 86125594..25dcd414 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -40,6 +40,7 @@ def __init__(self, key: str, value: Any): Annotation = TypeKind | TypeAttr | VectorInfo +Int64 = Annotated[int, TypeKind("Int64")] Float32 = Annotated[float, TypeKind("Float32")] Float64 = Annotated[float, TypeKind("Float64")] Range = Annotated[tuple[int, int], TypeKind("Range")] diff --git a/src/py/convert.rs b/src/py/convert.rs index d7db53db..6dcbaad0 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -1,3 +1,5 @@ +use crate::prelude::*; + use bytes::Bytes; use numpy::{PyArray1, PyArrayDyn, PyArrayMethods}; use pyo3::IntoPyObjectExt; @@ -6,14 +8,10 @@ use pyo3::types::PyAny; use pyo3::types::{PyList, PyTuple}; use pyo3::{exceptions::PyException, prelude::*}; use pythonize::{depythonize, pythonize}; -use serde::Serialize; use serde::de::DeserializeOwned; -use std::collections::BTreeMap; use std::ops::Deref; -use std::sync::Arc; use super::IntoPyResult; -use crate::base::{schema, value}; #[derive(Debug)] pub struct Pythonized(pub T); @@ -143,7 +141,23 @@ fn basic_value_from_py_object<'py>( value::BasicValue::LocalDateTime(v.extract::()?) } schema::BasicValueType::OffsetDateTime => { - value::BasicValue::OffsetDateTime(v.extract::>()?) + if v.getattr_opt("tzinfo")? + .ok_or_else(|| { + PyErr::new::(format!( + "expecting a datetime.datetime value, got {}", + v.get_type() + )) + })? + .is_none() + { + value::BasicValue::OffsetDateTime( + v.extract::()?.and_utc().into(), + ) + } else { + value::BasicValue::OffsetDateTime( + v.extract::>()?, + ) + } } schema::BasicValueType::TimeDelta => { value::BasicValue::TimeDelta(v.extract::()?)