Skip to content

Commit 246cdd2

Browse files
committed
feat: handle numpy array vector in Python conversion
1 parent d935c17 commit 246cdd2

File tree

3 files changed

+175
-29
lines changed

3 files changed

+175
-29
lines changed

Cargo.lock

Lines changed: 71 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ name = "cocoindex_engine"
1414
crate-type = ["cdylib"]
1515

1616
[dependencies]
17-
pyo3 = { version = "0.24.1", features = ["chrono"] }
18-
pythonize = "0.24.0"
19-
pyo3-async-runtimes = { version = "0.24.0", features = ["tokio-runtime"] }
17+
pyo3 = { version = "0.25.0", features = ["chrono"] }
18+
pythonize = "0.25.0"
19+
pyo3-async-runtimes = { version = "0.25.0", features = ["tokio-runtime"] }
2020

2121
anyhow = { version = "1.0.97", features = ["std"] }
2222
async-trait = "0.1.88"
@@ -113,3 +113,4 @@ json5 = "0.4.1"
113113
aws-config = "1.6.2"
114114
aws-sdk-s3 = "1.85.0"
115115
aws-sdk-sqs = "1.67.0"
116+
numpy = "0.25.0"

src/py/convert.rs

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use bytes::Bytes;
2+
use numpy::{PyArray1, PyArrayDyn, PyArrayMethods};
23
use pyo3::IntoPyObjectExt;
4+
use pyo3::types::PyAny;
35
use pyo3::types::{PyList, PyTuple};
46
use pyo3::{exceptions::PyException, prelude::*};
57
use pythonize::{depythonize, pythonize};
@@ -72,11 +74,7 @@ fn basic_value_to_py_object<'py>(
7274
value::BasicValue::OffsetDateTime(v) => v.into_bound_py_any(py)?,
7375
value::BasicValue::TimeDelta(v) => v.into_bound_py_any(py)?,
7476
value::BasicValue::Json(v) => pythonize(py, v).into_py_result()?,
75-
value::BasicValue::Vector(v) => v
76-
.iter()
77-
.map(|v| basic_value_to_py_object(py, v))
78-
.collect::<PyResult<Vec<_>>>()?
79-
.into_bound_py_any(py)?,
77+
value::BasicValue::Vector(v) => handle_vector_to_py(py, v)?,
8078
};
8179
Ok(result)
8280
}
@@ -150,16 +148,107 @@ fn basic_value_from_py_object<'py>(
150148
schema::BasicValueType::Json => {
151149
value::BasicValue::Json(Arc::from(depythonize::<serde_json::Value>(v)?))
152150
}
153-
schema::BasicValueType::Vector(elem) => value::BasicValue::Vector(Arc::from(
154-
v.extract::<Vec<Bound<'py, PyAny>>>()?
155-
.into_iter()
156-
.map(|v| basic_value_from_py_object(&elem.element_type, &v))
157-
.collect::<PyResult<Vec<_>>>()?,
158-
)),
151+
schema::BasicValueType::Vector(elem) => {
152+
if let Some(vector) = handle_ndarray_from_py(&elem.element_type, v)? {
153+
vector
154+
} else {
155+
// Fallback to list
156+
value::BasicValue::Vector(Arc::from(
157+
v.extract::<Vec<Bound<'py, PyAny>>>()?
158+
.into_iter()
159+
.map(|v| basic_value_from_py_object(&elem.element_type, &v))
160+
.collect::<PyResult<Vec<_>>>()?,
161+
))
162+
}
163+
}
159164
};
160165
Ok(result)
161166
}
162167

168+
fn handle_ndarray_from_py<'py>(
169+
elem_type: &schema::BasicValueType,
170+
v: &Bound<'py, PyAny>,
171+
) -> PyResult<Option<value::BasicValue>> {
172+
macro_rules! try_convert {
173+
($t:ty, $cast:expr) => {
174+
if let Ok(array) = v.downcast::<PyArrayDyn<$t>>() {
175+
let data = array.readonly().as_slice()?.to_vec();
176+
let vec = data.into_iter().map($cast).collect::<Vec<_>>();
177+
return Ok(Some(value::BasicValue::Vector(Arc::from(vec))));
178+
}
179+
};
180+
}
181+
182+
match elem_type {
183+
&schema::BasicValueType::Float32 => try_convert!(f32, value::BasicValue::Float32),
184+
&schema::BasicValueType::Float64 => try_convert!(f64, value::BasicValue::Float64),
185+
&schema::BasicValueType::Int64 => {
186+
try_convert!(i32, |v| value::BasicValue::Int64(v as i64));
187+
try_convert!(i64, value::BasicValue::Int64);
188+
try_convert!(u8, |v| value::BasicValue::Int64(v as i64));
189+
try_convert!(u16, |v| value::BasicValue::Int64(v as i64));
190+
try_convert!(u32, |v| value::BasicValue::Int64(v as i64));
191+
try_convert!(u64, |v| value::BasicValue::Int64(v as i64));
192+
}
193+
_ => {}
194+
}
195+
196+
Ok(None)
197+
}
198+
199+
// Helper function to convert BasicValue::Vector to PyAny
200+
fn handle_vector_to_py<'py>(
201+
py: Python<'py>,
202+
v: &[value::BasicValue],
203+
) -> PyResult<Bound<'py, PyAny>> {
204+
match v.first() {
205+
Some(value::BasicValue::Float32(_)) => {
206+
let data = v
207+
.iter()
208+
.filter_map(|x| {
209+
if let value::BasicValue::Float32(f) = x {
210+
Some(*f)
211+
} else {
212+
None
213+
}
214+
})
215+
.collect::<Vec<_>>();
216+
Ok(PyArray1::from_vec(py, data).into_any())
217+
}
218+
Some(value::BasicValue::Float64(_)) => {
219+
let data = v
220+
.iter()
221+
.filter_map(|x| {
222+
if let value::BasicValue::Float64(f) = x {
223+
Some(*f)
224+
} else {
225+
None
226+
}
227+
})
228+
.collect::<Vec<_>>();
229+
Ok(PyArray1::from_vec(py, data).into_any())
230+
}
231+
Some(value::BasicValue::Int64(_)) => {
232+
let data = v
233+
.iter()
234+
.filter_map(|x| {
235+
if let value::BasicValue::Int64(i) = x {
236+
Some(*i)
237+
} else {
238+
None
239+
}
240+
})
241+
.collect::<Vec<_>>();
242+
Ok(PyArray1::from_vec(py, data).into_any())
243+
}
244+
_ => Ok(v
245+
.iter()
246+
.map(|v| basic_value_to_py_object(py, v))
247+
.collect::<PyResult<Vec<_>>>()?
248+
.into_bound_py_any(py)?),
249+
}
250+
}
251+
163252
fn field_values_from_py_object<'py>(
164253
schema: &schema::StructSchema,
165254
v: &Bound<'py, PyAny>,

0 commit comments

Comments
 (0)