Skip to content

Commit 81b0ab2

Browse files
authored
chore(types): consolidate binding supports for various basic types (#692)
1 parent 2cef6a6 commit 81b0ab2

File tree

4 files changed

+100
-25
lines changed

4 files changed

+100
-25
lines changed

python/cocoindex/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@
1717
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
1818
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
1919
from .setting import get_app_namespace
20-
from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json
20+
from .typing import (
21+
Int64,
22+
Float32,
23+
Float64,
24+
LocalDateTime,
25+
OffsetDateTime,
26+
Range,
27+
Vector,
28+
Json,
29+
)
2130

2231
__all__ = [
2332
# Submodules
@@ -64,6 +73,7 @@
6473
"ServerSettings",
6574
"get_app_namespace",
6675
# Typing
76+
"Int64",
6777
"Float32",
6878
"Float64",
6979
"LocalDateTime",

python/cocoindex/tests/test_convert.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,15 @@ def build_engine_value_decoder(
7878
return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
7979

8080

81-
def validate_full_roundtrip(
81+
def validate_full_roundtrip_to(
8282
value: Any,
83-
value_type: Any = None,
84-
*other_decoded_values: tuple[Any, Any],
83+
value_type: Any,
84+
*decoded_values: tuple[Any, Any],
8585
) -> None:
8686
"""
87-
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
87+
Validate the given value becomes specific values after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
8888
89-
`other_decoded_values` is a tuple of (value, type) pairs.
90-
If provided, also validate the value can be decoded to the other types.
89+
`decoded_values` is a tuple of (value, type) pairs.
9190
"""
9291
from cocoindex import _engine # type: ignore
9392

@@ -102,15 +101,27 @@ def eq(a: Any, b: Any) -> bool:
102101
value_from_engine = _engine.testutil.seder_roundtrip(
103102
encoded_value, encoded_output_type
104103
)
105-
decoder = make_engine_value_decoder([], encoded_output_type, value_type)
106-
decoded_value = decoder(value_from_engine)
107-
assert eq(decoded_value, value), f"{decoded_value} != {value}"
108104

109-
if other_decoded_values is not None:
110-
for other_value, other_type in other_decoded_values:
111-
decoder = make_engine_value_decoder([], encoded_output_type, other_type)
112-
other_decoded_value = decoder(value_from_engine)
113-
assert eq(other_decoded_value, other_value)
105+
for other_value, other_type in decoded_values:
106+
decoder = make_engine_value_decoder([], encoded_output_type, other_type)
107+
other_decoded_value = decoder(value_from_engine)
108+
assert eq(other_decoded_value, other_value)
109+
110+
111+
def validate_full_roundtrip(
112+
value: Any,
113+
value_type: Any,
114+
*other_decoded_values: tuple[Any, Any],
115+
) -> None:
116+
"""
117+
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
118+
119+
`other_decoded_values` is a tuple of (value, type) pairs.
120+
If provided, also validate the value can be decoded to the other types.
121+
"""
122+
validate_full_roundtrip_to(
123+
value, value_type, (value, value_type), *other_decoded_values
124+
)
114125

115126

116127
def test_encode_engine_value_basic_types() -> None:
@@ -218,17 +229,33 @@ def test_encode_engine_value_none() -> None:
218229

219230

220231
def test_roundtrip_basic_types() -> None:
221-
validate_full_roundtrip(42, int, (42, None))
222-
validate_full_roundtrip(3.25, float, (3.25, Float64))
223232
validate_full_roundtrip(
224-
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64)
233+
42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, None)
225234
)
235+
validate_full_roundtrip(42, int, (42, cocoindex.Int64))
236+
validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64))
237+
226238
validate_full_roundtrip(
227-
3.25, Float32, (3.25, float), (np.float32(3.25), np.float32)
239+
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, None)
228240
)
241+
validate_full_roundtrip(3.25, float, (3.25, Float64))
242+
validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64))
243+
244+
validate_full_roundtrip(
245+
3.25,
246+
Float32,
247+
(3.25, float),
248+
(np.float32(3.25), np.float32),
249+
(np.float64(3.25), np.float64),
250+
(3.25, Float64),
251+
(3.25, None),
252+
)
253+
validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))
254+
229255
validate_full_roundtrip("hello", str, ("hello", None))
230256
validate_full_roundtrip(True, bool, (True, None))
231257
validate_full_roundtrip(False, bool, (False, None))
258+
validate_full_roundtrip((1, 2), cocoindex.Range, ((1, 2), None))
232259
validate_full_roundtrip(
233260
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), None)
234261
)
@@ -238,14 +265,37 @@ def test_roundtrip_basic_types() -> None:
238265
cocoindex.LocalDateTime,
239266
(datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), datetime.datetime),
240267
)
268+
269+
tz = datetime.timezone(datetime.timedelta(hours=5))
241270
validate_full_roundtrip(
242-
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
271+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
272+
cocoindex.OffsetDateTime,
273+
(
274+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
275+
datetime.datetime,
276+
),
277+
)
278+
validate_full_roundtrip(
279+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
280+
datetime.datetime,
281+
(datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), cocoindex.OffsetDateTime),
282+
)
283+
validate_full_roundtrip_to(
284+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
243285
cocoindex.OffsetDateTime,
244286
(
245287
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
246288
datetime.datetime,
247289
),
248290
)
291+
validate_full_roundtrip_to(
292+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
293+
datetime.datetime,
294+
(
295+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
296+
cocoindex.OffsetDateTime,
297+
),
298+
)
249299

250300
uuid_value = uuid.uuid4()
251301
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None))

python/cocoindex/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self, key: str, value: Any):
4040

4141
Annotation = TypeKind | TypeAttr | VectorInfo
4242

43+
Int64 = Annotated[int, TypeKind("Int64")]
4344
Float32 = Annotated[float, TypeKind("Float32")]
4445
Float64 = Annotated[float, TypeKind("Float64")]
4546
Range = Annotated[tuple[int, int], TypeKind("Range")]

src/py/convert.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::prelude::*;
2+
13
use bytes::Bytes;
24
use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
35
use pyo3::IntoPyObjectExt;
@@ -6,14 +8,10 @@ use pyo3::types::PyAny;
68
use pyo3::types::{PyList, PyTuple};
79
use pyo3::{exceptions::PyException, prelude::*};
810
use pythonize::{depythonize, pythonize};
9-
use serde::Serialize;
1011
use serde::de::DeserializeOwned;
11-
use std::collections::BTreeMap;
1212
use std::ops::Deref;
13-
use std::sync::Arc;
1413

1514
use super::IntoPyResult;
16-
use crate::base::{schema, value};
1715

1816
#[derive(Debug)]
1917
pub struct Pythonized<T>(pub T);
@@ -143,7 +141,23 @@ fn basic_value_from_py_object<'py>(
143141
value::BasicValue::LocalDateTime(v.extract::<chrono::NaiveDateTime>()?)
144142
}
145143
schema::BasicValueType::OffsetDateTime => {
146-
value::BasicValue::OffsetDateTime(v.extract::<chrono::DateTime<chrono::FixedOffset>>()?)
144+
if v.getattr_opt("tzinfo")?
145+
.ok_or_else(|| {
146+
PyErr::new::<PyTypeError, _>(format!(
147+
"expecting a datetime.datetime value, got {}",
148+
v.get_type()
149+
))
150+
})?
151+
.is_none()
152+
{
153+
value::BasicValue::OffsetDateTime(
154+
v.extract::<chrono::NaiveDateTime>()?.and_utc().into(),
155+
)
156+
} else {
157+
value::BasicValue::OffsetDateTime(
158+
v.extract::<chrono::DateTime<chrono::FixedOffset>>()?,
159+
)
160+
}
147161
}
148162
schema::BasicValueType::TimeDelta => {
149163
value::BasicValue::TimeDelta(v.extract::<chrono::TimeDelta>()?)

0 commit comments

Comments
 (0)