Skip to content

Commit 82b8507

Browse files
committed
Optimize copying nalgebra matrices into NumPy arrays.
1 parent 5271427 commit 82b8507

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

src/convert.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,17 +189,21 @@ where
189189
N: nalgebra::Scalar + Element,
190190
R: nalgebra::Dim,
191191
C: nalgebra::Dim,
192-
S: nalgebra::storage::Storage<N, R, C>,
192+
S: nalgebra::Storage<N, R, C>,
193193
{
194194
type Item = N;
195195
type Dim = crate::Ix2;
196196

197197
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
198198
unsafe {
199-
let array = PyArray::new(py, (self.nrows(), self.ncols()), false);
200-
for r in 0..self.nrows() {
201-
for c in 0..self.ncols() {
202-
*array.uget_mut((r, c)) = self.get_unchecked((r, c)).clone();
199+
let array = PyArray::<N, _>::new(py, (self.nrows(), self.ncols()), true);
200+
let mut data_ptr = array.data();
201+
if self.data.is_contiguous() {
202+
ptr::copy_nonoverlapping(self.data.ptr(), data_ptr, self.len());
203+
} else {
204+
for item in self.iter() {
205+
data_ptr.write(item.clone());
206+
data_ptr = data_ptr.add(1);
203207
}
204208
}
205209
array

tests/to_py.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ fn slice_container_type_confusion() {
303303
#[test]
304304
fn matrix_to_numpy() {
305305
let matrix = nalgebra::Matrix3::<i32>::new(0, 1, 2, 3, 4, 5, 6, 7, 8);
306+
assert!(nalgebra::RawStorage::is_contiguous(&matrix.data));
306307

307308
Python::with_gil(|py| {
308309
let array = matrix.to_pyarray(py);
@@ -312,4 +313,13 @@ fn matrix_to_numpy() {
312313
array![[0, 1, 2], [3, 4, 5], [6, 7, 8]],
313314
);
314315
});
316+
317+
let matrix = matrix.row(0);
318+
assert!(!nalgebra::RawStorage::is_contiguous(&matrix.data));
319+
320+
Python::with_gil(|py| {
321+
let array = matrix.to_pyarray(py);
322+
323+
assert_eq!(array.readonly().as_array(), array![[0, 1, 2]]);
324+
});
315325
}

0 commit comments

Comments
 (0)