|
1 | 1 | use bytes::Bytes; |
| 2 | +use numpy::{PyArray1, PyArrayDyn, PyArrayMethods}; |
2 | 3 | use pyo3::IntoPyObjectExt; |
| 4 | +use pyo3::types::PyAny; |
3 | 5 | use pyo3::types::{PyList, PyTuple}; |
4 | 6 | use pyo3::{exceptions::PyException, prelude::*}; |
5 | 7 | use pythonize::{depythonize, pythonize}; |
@@ -72,11 +74,7 @@ fn basic_value_to_py_object<'py>( |
72 | 74 | value::BasicValue::OffsetDateTime(v) => v.into_bound_py_any(py)?, |
73 | 75 | value::BasicValue::TimeDelta(v) => v.into_bound_py_any(py)?, |
74 | 76 | value::BasicValue::Json(v) => pythonize(py, v).into_py_result()?, |
75 | | - value::BasicValue::Vector(v) => v |
76 | | - .iter() |
77 | | - .map(|v| basic_value_to_py_object(py, v)) |
78 | | - .collect::<PyResult<Vec<_>>>()? |
79 | | - .into_bound_py_any(py)?, |
| 77 | + value::BasicValue::Vector(v) => handle_vector_to_py(py, v)?, |
80 | 78 | }; |
81 | 79 | Ok(result) |
82 | 80 | } |
@@ -150,16 +148,107 @@ fn basic_value_from_py_object<'py>( |
150 | 148 | schema::BasicValueType::Json => { |
151 | 149 | value::BasicValue::Json(Arc::from(depythonize::<serde_json::Value>(v)?)) |
152 | 150 | } |
153 | | - schema::BasicValueType::Vector(elem) => value::BasicValue::Vector(Arc::from( |
154 | | - v.extract::<Vec<Bound<'py, PyAny>>>()? |
155 | | - .into_iter() |
156 | | - .map(|v| basic_value_from_py_object(&elem.element_type, &v)) |
157 | | - .collect::<PyResult<Vec<_>>>()?, |
158 | | - )), |
| 151 | + schema::BasicValueType::Vector(elem) => { |
| 152 | + if let Some(vector) = handle_ndarray_from_py(&elem.element_type, v)? { |
| 153 | + vector |
| 154 | + } else { |
| 155 | + // Fallback to list |
| 156 | + value::BasicValue::Vector(Arc::from( |
| 157 | + v.extract::<Vec<Bound<'py, PyAny>>>()? |
| 158 | + .into_iter() |
| 159 | + .map(|v| basic_value_from_py_object(&elem.element_type, &v)) |
| 160 | + .collect::<PyResult<Vec<_>>>()?, |
| 161 | + )) |
| 162 | + } |
| 163 | + } |
159 | 164 | }; |
160 | 165 | Ok(result) |
161 | 166 | } |
162 | 167 |
|
| 168 | +fn handle_ndarray_from_py<'py>( |
| 169 | + elem_type: &schema::BasicValueType, |
| 170 | + v: &Bound<'py, PyAny>, |
| 171 | +) -> PyResult<Option<value::BasicValue>> { |
| 172 | + macro_rules! try_convert { |
| 173 | + ($t:ty, $cast:expr) => { |
| 174 | + if let Ok(array) = v.downcast::<PyArrayDyn<$t>>() { |
| 175 | + let data = array.readonly().as_slice()?.to_vec(); |
| 176 | + let vec = data.into_iter().map($cast).collect::<Vec<_>>(); |
| 177 | + return Ok(Some(value::BasicValue::Vector(Arc::from(vec)))); |
| 178 | + } |
| 179 | + }; |
| 180 | + } |
| 181 | + |
| 182 | + match elem_type { |
| 183 | + &schema::BasicValueType::Float32 => try_convert!(f32, value::BasicValue::Float32), |
| 184 | + &schema::BasicValueType::Float64 => try_convert!(f64, value::BasicValue::Float64), |
| 185 | + &schema::BasicValueType::Int64 => { |
| 186 | + try_convert!(i32, |v| value::BasicValue::Int64(v as i64)); |
| 187 | + try_convert!(i64, value::BasicValue::Int64); |
| 188 | + try_convert!(u8, |v| value::BasicValue::Int64(v as i64)); |
| 189 | + try_convert!(u16, |v| value::BasicValue::Int64(v as i64)); |
| 190 | + try_convert!(u32, |v| value::BasicValue::Int64(v as i64)); |
| 191 | + try_convert!(u64, |v| value::BasicValue::Int64(v as i64)); |
| 192 | + } |
| 193 | + _ => {} |
| 194 | + } |
| 195 | + |
| 196 | + Ok(None) |
| 197 | +} |
| 198 | + |
| 199 | +// Helper function to convert BasicValue::Vector to PyAny |
| 200 | +fn handle_vector_to_py<'py>( |
| 201 | + py: Python<'py>, |
| 202 | + v: &[value::BasicValue], |
| 203 | +) -> PyResult<Bound<'py, PyAny>> { |
| 204 | + match v.first() { |
| 205 | + Some(value::BasicValue::Float32(_)) => { |
| 206 | + let data = v |
| 207 | + .iter() |
| 208 | + .filter_map(|x| { |
| 209 | + if let value::BasicValue::Float32(f) = x { |
| 210 | + Some(*f) |
| 211 | + } else { |
| 212 | + None |
| 213 | + } |
| 214 | + }) |
| 215 | + .collect::<Vec<_>>(); |
| 216 | + Ok(PyArray1::from_vec(py, data).into_any()) |
| 217 | + } |
| 218 | + Some(value::BasicValue::Float64(_)) => { |
| 219 | + let data = v |
| 220 | + .iter() |
| 221 | + .filter_map(|x| { |
| 222 | + if let value::BasicValue::Float64(f) = x { |
| 223 | + Some(*f) |
| 224 | + } else { |
| 225 | + None |
| 226 | + } |
| 227 | + }) |
| 228 | + .collect::<Vec<_>>(); |
| 229 | + Ok(PyArray1::from_vec(py, data).into_any()) |
| 230 | + } |
| 231 | + Some(value::BasicValue::Int64(_)) => { |
| 232 | + let data = v |
| 233 | + .iter() |
| 234 | + .filter_map(|x| { |
| 235 | + if let value::BasicValue::Int64(i) = x { |
| 236 | + Some(*i) |
| 237 | + } else { |
| 238 | + None |
| 239 | + } |
| 240 | + }) |
| 241 | + .collect::<Vec<_>>(); |
| 242 | + Ok(PyArray1::from_vec(py, data).into_any()) |
| 243 | + } |
| 244 | + _ => Ok(v |
| 245 | + .iter() |
| 246 | + .map(|v| basic_value_to_py_object(py, v)) |
| 247 | + .collect::<PyResult<Vec<_>>>()? |
| 248 | + .into_bound_py_any(py)?), |
| 249 | + } |
| 250 | +} |
| 251 | + |
163 | 252 | fn field_values_from_py_object<'py>( |
164 | 253 | schema: &schema::StructSchema, |
165 | 254 | v: &Bound<'py, PyAny>, |
|
0 commit comments