Skip to content

Commit bd1d494

Browse files
committed
adds into_pyarray to faer mat
1 parent 8dd4f8e commit bd1d494

File tree

4 files changed

+82
-2
lines changed

4 files changed

+82
-2
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ license = "BSD-2-Clause"
1818
half = { version = "2.0", default-features = false, optional = true }
1919
libc = "0.2"
2020
nalgebra = { version = ">=0.30, <0.34", default-features = false, optional = true }
21+
faer = { version = "0.21", optional = true }
2122
num-complex = ">= 0.2, < 0.5"
2223
num-integer = "0.1"
2324
num-traits = "0.2"
2425
ndarray = ">= 0.15, < 0.17"
2526
pyo3 = { version = "0.23.4", default-features = false, features = ["macros"] }
2627
rustc-hash = "2.0"
2728

29+
[features]
30+
faer = ["dep:faer"]
31+
2832
[dev-dependencies]
2933
pyo3 = { version = "0.23.3", default-features = false, features = ["auto-initialize"] }
3034
nalgebra = { version = ">=0.30, <0.34", default-features = false, features = ["std"] }

src/convert.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use std::{mem, os::raw::c_int, ptr};
44

5-
use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, OwnedRepr};
5+
use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, Ix2, OwnedRepr};
66
use pyo3::{Bound, Python};
77

88
use crate::array::{PyArray, PyArrayMethods};
@@ -90,6 +90,31 @@ impl<T: Element> IntoPyArray for Vec<T> {
9090
}
9191
}
9292

93+
#[cfg(feature = "faer")]
94+
impl<T: Element> IntoPyArray for faer::Mat<T> {
95+
type Item = T;
96+
type Dim = Ix2;
97+
98+
fn into_pyarray<'py>(mut self, py: Python<'py>) -> Bound<'py, PyArray<Self::Item, Self::Dim>> {
99+
let dims = Dim([self.nrows(), self.ncols()]);
100+
let rstride = self.row_stride();
101+
let cstride = self.col_stride();
102+
// let strides = [mem::size_of::<T>() as npy_intp, mem::size_of::<T>() as npy_intp];
103+
let strides = [rstride*mem::size_of::<T>() as npy_intp, cstride*mem::size_of::<T>() as npy_intp];
104+
let data_ptr = self.as_ptr_mut();
105+
unsafe {
106+
PyArray::from_raw_parts(
107+
py,
108+
dims,
109+
strides.as_ptr(),
110+
data_ptr,
111+
PySliceContainer::from(self),
112+
)
113+
}
114+
}
115+
}
116+
117+
93118
impl<A, D> IntoPyArray for ArrayBase<OwnedRepr<A>, D>
94119
where
95120
A: Element,

src/slice_container.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{mem, ptr};
1+
use std::{mem, ptr, slice::from_raw_parts_mut};
22

33
use ndarray::{ArrayBase, Dimension, OwnedRepr};
44
use pyo3::pyclass;
@@ -71,6 +71,31 @@ impl<T: Send + Sync> From<Vec<T>> for PySliceContainer {
7171
}
7272
}
7373

74+
#[cfg(feature = "faer")]
75+
impl<T: Send + Sync> From<faer::Mat<T>> for PySliceContainer {
76+
fn from(data: faer::Mat<T>) -> Self {
77+
unsafe fn drop_faer_mat<T>(ptr: *mut u8, len: usize, _cap: usize) {
78+
faer::mat::MatMut::from_raw_parts_mut(ptr as *mut T, len, 1, 1, 1);
79+
}
80+
81+
// FIXME(adamreichold): Use `Vec::into_raw_parts`
82+
// when it becomes stable and compatible with our MSRV.
83+
let mut data = mem::ManuallyDrop::new(data);
84+
85+
let ptr = data.as_ptr_mut() as *mut u8;
86+
let len = data.nrows() * data.ncols();
87+
let cap = 0;
88+
let drop = drop_faer_mat::<T>;
89+
90+
Self {
91+
ptr,
92+
len,
93+
cap,
94+
drop,
95+
}
96+
}
97+
}
98+
7499
impl<A, D> From<ArrayBase<OwnedRepr<A>, D>> for PySliceContainer
75100
where
76101
A: Send + Sync,

tests/to_py.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,32 @@ fn slice_container_type_confusion() {
287287
let _py_arr = vec![1, 2, 3].into_pyarray(py);
288288
});
289289
}
290+
#[cfg(feature = "faer")]
291+
#[test]
292+
fn faer_mat_to_numpy() {
293+
let faer_mat: faer::Mat<f64> = faer::Scale(2.0)*faer::mat::Mat::<f64>::identity(2, 2);
294+
let faer_mat_wide: faer::Mat<f64> = faer::mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
295+
let faer_mat_tall: faer::Mat<f64> = faer_mat_wide.transpose().to_owned();
296+
Python::with_gil(|py| {
297+
let mat_pyarray = faer_mat.into_pyarray(py);
298+
let mat_wide_pyarray = faer_mat_wide.into_pyarray(py);
299+
let mat_tall_pyarray = faer_mat_tall.into_pyarray(py);
300+
assert_eq!(
301+
mat_pyarray.readonly().as_array(),
302+
array![[2.0f64, 0.0f64], [0.0f64, 2.0f64]]
303+
);
304+
assert_eq!(
305+
mat_wide_pyarray.readonly().as_array(),
306+
array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]]
307+
);
308+
assert_eq!(
309+
mat_tall_pyarray.readonly().as_array(),
310+
array![[1.0f64, 4.0],
311+
[2.0, 5.0],
312+
[3.0, 6.0]]
313+
);
314+
});
315+
}
290316

291317
#[cfg(feature = "nalgebra")]
292318
#[test]

0 commit comments

Comments
 (0)