Skip to content

Commit 46c524c

Browse files
authored
Merge pull request #233 from adamreichold/fix-slice-box-type-confusion
Fix SliceBox-induced type confusion
2 parents 480bf55 + 83c8e19 commit 46c524c

File tree

4 files changed

+55
-20
lines changed

4 files changed

+55
-20
lines changed

src/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,8 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
448448
ID: IntoDimension<Dim = D>,
449449
{
450450
let dims = dims.into_dimension();
451+
let data_ptr = data_ptr.unwrap_or(boxed_slice.as_ptr());
451452
let container = SliceBox::new(boxed_slice);
452-
let data_ptr = data_ptr.unwrap_or_else(|| container.data.as_ptr());
453453
let cell = pyo3::PyClassInitializer::from(container)
454454
.create_cell(py)
455455
.expect("Object creation failed.");

src/npyiter.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ impl<'py, T: Element> NpySingleIterBuilder<'py, T, ReadWrite> {
186186

187187
impl<'py, T: Element, I: IterMode> NpySingleIterBuilder<'py, T, I> {
188188
/// Sets a flag to this builder, returning `self`.
189+
#[must_use]
189190
pub fn set(mut self, flag: NpyIterFlag) -> Self {
190191
self.flags |= flag.to_c_enum();
191192
self
@@ -388,6 +389,7 @@ impl<'py, T: Element> NpyMultiIterBuilder<'py, T, ()> {
388389
}
389390

390391
/// Set a flag to this builder, returning `self`.
392+
#[must_use]
391393
pub fn set(mut self, flag: NpyIterFlag) -> Self {
392394
self.flags |= flag.to_c_enum();
393395
self

src/slice_box.rs

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,58 @@
11
use pyo3::class::impl_::{PyClassImpl, ThreadCheckerStub};
22
use pyo3::pyclass::PyClass;
33
use pyo3::pyclass_slots::PyClassDummySlot;
4-
use pyo3::{ffi, type_object, types::PyAny, PyCell};
4+
use pyo3::type_object::{LazyStaticType, PyTypeInfo};
5+
use pyo3::{ffi, types::PyAny, PyCell};
56

6-
pub(crate) struct SliceBox<T> {
7-
pub(crate) data: Box<[T]>,
7+
pub(crate) struct SliceBox {
8+
ptr: *mut [u8],
9+
drop: unsafe fn(*mut [u8]),
810
}
911

10-
impl<T> SliceBox<T> {
11-
pub(crate) fn new(data: Box<[T]>) -> Self {
12-
Self { data }
12+
unsafe impl Send for SliceBox {}
13+
14+
impl SliceBox {
15+
pub(crate) fn new<T: Send>(data: Box<[T]>) -> Self {
16+
unsafe fn drop_boxed_slice<T>(ptr: *mut [u8]) {
17+
let _ = Box::from_raw(ptr as *mut [T]);
18+
}
19+
20+
let ptr = Box::into_raw(data) as *mut [u8];
21+
let drop = drop_boxed_slice::<T>;
22+
23+
Self { ptr, drop }
24+
}
25+
}
26+
27+
impl Drop for SliceBox {
28+
fn drop(&mut self) {
29+
unsafe {
30+
(self.drop)(self.ptr);
31+
}
1332
}
1433
}
1534

16-
impl<T> PyClass for SliceBox<T>
17-
where
18-
T: Send,
19-
{
35+
impl PyClass for SliceBox {
2036
type Dict = PyClassDummySlot;
2137
type WeakRef = PyClassDummySlot;
2238
type BaseNativeType = PyAny;
2339
}
2440

25-
impl<T> PyClassImpl for SliceBox<T>
26-
where
27-
T: Send,
28-
{
41+
impl PyClassImpl for SliceBox {
2942
const DOC: &'static str = "Memory store for PyArray using rust's Box<[T]> \0";
3043

3144
type BaseType = PyAny;
3245
type Layout = PyCell<Self>;
3346
type ThreadChecker = ThreadCheckerStub<Self>;
3447
}
3548

36-
unsafe impl<T> type_object::PyTypeInfo for SliceBox<T>
37-
where
38-
T: Send,
39-
{
49+
unsafe impl PyTypeInfo for SliceBox {
4050
type AsRefTarget = PyCell<Self>;
4151
const NAME: &'static str = "SliceBox";
4252
const MODULE: Option<&'static str> = Some("_rust_numpy");
4353

4454
#[inline]
4555
fn type_object_raw(py: pyo3::Python) -> *mut ffi::PyTypeObject {
46-
use pyo3::type_object::LazyStaticType;
4756
static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
4857
TYPE_OBJECT.get_or_init::<Self>(py)
4958
}

tests/to_py.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,27 @@ fn to_pyarray_object_array() {
232232
}
233233
})
234234
}
235+
236+
#[test]
237+
fn slice_box_type_confusion() {
238+
use ndarray::Array2;
239+
use pyo3::{
240+
types::{PyDict, PyString},
241+
ToPyObject,
242+
};
243+
244+
pyo3::Python::with_gil(|py| {
245+
let mut nd_arr = Array2::from_shape_fn((2, 3), |(_, _)| py.None());
246+
nd_arr[(0, 2)] = PyDict::new(py).to_object(py);
247+
nd_arr[(1, 0)] = PyString::new(py, "Hello:)").to_object(py);
248+
249+
let _py_arr = nd_arr.into_pyarray(py);
250+
251+
// Dropping `_arr` used to trigger a segmentation fault due to calling `Py_DECREF`
252+
// on 1, 2 and 3 interpreted as pointers into the Python heap
253+
// after having created a `SliceBox<PyObject>` backing `_py_arr`,
254+
// c.f. https://github.com/PyO3/rust-numpy/issues/232.
255+
let vec = vec![1, 2, 3];
256+
let _arr = vec.into_pyarray(py);
257+
});
258+
}

0 commit comments

Comments
 (0)