From 246cdd2c99e64f08ef4b2268b3edefff8de7a871 Mon Sep 17 00:00:00 2001 From: lemorage Date: Sat, 7 Jun 2025 11:43:22 +0200 Subject: [PATCH 1/4] feat: handle numpy array vector in Python conversion --- Cargo.lock | 86 ++++++++++++++++++++++++++++------- Cargo.toml | 7 +-- src/py/convert.rs | 111 +++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 175 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ab77b1beb..94648c7ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1073,6 +1073,7 @@ dependencies = [ "json5", "log", "neo4rs", + "numpy", "owo-colors", "pgvector", "phf", @@ -2664,6 +2665,16 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.6" @@ -2731,6 +2742,21 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "neo4rs" version = "0.8.0" @@ -2796,6 +2822,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2851,6 +2886,22 @@ dependencies = [ "libc", ] +[[package]] +name = "numpy" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29f1dee9aa8d3f6f8e8b9af3803006101bb3653866ef056d530d53ae68587191" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "pyo3-build-config", + "rustc-hash 2.1.1", +] + [[package]] name = "object" version = "0.36.7" @@ -3228,11 +3279,10 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" +checksum = "f239d656363bcee73afef85277f1b281e8ac6212a1d42aa90e55b90ed43c47a4" dependencies = [ - "cfg-if", "chrono", "indoc", "libc", @@ -3247,9 +3297,9 @@ dependencies = [ [[package]] name = "pyo3-async-runtimes" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd0b83dc42f9d41f50d38180dad65f0c99763b65a3ff2a81bf351dd35a1df8bf" +checksum = "d73cc6b1b7d8b3cef02101d37390dbdfe7e450dfea14921cae80a9534ba59ef2" dependencies = [ "futures", "once_cell", @@ -3260,9 +3310,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" +checksum = "755ea671a1c34044fa165247aaf6f419ca39caa6003aee791a0df2713d8f1b6d" dependencies = [ "once_cell", "target-lexicon", @@ -3270,9 +3320,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" +checksum = "fc95a2e67091e44791d4ea300ff744be5293f394f1bafd9f78c080814d35956e" dependencies = [ "libc", "pyo3-build-config", @@ -3280,9 +3330,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" +checksum = "a179641d1b93920829a62f15e87c0ed791b6c8db2271ba0fd7c2686090510214" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -3292,9 +3342,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" +checksum = "9dff85ebcaab8c441b0e3f7ae40a6963ecea8a9f5e74f647e33fcf5ec9a1e89e" dependencies = [ "heck", "proc-macro2", @@ -3305,9 +3355,9 @@ dependencies = [ [[package]] name = "pythonize" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5bcac0d0b71821f0d69e42654f1e15e5c94b85196446c4de9588951a2117e7b" +checksum = "597907139a488b22573158793aa7539df36ae863eba300c75f3a0d65fc475e27" dependencies = [ "pyo3", "serde", @@ -3462,6 +3512,12 @@ dependencies = [ "getrandom 0.3.2", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.5.11" diff --git a/Cargo.toml b/Cargo.toml index 6841f630b..a86e97268 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,9 +14,9 @@ name = "cocoindex_engine" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.24.1", features = ["chrono"] } -pythonize = "0.24.0" -pyo3-async-runtimes = { version = "0.24.0", features = ["tokio-runtime"] } +pyo3 = { version = "0.25.0", features = ["chrono"] } +pythonize = "0.25.0" +pyo3-async-runtimes = { version = "0.25.0", features = ["tokio-runtime"] } anyhow = { version = "1.0.97", features = ["std"] } async-trait = "0.1.88" @@ -113,3 +113,4 @@ json5 = "0.4.1" aws-config = "1.6.2" aws-sdk-s3 = "1.85.0" aws-sdk-sqs = "1.67.0" +numpy = "0.25.0" diff --git a/src/py/convert.rs b/src/py/convert.rs index a09d3d483..1d4b5e2a2 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -1,5 +1,7 @@ use bytes::Bytes; +use numpy::{PyArray1, PyArrayDyn, PyArrayMethods}; use pyo3::IntoPyObjectExt; +use pyo3::types::PyAny; use pyo3::types::{PyList, PyTuple}; use pyo3::{exceptions::PyException, prelude::*}; use pythonize::{depythonize, pythonize}; @@ -72,11 +74,7 @@ fn basic_value_to_py_object<'py>( value::BasicValue::OffsetDateTime(v) => v.into_bound_py_any(py)?, value::BasicValue::TimeDelta(v) => v.into_bound_py_any(py)?, value::BasicValue::Json(v) => pythonize(py, v).into_py_result()?, - value::BasicValue::Vector(v) => v - .iter() - .map(|v| basic_value_to_py_object(py, v)) - .collect::>>()? - .into_bound_py_any(py)?, + value::BasicValue::Vector(v) => handle_vector_to_py(py, v)?, }; Ok(result) } @@ -150,16 +148,107 @@ fn basic_value_from_py_object<'py>( schema::BasicValueType::Json => { value::BasicValue::Json(Arc::from(depythonize::(v)?)) } - schema::BasicValueType::Vector(elem) => value::BasicValue::Vector(Arc::from( - v.extract::>>()? - .into_iter() - .map(|v| basic_value_from_py_object(&elem.element_type, &v)) - .collect::>>()?, - )), + schema::BasicValueType::Vector(elem) => { + if let Some(vector) = handle_ndarray_from_py(&elem.element_type, v)? { + vector + } else { + // Fallback to list + value::BasicValue::Vector(Arc::from( + v.extract::>>()? + .into_iter() + .map(|v| basic_value_from_py_object(&elem.element_type, &v)) + .collect::>>()?, + )) + } + } }; Ok(result) } +fn handle_ndarray_from_py<'py>( + elem_type: &schema::BasicValueType, + v: &Bound<'py, PyAny>, +) -> PyResult> { + macro_rules! try_convert { + ($t:ty, $cast:expr) => { + if let Ok(array) = v.downcast::>() { + let data = array.readonly().as_slice()?.to_vec(); + let vec = data.into_iter().map($cast).collect::>(); + return Ok(Some(value::BasicValue::Vector(Arc::from(vec)))); + } + }; + } + + match elem_type { + &schema::BasicValueType::Float32 => try_convert!(f32, value::BasicValue::Float32), + &schema::BasicValueType::Float64 => try_convert!(f64, value::BasicValue::Float64), + &schema::BasicValueType::Int64 => { + try_convert!(i32, |v| value::BasicValue::Int64(v as i64)); + try_convert!(i64, value::BasicValue::Int64); + try_convert!(u8, |v| value::BasicValue::Int64(v as i64)); + try_convert!(u16, |v| value::BasicValue::Int64(v as i64)); + try_convert!(u32, |v| value::BasicValue::Int64(v as i64)); + try_convert!(u64, |v| value::BasicValue::Int64(v as i64)); + } + _ => {} + } + + Ok(None) +} + +// Helper function to convert BasicValue::Vector to PyAny +fn handle_vector_to_py<'py>( + py: Python<'py>, + v: &[value::BasicValue], +) -> PyResult> { + match v.first() { + Some(value::BasicValue::Float32(_)) => { + let data = v + .iter() + .filter_map(|x| { + if let value::BasicValue::Float32(f) = x { + Some(*f) + } else { + None + } + }) + .collect::>(); + Ok(PyArray1::from_vec(py, data).into_any()) + } + Some(value::BasicValue::Float64(_)) => { + let data = v + .iter() + .filter_map(|x| { + if let value::BasicValue::Float64(f) = x { + Some(*f) + } else { + None + } + }) + .collect::>(); + Ok(PyArray1::from_vec(py, data).into_any()) + } + Some(value::BasicValue::Int64(_)) => { + let data = v + .iter() + .filter_map(|x| { + if let value::BasicValue::Int64(i) = x { + Some(*i) + } else { + None + } + }) + .collect::>(); + Ok(PyArray1::from_vec(py, data).into_any()) + } + _ => Ok(v + .iter() + .map(|v| basic_value_to_py_object(py, v)) + .collect::>>()? + .into_bound_py_any(py)?), + } +} + fn field_values_from_py_object<'py>( schema: &schema::StructSchema, v: &Bound<'py, PyAny>, From 0f4e84302126118ed3007c8d29e930690cb8a0fe Mon Sep 17 00:00:00 2001 From: lemorage Date: Mon, 9 Jun 2025 05:02:41 +0200 Subject: [PATCH 2/4] test: add roundtrip tests for numeric and string vector types --- python/cocoindex/tests/test_convert.py | 64 +++++++++++++++++++++++++- python/cocoindex/typing.py | 20 +++++++- src/py/convert.rs | 1 + 3 files changed, 81 insertions(+), 4 deletions(-) diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 32d5603f8..059c033c7 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -1,7 +1,7 @@ import uuid import datetime from dataclasses import dataclass, make_dataclass -from typing import NamedTuple, Literal, Any, Callable +from typing import NamedTuple, Literal, Any, Callable, Union import pytest import cocoindex from cocoindex.typing import ( @@ -91,7 +91,7 @@ def validate_full_roundtrip( decoded_value = build_engine_value_decoder(input_type or output_type, output_type)( value_from_engine ) - assert decoded_value == value + np.testing.assert_array_equal(decoded_value, value) def test_encode_engine_value_basic_types(): @@ -540,6 +540,11 @@ def test_vector_as_list() -> None: Float64VectorType = Vector[np.float64, Literal[3]] Int64VectorType = Vector[np.int64, Literal[3]] Int32VectorType = Vector[np.int32, Literal[3]] +UInt8VectorType = Vector[np.uint8, Literal[3]] +UInt16VectorType = Vector[np.uint16, Literal[3]] +UInt32VectorType = Vector[np.uint32, Literal[3]] +UInt64VectorType = Vector[np.uint64, Literal[3]] +StrVectorType = Vector[str] NDArrayFloat32Type = NDArray[np.float32] NDArrayFloat64Type = NDArray[np.float64] NDArrayInt64Type = NDArray[np.int64] @@ -765,3 +770,58 @@ def test_dump_vector_type_annotation_no_dim(): } } assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim + + +def test_roundtrip_vector_numeric_types() -> None: + """Test full roundtrip for numeric vector types using NDArray.""" + value_f32: Vector[np.float32, Literal[3]] = np.array( + [1.0, 2.0, 3.0], dtype=np.float32 + ) + validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]]) + value_f64: Vector[np.float64, Literal[3]] = np.array( + [1.0, 2.0, 3.0], dtype=np.float64 + ) + validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]]) + value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32) + validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]]) + value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64) + validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]]) + value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8) + validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]]) + value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16) + validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]]) + value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32) + validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]]) + value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64) + validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]]) + + +def test_roundtrip_vector_no_dimension() -> None: + """Test full roundtrip for vector types without dimension annotation.""" + value_f64: Vector[np.float64] = np.array([1.0, 2.0, 3.0], dtype=np.float64) + validate_full_roundtrip(value_f64, Vector[np.float64]) + + +def test_roundtrip_string_vector() -> None: + """Test full roundtrip for string vector using list.""" + value_str: Vector[str] = ["hello", "world"] + validate_full_roundtrip(value_str, Vector[str]) + + +def test_roundtrip_empty_vector() -> None: + """Test full roundtrip for empty numeric vector.""" + value_empty: Vector[np.float32] = np.array([], dtype=np.float32) + validate_full_roundtrip(value_empty, Vector[np.float32]) + + +def test_roundtrip_dimension_mismatch() -> None: + """Test that dimension mismatch raises an error during roundtrip.""" + value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32) + with pytest.raises(ValueError, match="Vector dimension mismatch"): + validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]]) + + +def test_roundtrip_list_backward_compatibility() -> None: + """Test full roundtrip for list-based vectors for backward compatibility.""" + value_list: list[int] = [1, 2, 3] + validate_full_roundtrip(value_list, list[int]) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 9ed52ed1c..11769dc16 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -111,6 +111,11 @@ def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None: class DtypeRegistry: + """ + Registry for NumPy dtypes used in CocoIndex. + Provides mappings from NumPy dtypes to CocoIndex's type representation. + """ + _mappings: dict[type, DtypeInfo] = { np.float32: DtypeInfo(np.float32, "Float32", float), np.float64: DtypeInfo(np.float64, "Float64", float), @@ -124,6 +129,7 @@ class DtypeRegistry: @classmethod def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None: + """Get DtypeInfo by NumPy dtype.""" if dtype is Any: raise TypeError( "NDArray for Vector must use a concrete numpy dtype, got `Any`." @@ -132,13 +138,21 @@ def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None: @staticmethod def get_by_kind(kind: str) -> DtypeInfo | None: + """Get DtypeInfo by kind.""" return next( (info for info in DtypeRegistry._mappings.values() if info.kind == kind), None, ) + @staticmethod + def rust_compatible_kind(kind: str) -> str: + """Map to a Rust-compatible kind for schema encoding.""" + # incompatible_integer_kinds = {"Int32", "UInt8", "UInt16", "UInt32", "UInt64"} + return "Int64" if "Int" in kind else kind + @staticmethod def supported_dtypes() -> KeysView[type]: + """Get a list of supported NumPy dtypes.""" return DtypeRegistry._mappings.keys() @@ -340,8 +354,10 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: raise ValueError("Vector type must have a vector info") if type_info.elem_type is None: raise ValueError("Vector type must have an element type") - encoded_type["element_type"] = _encode_type( - analyze_type_info(type_info.elem_type) + elem_type_info = analyze_type_info(type_info.elem_type) + encoded_type["element_type"] = _encode_type(elem_type_info) + encoded_type["element_type"]["kind"] = DtypeRegistry.rust_compatible_kind( + elem_type_info.kind ) encoded_type["dimension"] = type_info.vector_info.dim diff --git a/src/py/convert.rs b/src/py/convert.rs index 1d4b5e2a2..ac416cbf7 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -165,6 +165,7 @@ fn basic_value_from_py_object<'py>( Ok(result) } +// Helper function to convert PyAny to BasicValue for NDArray fn handle_ndarray_from_py<'py>( elem_type: &schema::BasicValueType, v: &Bound<'py, PyAny>, From 76d1304dd6d84f1bd7d72b264be47291228459ad Mon Sep 17 00:00:00 2001 From: lemorage Date: Mon, 9 Jun 2025 15:36:30 +0200 Subject: [PATCH 3/4] feat: numeric type uint64 is unsupported --- python/cocoindex/tests/test_convert.py | 14 ++------- python/cocoindex/typing.py | 1 - src/py/convert.rs | 43 +++++++++++++------------- 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 059c033c7..3e8cb52f5 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -640,15 +640,6 @@ def test_uint_support(): decoded = decoder(encoded) assert np.array_equal(decoded, value_uint32) assert decoded.dtype == np.uint32 - value_uint64 = np.array([1, 2, 3], dtype=np.uint64) - encoded = encode_engine_value(value_uint64) - assert np.array_equal(encoded, [1, 2, 3]) - decoder = make_engine_value_decoder( - [], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint64] - ) - decoded = decoder(encoded) - assert np.array_equal(decoded, value_uint64) - assert decoded.dtype == np.uint64 def test_ndarray_dimension_mismatch(): @@ -772,7 +763,7 @@ def test_dump_vector_type_annotation_no_dim(): assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim -def test_roundtrip_vector_numeric_types() -> None: +def test_full_roundtrip_vector_numeric_types() -> None: """Test full roundtrip for numeric vector types using NDArray.""" value_f32: Vector[np.float32, Literal[3]] = np.array( [1.0, 2.0, 3.0], dtype=np.float32 @@ -793,7 +784,8 @@ def test_roundtrip_vector_numeric_types() -> None: value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32) validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]]) value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64) - validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]]) + with pytest.raises(ValueError, match="type unsupported yet"): + validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]]) def test_roundtrip_vector_no_dimension() -> None: diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 11769dc16..f6ea75fa1 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -124,7 +124,6 @@ class DtypeRegistry: np.uint8: DtypeInfo(np.uint8, "UInt8", int), np.uint16: DtypeInfo(np.uint16, "UInt16", int), np.uint32: DtypeInfo(np.uint32, "UInt32", int), - np.uint64: DtypeInfo(np.uint64, "UInt64", int), } @classmethod diff --git a/src/py/convert.rs b/src/py/convert.rs index ac416cbf7..0c9d1cb51 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -189,7 +189,6 @@ fn handle_ndarray_from_py<'py>( try_convert!(u8, |v| value::BasicValue::Int64(v as i64)); try_convert!(u16, |v| value::BasicValue::Int64(v as i64)); try_convert!(u32, |v| value::BasicValue::Int64(v as i64)); - try_convert!(u64, |v| value::BasicValue::Int64(v as i64)); } _ => {} } @@ -206,40 +205,40 @@ fn handle_vector_to_py<'py>( Some(value::BasicValue::Float32(_)) => { let data = v .iter() - .filter_map(|x| { - if let value::BasicValue::Float32(f) = x { - Some(*f) - } else { - None - } + .map(|x| match x { + value::BasicValue::Float32(f) => Ok(*f), + _ => Err(PyErr::new::( + "Expected all elements to be Float32", + )), }) - .collect::>(); + .collect::>>()?; + Ok(PyArray1::from_vec(py, data).into_any()) } Some(value::BasicValue::Float64(_)) => { let data = v .iter() - .filter_map(|x| { - if let value::BasicValue::Float64(f) = x { - Some(*f) - } else { - None - } + .map(|x| match x { + value::BasicValue::Float64(f) => Ok(*f), + _ => Err(PyErr::new::( + "Expected all elements to be Float64", + )), }) - .collect::>(); + .collect::>>()?; + Ok(PyArray1::from_vec(py, data).into_any()) } Some(value::BasicValue::Int64(_)) => { let data = v .iter() - .filter_map(|x| { - if let value::BasicValue::Int64(i) = x { - Some(*i) - } else { - None - } + .map(|x| match x { + value::BasicValue::Int64(i) => Ok(*i), + _ => Err(PyErr::new::( + "Expected all elements to be Int64", + )), }) - .collect::>(); + .collect::>>()?; + Ok(PyArray1::from_vec(py, data).into_any()) } _ => Ok(v From 98026b527b65cdaf34f0ccb7be635be1625cae8a Mon Sep 17 00:00:00 2001 From: lemorage Date: Wed, 11 Jun 2025 16:09:52 +0200 Subject: [PATCH 4/4] feat: support dtype decoding by adding `np_number_type` to `AnalyzedTypeInfo` --- python/cocoindex/convert.py | 3 +-- python/cocoindex/tests/test_typing.py | 28 ++++++++++++++++++-- python/cocoindex/typing.py | 37 ++++++++++----------------- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 7527034e8..87600ddd9 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -127,8 +127,7 @@ def decode(value: Any) -> Any | None: return lambda value: uuid.UUID(bytes=value) if src_type_kind == "Vector": - elem_coco_type_info = analyze_type_info(dst_type_info.elem_type) - dtype_info = DtypeRegistry.get_by_kind(elem_coco_type_info.kind) + dtype_info = DtypeRegistry.get_by_dtype(dst_type_info.np_number_type) def decode_vector(value: Any) -> Any | None: if value is None: diff --git a/python/cocoindex/tests/test_typing.py b/python/cocoindex/tests/test_typing.py index 1714a7c7b..b2432d3f7 100644 --- a/python/cocoindex/tests/test_typing.py +++ b/python/cocoindex/tests/test_typing.py @@ -48,6 +48,7 @@ def test_ndarray_float32_no_dim(): elem_type=Float32, key_type=None, struct_type=None, + np_number_type=np.float32, attrs=None, nullable=False, ) @@ -62,6 +63,7 @@ def test_vector_float32_no_dim(): elem_type=Float32, key_type=None, struct_type=None, + np_number_type=np.float32, attrs=None, nullable=False, ) @@ -76,6 +78,7 @@ def test_ndarray_float64_with_dim(): elem_type=Float64, key_type=None, struct_type=None, + np_number_type=np.float64, attrs=None, nullable=False, ) @@ -90,6 +93,7 @@ def test_vector_float32_with_dim(): elem_type=Float32, key_type=None, struct_type=None, + np_number_type=np.float32, attrs=None, nullable=False, ) @@ -109,7 +113,7 @@ def test_ndarray_int32_with_dim(): result = analyze_type_info(typ) assert result.kind == "Vector" assert result.vector_info == VectorInfo(dim=10) - assert get_args(result.elem_type) == (int, TypeKind("Int32")) + assert get_args(result.elem_type) == (int, TypeKind("Int64")) assert not result.nullable @@ -118,7 +122,7 @@ def test_ndarray_uint8_no_dim(): result = analyze_type_info(typ) assert result.kind == "Vector" assert result.vector_info == VectorInfo(dim=None) - assert get_args(result.elem_type) == (int, TypeKind("UInt8")) + assert get_args(result.elem_type) == (int, TypeKind("Int64")) assert not result.nullable @@ -131,6 +135,7 @@ def test_nullable_ndarray(): elem_type=Float32, key_type=None, struct_type=None, + np_number_type=np.float32, attrs=None, nullable=True, ) @@ -177,6 +182,7 @@ def test_list_of_primitives(): elem_type=str, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -191,6 +197,7 @@ def test_list_of_structs(): elem_type=SimpleDataclass, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -205,6 +212,7 @@ def test_sequence_of_int(): elem_type=int, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -219,6 +227,7 @@ def test_list_with_vector_info(): elem_type=int, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -233,6 +242,7 @@ def test_dict_str_int(): elem_type=(str, int), key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -247,6 +257,7 @@ def test_mapping_str_dataclass(): elem_type=(str, SimpleDataclass), key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -261,6 +272,7 @@ def test_dataclass(): elem_type=None, key_type=None, struct_type=SimpleDataclass, + np_number_type=None, attrs=None, nullable=False, ) @@ -275,6 +287,7 @@ def test_named_tuple(): elem_type=None, key_type=None, struct_type=SimpleNamedTuple, + np_number_type=None, attrs=None, nullable=False, ) @@ -289,6 +302,7 @@ def test_tuple_key_value(): elem_type=None, key_type=str, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -303,6 +317,7 @@ def test_str(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -317,6 +332,7 @@ def test_bool(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -331,6 +347,7 @@ def test_bytes(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -345,6 +362,7 @@ def test_uuid(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -359,6 +377,7 @@ def test_date(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -373,6 +392,7 @@ def test_time(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -387,6 +407,7 @@ def test_timedelta(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -401,6 +422,7 @@ def test_float(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -415,6 +437,7 @@ def test_int(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs=None, nullable=False, ) @@ -429,6 +452,7 @@ def test_type_with_attributes(): elem_type=None, key_type=None, struct_type=None, + np_number_type=None, attrs={"key": "value"}, nullable=False, ) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index f6ea75fa1..6e00682c2 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -119,11 +119,11 @@ class DtypeRegistry: _mappings: dict[type, DtypeInfo] = { np.float32: DtypeInfo(np.float32, "Float32", float), np.float64: DtypeInfo(np.float64, "Float64", float), - np.int32: DtypeInfo(np.int32, "Int32", int), + np.int32: DtypeInfo(np.int32, "Int64", int), np.int64: DtypeInfo(np.int64, "Int64", int), - np.uint8: DtypeInfo(np.uint8, "UInt8", int), - np.uint16: DtypeInfo(np.uint16, "UInt16", int), - np.uint32: DtypeInfo(np.uint32, "UInt32", int), + np.uint8: DtypeInfo(np.uint8, "Int64", int), + np.uint16: DtypeInfo(np.uint16, "Int64", int), + np.uint32: DtypeInfo(np.uint32, "Int64", int), } @classmethod @@ -135,20 +135,6 @@ def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None: ) return cls._mappings.get(dtype) - @staticmethod - def get_by_kind(kind: str) -> DtypeInfo | None: - """Get DtypeInfo by kind.""" - return next( - (info for info in DtypeRegistry._mappings.values() if info.kind == kind), - None, - ) - - @staticmethod - def rust_compatible_kind(kind: str) -> str: - """Map to a Rust-compatible kind for schema encoding.""" - # incompatible_integer_kinds = {"Int32", "UInt8", "UInt16", "UInt32", "UInt64"} - return "Int64" if "Int" in kind else kind - @staticmethod def supported_dtypes() -> KeysView[type]: """Get a list of supported NumPy dtypes.""" @@ -167,6 +153,9 @@ class AnalyzedTypeInfo: key_type: type | None # For element of KTable struct_type: type | None # For Struct, a dataclass or namedtuple + np_number_type: ( + type | None + ) # NumPy dtype for the element type, if represented by numpy.ndarray or a NumPy scalar attrs: dict[str, Any] | None nullable: bool = False @@ -221,6 +210,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: struct_type: type | None = None elem_type: ElementType | None = None key_type: type | None = None + np_number_type: type | None = None if _is_struct_type(t): struct_type = t @@ -254,11 +244,11 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: if not dtype_args: raise ValueError("Invalid dtype specification for NDArray") - numpy_dtype = dtype_args[0] - dtype_info = DtypeRegistry.get_by_dtype(numpy_dtype) + np_number_type = dtype_args[0] + dtype_info = DtypeRegistry.get_by_dtype(np_number_type) if dtype_info is None: raise ValueError( - f"Unsupported numpy dtype for NDArray: {numpy_dtype}. " + f"Unsupported numpy dtype for NDArray: {np_number_type}. " f"Supported dtypes: {DtypeRegistry.supported_dtypes()}" ) elem_type = dtype_info.annotated_type @@ -272,6 +262,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: dtype_info = DtypeRegistry.get_by_dtype(t) if dtype_info is not None: kind = dtype_info.kind + np_number_type = dtype_info.numpy_dtype elif t is bytes: kind = "Bytes" elif t is str: @@ -301,6 +292,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: elem_type=elem_type, key_type=key_type, struct_type=struct_type, + np_number_type=np_number_type, attrs=attrs, nullable=nullable, ) @@ -355,9 +347,6 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: raise ValueError("Vector type must have an element type") elem_type_info = analyze_type_info(type_info.elem_type) encoded_type["element_type"] = _encode_type(elem_type_info) - encoded_type["element_type"]["kind"] = DtypeRegistry.rust_compatible_kind( - elem_type_info.kind - ) encoded_type["dimension"] = type_info.vector_info.dim elif type_info.kind in TABLE_TYPES: