diff --git a/newsfragments/4742.fixed.md b/newsfragments/4742.fixed.md new file mode 100644 index 00000000000..7bfe38b0c38 --- /dev/null +++ b/newsfragments/4742.fixed.md @@ -0,0 +1 @@ +Use critical section in `PyByteArray::to_vec` on freethreaded build to replicate GIL-enabled "soundness". diff --git a/src/types/bytearray.rs b/src/types/bytearray.rs index 5aa7acceb86..da3eb3b2b20 100644 --- a/src/types/bytearray.rs +++ b/src/types/bytearray.rs @@ -2,6 +2,7 @@ use crate::err::{PyErr, PyResult}; use crate::ffi_ptr_ext::FfiPtrExt; use crate::instance::{Borrowed, Bound}; use crate::py_result_ext::PyResultExt; +use crate::sync::with_critical_section; use crate::types::any::PyAnyMethods; use crate::{ffi, PyAny, Python}; use std::slice; @@ -125,28 +126,29 @@ pub trait PyByteArrayMethods<'py>: crate::sealed::Sealed { /// /// As a result, this slice should only be used for short-lived operations without executing any /// Python code, such as copying into a Vec. + /// For free-threaded Python support see also [`with_critical_section`]. /// /// # Examples /// /// ```rust /// use pyo3::prelude::*; /// use pyo3::exceptions::PyRuntimeError; + /// use pyo3::sync::with_critical_section; /// use pyo3::types::PyByteArray; /// /// #[pyfunction] /// fn a_valid_function(bytes: &Bound<'_, PyByteArray>) -> PyResult<()> { - /// let section = { - /// // SAFETY: We promise to not let the interpreter regain control + /// let section = with_critical_section(bytes, || { + /// // SAFETY: We promise to not let the interpreter regain control over the bytearray /// // or invoke any PyO3 APIs while using the slice. /// let slice = unsafe { bytes.as_bytes() }; /// /// // Copy only a section of `bytes` while avoiding /// // `to_vec` which copies the entire thing. - /// let section = slice - /// .get(6..11) - /// .ok_or_else(|| PyRuntimeError::new_err("input is not long enough"))?; - /// Vec::from(section) - /// }; + /// slice.get(6..11) + /// .map(Vec::from) + /// .ok_or_else(|| PyRuntimeError::new_err("input is not long enough")) + /// })?; /// /// // Now we can do things with `section` and call PyO3 APIs again. /// // ... @@ -188,6 +190,9 @@ pub trait PyByteArrayMethods<'py>: crate::sealed::Sealed { /// # #[allow(dead_code)] /// #[pyfunction] /// fn bug(py: Python<'_>, bytes: &Bound<'_, PyByteArray>) { + /// // No critical section is being used. + /// // This means that for no-gil Python another thread could be modifying the + /// // bytearray concurrently and thus invalidate `slice` any time. /// let slice = unsafe { bytes.as_bytes() }; /// /// // This explicitly yields control back to the Python interpreter... @@ -267,7 +272,14 @@ impl<'py> PyByteArrayMethods<'py> for Bound<'py, PyByteArray> { } fn to_vec(&self) -> Vec { - unsafe { self.as_bytes() }.to_vec() + with_critical_section(self, || { + // SAFETY: + // * `self` is a `Bound` object, which guarantees that the Python GIL is held. + // * For no-gil Python, a critical section is used in lieu of the GIL. + // * We don't interact with the interpreter + // * We don't mutate the underlying slice + unsafe { self.as_bytes() }.to_vec() + }) } fn resize(&self, len: usize) -> PyResult<()> { @@ -444,4 +456,129 @@ mod tests { .is_instance_of::(py)); }) } + + // * wasm has no threading support + // * CPython 3.13t is unsound => test fails + #[cfg(all( + not(target_family = "wasm"), + any(Py_3_14, not(all(Py_3_13, Py_GIL_DISABLED))) + ))] + #[test] + fn test_data_integrity_in_critical_section() { + use crate::instance::Py; + use crate::sync::{with_critical_section, MutexExt}; + + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Mutex; + use std::thread; + use std::thread::ScopedJoinHandle; + use std::time::Duration; + + const SIZE: usize = 1_000_000; + const DATA_VALUE: u8 = 42; + + fn make_byte_array(py: Python<'_>, size: usize, value: u8) -> Bound<'_, PyByteArray> { + PyByteArray::new_with(py, size, |b| { + b.fill(value); + Ok(()) + }) + .unwrap() + } + + let data: Mutex> = Mutex::new(Python::attach(|py| { + make_byte_array(py, SIZE, DATA_VALUE).unbind() + })); + + fn get_data<'py>( + data: &Mutex>, + py: Python<'py>, + ) -> Bound<'py, PyByteArray> { + data.lock_py_attached(py).unwrap().bind(py).clone() + } + + fn set_data(data: &Mutex>, new: Bound<'_, PyByteArray>) { + let py = new.py(); + *data.lock_py_attached(py).unwrap() = new.unbind() + } + + let running = AtomicBool::new(true); + let extending = AtomicBool::new(false); + + // continuously extends and resets the bytearray in data + let worker1 = || { + let mut rounds = 0; + while running.load(Ordering::SeqCst) && rounds < 50 { + Python::attach(|py| { + let byte_array = get_data(&data, py); + extending.store(true, Ordering::SeqCst); + byte_array + .call_method("extend", (&byte_array,), None) + .unwrap(); + extending.store(false, Ordering::SeqCst); + set_data(&data, make_byte_array(py, SIZE, DATA_VALUE)); + rounds += 1; + }); + } + }; + + // continuously checks the integrity of bytearray in data + let worker2 = || { + while running.load(Ordering::SeqCst) { + if !extending.load(Ordering::SeqCst) { + // wait until we have a chance to read inconsistent state + continue; + } + Python::attach(|py| { + let read = get_data(&data, py); + if read.len() == SIZE { + // extend is still not done => wait even more + return; + } + with_critical_section(&read, || { + // SAFETY: we are in a critical section + // This is the whole point of the test: make sure that a + // critical section is sufficient to ensure that the data + // read is consistent. + unsafe { + let bytes = read.as_bytes(); + assert!(bytes.iter().rev().take(50).all(|v| *v == DATA_VALUE + && bytes.iter().take(50).all(|v| *v == DATA_VALUE))); + } + }); + }); + } + }; + + thread::scope(|s| { + let mut handle1 = Some(s.spawn(worker1)); + let mut handle2 = Some(s.spawn(worker2)); + let mut handles = [&mut handle1, &mut handle2]; + + let t0 = std::time::Instant::now(); + while t0.elapsed() < Duration::from_secs(1) { + for handle in &mut handles { + if handle + .as_ref() + .map(ScopedJoinHandle::is_finished) + .unwrap_or(false) + { + let res = handle.take().unwrap().join(); + if res.is_err() { + running.store(false, Ordering::SeqCst); + } + res.unwrap(); + } + } + if handles.iter().any(|handle| handle.is_none()) { + break; + } + } + running.store(false, Ordering::SeqCst); + for handle in &mut handles { + if let Some(handle) = handle.take() { + handle.join().unwrap() + } + } + }); + } }