Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ license = "BSD-2-Clause"
half = { version = "2.0", default-features = false, optional = true }
libc = "0.2"
nalgebra = { version = ">=0.30, <0.34", default-features = false, optional = true }
faer = { version = "0.21", optional = true }
num-complex = ">= 0.2, < 0.5"
num-integer = "0.1"
num-traits = "0.2"
ndarray = ">= 0.15, < 0.17"
pyo3 = { version = "0.23.4", default-features = false, features = ["macros"] }
rustc-hash = "2.0"

[features]
faer = ["dep:faer"]

[dev-dependencies]
pyo3 = { version = "0.23.3", default-features = false, features = ["auto-initialize"] }
nalgebra = { version = ">=0.30, <0.34", default-features = false, features = ["std"] }
Expand Down
27 changes: 26 additions & 1 deletion src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::{mem, os::raw::c_int, ptr};

use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, OwnedRepr};
use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, Ix2, OwnedRepr};
use pyo3::{Bound, Python};

use crate::array::{PyArray, PyArrayMethods};
Expand Down Expand Up @@ -90,6 +90,31 @@ impl<T: Element> IntoPyArray for Vec<T> {
}
}

#[cfg(feature = "faer")]
impl<T: Element> IntoPyArray for faer::Mat<T> {
type Item = T;
type Dim = Ix2;

fn into_pyarray<'py>(mut self, py: Python<'py>) -> Bound<'py, PyArray<Self::Item, Self::Dim>> {
let dims = Dim([self.nrows(), self.ncols()]);
let rstride = self.row_stride();
let cstride = self.col_stride();
// let strides = [mem::size_of::<T>() as npy_intp, mem::size_of::<T>() as npy_intp];
let strides = [rstride*mem::size_of::<T>() as npy_intp, cstride*mem::size_of::<T>() as npy_intp];
let data_ptr = self.as_ptr_mut();
unsafe {
PyArray::from_raw_parts(
py,
dims,
strides.as_ptr(),
data_ptr,
PySliceContainer::from(self),
)
}
}
}


impl<A, D> IntoPyArray for ArrayBase<OwnedRepr<A>, D>
where
A: Element,
Expand Down
27 changes: 26 additions & 1 deletion src/slice_container.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{mem, ptr};
use std::{mem, ptr, slice::from_raw_parts_mut};

use ndarray::{ArrayBase, Dimension, OwnedRepr};
use pyo3::pyclass;
Expand Down Expand Up @@ -71,6 +71,31 @@ impl<T: Send + Sync> From<Vec<T>> for PySliceContainer {
}
}

#[cfg(feature = "faer")]
impl<T: Send + Sync> From<faer::Mat<T>> for PySliceContainer {
fn from(data: faer::Mat<T>) -> Self {
unsafe fn drop_faer_mat<T>(ptr: *mut u8, len: usize, _cap: usize) {
let _ = faer::mat::MatMut::from_raw_parts_mut(ptr as *mut T, len, 1, 1, 1);
}

// FIXME(adamreichold): Use `Vec::into_raw_parts`
// when it becomes stable and compatible with our MSRV.
let mut data = mem::ManuallyDrop::new(data);

let ptr = data.as_ptr_mut() as *mut u8;
let len = data.nrows() * data.ncols();
let cap = 0;
let drop = drop_faer_mat::<T>;

Self {
ptr,
len,
cap,
drop,
}
}
}

impl<A, D> From<ArrayBase<OwnedRepr<A>, D>> for PySliceContainer
where
A: Send + Sync,
Expand Down
26 changes: 26 additions & 0 deletions tests/to_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,32 @@ fn slice_container_type_confusion() {
let _py_arr = vec![1, 2, 3].into_pyarray(py);
});
}
#[cfg(feature = "faer")]
#[test]
fn faer_mat_to_numpy() {
let faer_mat: faer::Mat<f64> = faer::Scale(2.0)*faer::mat::Mat::<f64>::identity(2, 2);
let faer_mat_wide: faer::Mat<f64> = faer::mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let faer_mat_tall: faer::Mat<f64> = faer_mat_wide.transpose().to_owned();
Python::with_gil(|py| {
let mat_pyarray = faer_mat.into_pyarray(py);
let mat_wide_pyarray = faer_mat_wide.into_pyarray(py);
let mat_tall_pyarray = faer_mat_tall.into_pyarray(py);
assert_eq!(
mat_pyarray.readonly().as_array(),
array![[2.0f64, 0.0f64], [0.0f64, 2.0f64]]
);
assert_eq!(
mat_wide_pyarray.readonly().as_array(),
array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]]
);
assert_eq!(
mat_tall_pyarray.readonly().as_array(),
array![[1.0f64, 4.0],
[2.0, 5.0],
[3.0, 6.0]]
);
});
}

#[cfg(feature = "nalgebra")]
#[test]
Expand Down
Loading