Skip to content

Commit 34fbd84

Browse files
authored
snipe off some unsafe code (#922)
1 parent 882b57f commit 34fbd84

File tree

5 files changed

+123
-155
lines changed

5 files changed

+123
-155
lines changed

src/errors/validation_exception.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ use std::fmt::{Display, Write};
33
use std::str::from_utf8;
44

55
use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError};
6-
use pyo3::ffi::Py_ssize_t;
6+
use pyo3::intern;
77
use pyo3::once_cell::GILOnceCell;
88
use pyo3::prelude::*;
99
use pyo3::types::{PyDict, PyList, PyString};
10-
use pyo3::{ffi, intern};
1110
use serde::ser::{Error, SerializeMap, SerializeSeq};
1211
use serde::{Serialize, Serializer};
1312

@@ -173,22 +172,27 @@ impl ValidationError {
173172
#[pyo3(signature = (*, include_url = true, include_context = true))]
174173
pub fn errors(&self, py: Python, include_url: bool, include_context: bool) -> PyResult<Py<PyList>> {
175174
let url_prefix = get_url_prefix(py, include_url);
176-
// taken approximately from the pyo3, but modified to return the error during iteration
177-
// https://github.com/PyO3/pyo3/blob/a3edbf4fcd595f0e234c87d4705eb600a9779130/src/types/list.rs#L27-L55
178-
unsafe {
179-
let ptr = ffi::PyList_New(self.line_errors.len() as Py_ssize_t);
180-
181-
// We create the `Py` pointer here for two reasons:
182-
// - panics if the ptr is null
183-
// - its Drop cleans up the list if user code or the asserts panic.
184-
let list: Py<PyList> = Py::from_owned_ptr(py, ptr);
185-
186-
for (index, line_error) in (0_isize..).zip(&self.line_errors) {
187-
let item = line_error.as_dict(py, url_prefix, include_context, &self.error_mode)?;
188-
ffi::PyList_SET_ITEM(ptr, index, item.into_ptr());
189-
}
190-
191-
Ok(list)
175+
let mut iteration_error = None;
176+
let list = PyList::new(
177+
py,
178+
// PyList::new takes ExactSizeIterator, so if an error occurs during iteration we
179+
// fill the list with None before returning the error; the list will then be thrown
180+
// away safely.
181+
self.line_errors.iter().map(|e| -> PyObject {
182+
if iteration_error.is_some() {
183+
return py.None();
184+
}
185+
e.as_dict(py, url_prefix, include_context, &self.error_mode)
186+
.unwrap_or_else(|err| {
187+
iteration_error = Some(err);
188+
py.None()
189+
})
190+
}),
191+
);
192+
if let Some(err) = iteration_error {
193+
Err(err)
194+
} else {
195+
Ok(list.into())
192196
}
193197
}
194198

src/input/input_python.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ impl<'a> Input<'a> for PyAny {
187187
let str = py_str.to_str()?;
188188
serde_json::from_str(str).map_err(|e| map_json_err(self, e))
189189
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
190+
// Safety: from_slice does not run arbitrary Python code and the GIL is held so the
191+
// bytes array will not be mutated while from_slice is reading it
190192
serde_json::from_slice(unsafe { py_byte_array.as_bytes() }).map_err(|e| map_json_err(self, e))
191193
} else {
192194
Err(ValError::new(ErrorTypeDefaults::JsonType, self))
@@ -235,13 +237,15 @@ impl<'a> Input<'a> for PyAny {
235237
};
236238
Ok(str.into())
237239
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
238-
// see https://docs.rs/pyo3/latest/pyo3/types/struct.PyByteArray.html#method.as_bytes
239-
// for why this is marked unsafe
240-
let str = match from_utf8(unsafe { py_byte_array.as_bytes() }) {
241-
Ok(s) => s,
240+
// Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated,
241+
// and we immediately copy the bytes into a new Python string
242+
let s = match from_utf8(unsafe { py_byte_array.as_bytes() }) {
243+
// Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the
244+
// final output needs to be Python anyway.
245+
Ok(s) => PyString::new(self.py(), s),
242246
Err(_) => return Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)),
243247
};
244-
Ok(str.into())
248+
Ok(s.into())
245249
} else {
246250
Err(ValError::new(ErrorTypeDefaults::StringType, self))
247251
}
@@ -337,9 +341,8 @@ impl<'a> Input<'a> for PyAny {
337341
}
338342
}
339343
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
340-
if PyFloat::is_exact_type_of(self) {
341-
// Safety: self is PyFloat
342-
Ok(EitherFloat::Py(unsafe { self.downcast_unchecked::<PyFloat>() }))
344+
if let Ok(py_float) = self.downcast_exact::<PyFloat>() {
345+
Ok(EitherFloat::Py(py_float))
343346
} else if let Ok(float) = self.extract::<f64>() {
344347
// bools are cast to floats as either 0.0 or 1.0, so check for bool type in this specific case
345348
if (float == 0.0 || float == 1.0) && PyBool::is_exact_type_of(self) {
@@ -353,9 +356,8 @@ impl<'a> Input<'a> for PyAny {
353356
}
354357

355358
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
356-
if PyFloat::is_exact_type_of(self) {
357-
// Safety: self is PyFloat
358-
Ok(EitherFloat::Py(unsafe { self.downcast_unchecked::<PyFloat>() }))
359+
if let Ok(py_float) = self.downcast_exact() {
360+
Ok(EitherFloat::Py(py_float))
359361
} else if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? {
360362
str_as_float(self, &cow_str)
361363
} else if let Ok(float) = self.extract::<f64>() {

src/input/return_enums.rs

Lines changed: 46 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,11 @@ impl BuildSet for &PySet {
224224

225225
impl BuildSet for &PyFrozenSet {
226226
fn build_add(&self, item: PyObject) -> PyResult<()> {
227-
unsafe {
228-
py_error_on_minusone(
229-
self.py(),
230-
ffi::PySet_Add(self.as_ptr(), item.to_object(self.py()).as_ptr()),
231-
)
232-
}
227+
py_error_on_minusone(self.py(), unsafe {
228+
// Safety: self.as_ptr() the _only_ pointer to the `frozenset`, and it's allowed
229+
// to mutate this via the C API when nothing else can refer to it.
230+
ffi::PySet_Add(self.as_ptr(), item.to_object(self.py()).as_ptr())
231+
})
233232
}
234233

235234
fn build_len(&self) -> usize {
@@ -492,57 +491,32 @@ impl<'py> Iterator for MappingGenericIterator<'py> {
492491
type Item = ValResult<'py, (&'py PyAny, &'py PyAny)>;
493492

494493
fn next(&mut self) -> Option<Self::Item> {
495-
let item = match self.iter.next() {
496-
Some(Err(e)) => return Some(Err(mapping_err(e, self.iter.py(), self.input))),
497-
Some(Ok(item)) => item,
498-
None => return None,
499-
};
500-
let tuple: &PyTuple = match item.downcast() {
501-
Ok(tuple) => tuple,
502-
Err(_) => {
503-
return Some(Err(ValError::new(
494+
Some(match self.iter.next()? {
495+
Ok(item) => item.extract().map_err(|_| {
496+
ValError::new(
504497
ErrorType::MappingType {
505498
error: MAPPING_TUPLE_ERROR.into(),
506499
context: None,
507500
},
508501
self.input,
509-
)))
510-
}
511-
};
512-
if tuple.len() != 2 {
513-
return Some(Err(ValError::new(
514-
ErrorType::MappingType {
515-
error: MAPPING_TUPLE_ERROR.into(),
516-
context: None,
517-
},
518-
self.input,
519-
)));
520-
};
521-
#[cfg(PyPy)]
522-
let key = tuple.get_item(0).unwrap();
523-
#[cfg(PyPy)]
524-
let value = tuple.get_item(1).unwrap();
525-
#[cfg(not(PyPy))]
526-
let key = unsafe { tuple.get_item_unchecked(0) };
527-
#[cfg(not(PyPy))]
528-
let value = unsafe { tuple.get_item_unchecked(1) };
529-
Some(Ok((key, value)))
502+
)
503+
}),
504+
Err(e) => Err(mapping_err(e, self.iter.py(), self.input)),
505+
})
530506
}
531-
// size_hint is omitted as it isn't needed
532507
}
533508

534509
pub struct AttributesGenericIterator<'py> {
535510
object: &'py PyAny,
536-
attributes: &'py PyList,
537-
index: usize,
511+
// PyO3 should export this type upstream
512+
attributes_iterator: <&'py PyList as IntoIterator>::IntoIter,
538513
}
539514

540515
impl<'py> AttributesGenericIterator<'py> {
541516
pub fn new(py_any: &'py PyAny) -> ValResult<'py, Self> {
542517
Ok(Self {
543518
object: py_any,
544-
attributes: py_any.dir(),
545-
index: 0,
519+
attributes_iterator: py_any.dir().into_iter(),
546520
})
547521
}
548522
}
@@ -553,37 +527,31 @@ impl<'py> Iterator for AttributesGenericIterator<'py> {
553527
fn next(&mut self) -> Option<Self::Item> {
554528
// loop until we find an attribute who's name does not start with underscore,
555529
// or we get to the end of the list of attributes
556-
while self.index < self.attributes.len() {
557-
#[cfg(PyPy)]
558-
let name: &PyAny = self.attributes.get_item(self.index).unwrap();
559-
#[cfg(not(PyPy))]
560-
let name: &PyAny = unsafe { self.attributes.get_item_unchecked(self.index) };
561-
self.index += 1;
562-
// from benchmarks this is 14x faster than using the python `startswith` method
563-
let name_cow = match name.downcast::<PyString>() {
564-
Ok(name) => name.to_string_lossy(),
565-
Err(e) => return Some(Err(e.into())),
566-
};
567-
if !name_cow.as_ref().starts_with('_') {
568-
// getattr is most likely to fail due to an exception in a @property, skip
569-
if let Ok(attr) = self.object.getattr(name_cow.as_ref()) {
570-
// we don't want bound methods to be included, is there a better way to check?
571-
// ref https://stackoverflow.com/a/18955425/949890
572-
let is_bound = matches!(attr.hasattr(intern!(attr.py(), "__self__")), Ok(true));
573-
// the PyFunction::is_type_of(attr) catches `staticmethod`, but also any other function,
574-
// I think that's better than including static methods in the yielded attributes,
575-
// if someone really wants fields, they can use an explicit field, or a function to modify input
576-
#[cfg(not(PyPy))]
577-
if !is_bound && !PyFunction::is_type_of(attr) {
578-
return Some(Ok((name, attr)));
579-
}
580-
// MASSIVE HACK! PyFunction doesn't exist for PyPy,
581-
// is_instance_of::<PyFunction> crashes with a null pointer, hence this hack, see
582-
// https://github.com/pydantic/pydantic-core/pull/161#discussion_r917257635
583-
#[cfg(PyPy)]
584-
if !is_bound && attr.get_type().to_string() != "<class 'function'>" {
585-
return Some(Ok((name, attr)));
586-
}
530+
let name = self.attributes_iterator.next()?;
531+
// from benchmarks this is 14x faster than using the python `startswith` method
532+
let name_cow = match name.downcast::<PyString>() {
533+
Ok(name) => name.to_string_lossy(),
534+
Err(e) => return Some(Err(e.into())),
535+
};
536+
if !name_cow.as_ref().starts_with('_') {
537+
// getattr is most likely to fail due to an exception in a @property, skip
538+
if let Ok(attr) = self.object.getattr(name_cow.as_ref()) {
539+
// we don't want bound methods to be included, is there a better way to check?
540+
// ref https://stackoverflow.com/a/18955425/949890
541+
let is_bound = matches!(attr.hasattr(intern!(attr.py(), "__self__")), Ok(true));
542+
// the PyFunction::is_type_of(attr) catches `staticmethod`, but also any other function,
543+
// I think that's better than including static methods in the yielded attributes,
544+
// if someone really wants fields, they can use an explicit field, or a function to modify input
545+
#[cfg(not(PyPy))]
546+
if !is_bound && !PyFunction::is_type_of(attr) {
547+
return Some(Ok((name, attr)));
548+
}
549+
// MASSIVE HACK! PyFunction doesn't exist for PyPy,
550+
// is_instance_of::<PyFunction> crashes with a null pointer, hence this hack, see
551+
// https://github.com/pydantic/pydantic-core/pull/161#discussion_r917257635
552+
#[cfg(PyPy)]
553+
if !is_bound && attr.get_type().to_string() != "<class 'function'>" {
554+
return Some(Ok((name, attr)));
587555
}
588556
}
589557
}
@@ -621,12 +589,7 @@ pub enum GenericIterator {
621589

622590
impl From<JsonArray> for GenericIterator {
623591
fn from(array: JsonArray) -> Self {
624-
let length = array.len();
625-
let json_iter = GenericJsonIterator {
626-
array,
627-
length,
628-
index: 0,
629-
};
592+
let json_iter = GenericJsonIterator { array, index: 0 };
630593
Self::JsonArray(json_iter)
631594
}
632595
}
@@ -674,14 +637,15 @@ impl GenericPyIterator {
674637
#[derive(Debug, Clone)]
675638
pub struct GenericJsonIterator {
676639
array: JsonArray,
677-
length: usize,
678640
index: usize,
679641
}
680642

681643
impl GenericJsonIterator {
682644
pub fn next(&mut self, _py: Python) -> PyResult<Option<(&JsonInput, usize)>> {
683-
if self.index < self.length {
684-
let next = unsafe { self.array.get_unchecked(self.index) };
645+
if self.index < self.array.len() {
646+
// panic here is impossible due to bounds check above; compiler should be
647+
// able to optimize it away even
648+
let next = &self.array[self.index];
685649
let a = (next, self.index);
686650
self.index += 1;
687651
Ok(Some(a))
@@ -940,13 +904,7 @@ impl<'a> EitherFloat<'a> {
940904
pub fn as_f64(self) -> f64 {
941905
match self {
942906
EitherFloat::F64(f) => f,
943-
944-
EitherFloat::Py(f) => {
945-
{
946-
// Safety: known to be a python float
947-
unsafe { ffi::PyFloat_AS_DOUBLE(f.as_ptr()) }
948-
}
949-
}
907+
EitherFloat::Py(f) => f.value(),
950908
}
951909
}
952910
}

src/serializers/infer.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ pub(crate) fn infer_to_python_known(
133133
.map(|s| s.into_py(py))?,
134134
ObType::Bytearray => {
135135
let py_byte_array: &PyByteArray = value.downcast()?;
136-
// see https://docs.rs/pyo3/latest/pyo3/types/struct.PyByteArray.html#method.as_bytes
137-
// for why this is marked unsafe
136+
// Safety: the GIL is held while bytes_to_string is running; it doesn't run
137+
// arbitrary Python code, so py_byte_array cannot be mutated.
138138
let bytes = unsafe { py_byte_array.as_bytes() };
139139
extra
140140
.config
@@ -428,8 +428,12 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
428428
}
429429
ObType::Bytearray => {
430430
let py_byte_array: &PyByteArray = value.downcast().map_err(py_err_se_err)?;
431-
let bytes = unsafe { py_byte_array.as_bytes() };
432-
extra.config.bytes_mode.serialize_bytes(bytes, serializer)
431+
// Safety: the GIL is held while serialize_bytes is running; it doesn't run
432+
// arbitrary Python code, so py_byte_array cannot be mutated.
433+
extra
434+
.config
435+
.bytes_mode
436+
.serialize_bytes(unsafe { py_byte_array.as_bytes() }, serializer)
433437
}
434438
ObType::Dict => serialize_dict!(value.downcast::<PyDict>().map_err(py_err_se_err)?),
435439
ObType::List => serialize_seq_filter!(PyList),
@@ -581,8 +585,15 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: ObType, key: &'py PyAny, extra:
581585
.bytes_to_string(key.py(), key.downcast::<PyBytes>()?.as_bytes()),
582586
ObType::Bytearray => {
583587
let py_byte_array: &PyByteArray = key.downcast()?;
584-
let bytes = unsafe { py_byte_array.as_bytes() };
585-
extra.config.bytes_mode.bytes_to_string(key.py(), bytes)
588+
// Safety: the GIL is held while serialize_bytes is running; it doesn't run
589+
// arbitrary Python code, so py_byte_array cannot be mutated during the call.
590+
//
591+
// We copy the bytes into a new buffer immediately afterwards
592+
extra
593+
.config
594+
.bytes_mode
595+
.bytes_to_string(key.py(), unsafe { py_byte_array.as_bytes() })
596+
.map(|cow| Cow::Owned(cow.into_owned()))
586597
}
587598
ObType::Datetime => {
588599
let py_dt: &PyDateTime = key.downcast()?;

0 commit comments

Comments
 (0)