Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion python/cocoindex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
from .setting import get_app_namespace
from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json
from .typing import (
Int64,
Float32,
Float64,
LocalDateTime,
OffsetDateTime,
Range,
Vector,
Json,
)

__all__ = [
# Submodules
Expand Down Expand Up @@ -64,6 +73,7 @@
"ServerSettings",
"get_app_namespace",
# Typing
"Int64",
"Float32",
"Float64",
"LocalDateTime",
Expand Down
88 changes: 69 additions & 19 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,15 @@ def build_engine_value_decoder(
return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)


def validate_full_roundtrip(
def validate_full_roundtrip_to(
value: Any,
value_type: Any = None,
*other_decoded_values: tuple[Any, Any],
value_type: Any,
*decoded_values: tuple[Any, Any],
) -> None:
"""
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
Validate the given value becomes specific values after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).

`other_decoded_values` is a tuple of (value, type) pairs.
If provided, also validate the value can be decoded to the other types.
`decoded_values` is a tuple of (value, type) pairs.
"""
from cocoindex import _engine # type: ignore

Expand All @@ -102,15 +101,27 @@ def eq(a: Any, b: Any) -> bool:
value_from_engine = _engine.testutil.seder_roundtrip(
encoded_value, encoded_output_type
)
decoder = make_engine_value_decoder([], encoded_output_type, value_type)
decoded_value = decoder(value_from_engine)
assert eq(decoded_value, value), f"{decoded_value} != {value}"

if other_decoded_values is not None:
for other_value, other_type in other_decoded_values:
decoder = make_engine_value_decoder([], encoded_output_type, other_type)
other_decoded_value = decoder(value_from_engine)
assert eq(other_decoded_value, other_value)
for other_value, other_type in decoded_values:
decoder = make_engine_value_decoder([], encoded_output_type, other_type)
other_decoded_value = decoder(value_from_engine)
assert eq(other_decoded_value, other_value)


def validate_full_roundtrip(
value: Any,
value_type: Any,
*other_decoded_values: tuple[Any, Any],
) -> None:
"""
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).

`other_decoded_values` is a tuple of (value, type) pairs.
If provided, also validate the value can be decoded to the other types.
"""
validate_full_roundtrip_to(
value, value_type, (value, value_type), *other_decoded_values
)


def test_encode_engine_value_basic_types() -> None:
Expand Down Expand Up @@ -218,17 +229,33 @@ def test_encode_engine_value_none() -> None:


def test_roundtrip_basic_types() -> None:
validate_full_roundtrip(42, int, (42, None))
validate_full_roundtrip(3.25, float, (3.25, Float64))
validate_full_roundtrip(
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64)
42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, None)
)
validate_full_roundtrip(42, int, (42, cocoindex.Int64))
validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64))

validate_full_roundtrip(
3.25, Float32, (3.25, float), (np.float32(3.25), np.float32)
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, None)
)
validate_full_roundtrip(3.25, float, (3.25, Float64))
validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64))

validate_full_roundtrip(
3.25,
Float32,
(3.25, float),
(np.float32(3.25), np.float32),
(np.float64(3.25), np.float64),
(3.25, Float64),
(3.25, None),
)
validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))

validate_full_roundtrip("hello", str, ("hello", None))
validate_full_roundtrip(True, bool, (True, None))
validate_full_roundtrip(False, bool, (False, None))
validate_full_roundtrip((1, 2), cocoindex.Range, ((1, 2), None))
validate_full_roundtrip(
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), None)
)
Expand All @@ -238,14 +265,37 @@ def test_roundtrip_basic_types() -> None:
cocoindex.LocalDateTime,
(datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), datetime.datetime),
)

tz = datetime.timezone(datetime.timedelta(hours=5))
validate_full_roundtrip(
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
cocoindex.OffsetDateTime,
(
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
datetime.datetime,
),
)
validate_full_roundtrip(
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
datetime.datetime,
(datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), cocoindex.OffsetDateTime),
)
validate_full_roundtrip_to(
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
cocoindex.OffsetDateTime,
(
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
datetime.datetime,
),
)
validate_full_roundtrip_to(
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
datetime.datetime,
(
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
cocoindex.OffsetDateTime,
),
)

uuid_value = uuid.uuid4()
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None))
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, key: str, value: Any):

Annotation = TypeKind | TypeAttr | VectorInfo

Int64 = Annotated[int, TypeKind("Int64")]
Float32 = Annotated[float, TypeKind("Float32")]
Float64 = Annotated[float, TypeKind("Float64")]
Range = Annotated[tuple[int, int], TypeKind("Range")]
Expand Down
24 changes: 19 additions & 5 deletions src/py/convert.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::prelude::*;

use bytes::Bytes;
use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
use pyo3::IntoPyObjectExt;
Expand All @@ -6,14 +8,10 @@ use pyo3::types::PyAny;
use pyo3::types::{PyList, PyTuple};
use pyo3::{exceptions::PyException, prelude::*};
use pythonize::{depythonize, pythonize};
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::collections::BTreeMap;
use std::ops::Deref;
use std::sync::Arc;

use super::IntoPyResult;
use crate::base::{schema, value};

#[derive(Debug)]
pub struct Pythonized<T>(pub T);
Expand Down Expand Up @@ -143,7 +141,23 @@ fn basic_value_from_py_object<'py>(
value::BasicValue::LocalDateTime(v.extract::<chrono::NaiveDateTime>()?)
}
schema::BasicValueType::OffsetDateTime => {
value::BasicValue::OffsetDateTime(v.extract::<chrono::DateTime<chrono::FixedOffset>>()?)
if v.getattr_opt("tzinfo")?
.ok_or_else(|| {
PyErr::new::<PyTypeError, _>(format!(
"expecting a datetime.datetime value, got {}",
v.get_type()
))
})?
.is_none()
{
value::BasicValue::OffsetDateTime(
v.extract::<chrono::NaiveDateTime>()?.and_utc().into(),
)
} else {
value::BasicValue::OffsetDateTime(
v.extract::<chrono::DateTime<chrono::FixedOffset>>()?,
)
}
}
schema::BasicValueType::TimeDelta => {
value::BasicValue::TimeDelta(v.extract::<chrono::TimeDelta>()?)
Expand Down
Loading