Skip to content

Commit c66a76e

Browse files
authored
Split standard_b64encode_impl (python#17)
Also make PyObject an UnsafeCell<ffi::_object> so it can be passed around by & reference
1 parent 071b7f1 commit c66a76e

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

Modules/_base64/src/lib.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ struct BorrowedBuffer {
8383
}
8484

8585
impl BorrowedBuffer {
86-
unsafe fn from_object(obj: *mut PyObject) -> Result<Self, ()> {
86+
fn from_object(obj: &PyObject) -> Result<Self, ()> {
8787
let mut view = MaybeUninit::<Py_buffer>::uninit();
88-
if unsafe { PyObject_GetBuffer(obj, view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 {
88+
if unsafe { PyObject_GetBuffer(obj.as_raw(), view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 {
8989
return Err(());
9090
}
9191
Ok(Self {
@@ -110,6 +110,9 @@ impl Drop for BorrowedBuffer {
110110
}
111111
}
112112

113+
/// # Safety
114+
/// `module` must be a valid pointer of PyObject representing the module.
115+
/// `args` must be a valid pointer to an array of valid PyObject pointers with length `nargs`.
113116
#[unsafe(no_mangle)]
114117
pub unsafe extern "C" fn standard_b64encode(
115118
_module: *mut PyObject,
@@ -126,10 +129,19 @@ pub unsafe extern "C" fn standard_b64encode(
126129
return ptr::null_mut();
127130
}
128131

129-
let source = unsafe { *args };
130-
let buffer = match unsafe { BorrowedBuffer::from_object(source) } {
132+
let source = unsafe { &**args };
133+
134+
// Safe cast by Safety
135+
match standard_b64encode_impl(source) {
136+
Ok(result) => result,
137+
Err(_) => ptr::null_mut(),
138+
}
139+
}
140+
141+
fn standard_b64encode_impl(source: &PyObject) -> Result<*mut PyObject, ()> {
142+
let buffer = match BorrowedBuffer::from_object(source) {
131143
Ok(buf) => buf,
132-
Err(_) => return ptr::null_mut(),
144+
Err(_) => return Err(()),
133145
};
134146

135147
let view_len = buffer.len();
@@ -140,44 +152,43 @@ pub unsafe extern "C" fn standard_b64encode(
140152
c"standard_b64encode() argument has negative length".as_ptr(),
141153
);
142154
}
143-
return ptr::null_mut();
155+
return Err(());
144156
}
157+
145158
let input_len = view_len as usize;
146159
let input = unsafe { slice::from_raw_parts(buffer.as_ptr(), input_len) };
147160

148161
let Some(output_len) = encoded_output_len(input_len) else {
149162
unsafe {
150163
PyErr_NoMemory();
151164
}
152-
return ptr::null_mut();
165+
return Err(());
153166
};
154167

155168
if output_len > isize::MAX as usize {
156169
unsafe {
157170
PyErr_NoMemory();
158171
}
159-
return ptr::null_mut();
172+
return Err(());
160173
}
161174

162-
let result = unsafe {
163-
PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t)
164-
};
175+
let result = unsafe { PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t) };
165176
if result.is_null() {
166-
return ptr::null_mut();
177+
return Err(());
167178
}
168179

169180
let dest_ptr = unsafe { PyBytes_AsString(result) };
170181
if dest_ptr.is_null() {
171182
unsafe {
172183
Py_DecRef(result);
173184
}
174-
return ptr::null_mut();
185+
return Err(());
175186
}
176187
let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::<u8>(), output_len) };
177188

178189
let written = encode_into(input, dest);
179190
debug_assert_eq!(written, output_len);
180-
result
191+
Ok(result)
181192
}
182193

183194
#[unsafe(no_mangle)]

Modules/cpython-sys/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ fn generate_c_api_bindings(srcdir: &Path, builddir: Option<&str>, out_path: &Pat
5555
.allowlist_type("_?Py.*")
5656
.allowlist_var("_?Py.*")
5757
.blocklist_type("^PyMethodDef$")
58+
.blocklist_type("PyObject")
5859
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
5960
.generate()
6061
.expect("Unable to generate bindings");

Modules/cpython-sys/src/lib.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ pub const _Py_STATIC_IMMORTAL_INITIAL_REFCNT: Py_ssize_t =
5656
#[cfg(not(target_pointer_width = "64"))]
5757
pub const _Py_STATIC_IMMORTAL_INITIAL_REFCNT: Py_ssize_t = 7u32 << 28;
5858

59+
#[repr(transparent)]
60+
pub struct PyObject(std::cell::UnsafeCell<_object>);
61+
62+
impl PyObject {
63+
#[inline]
64+
pub fn as_raw(&self) -> *mut Self {
65+
self.0.get() as *mut Self
66+
}
67+
}
68+
69+
5970
#[repr(C)]
6071
pub union PyMethodDefFuncPointer {
6172
pub PyCFunction: unsafe extern "C" fn(slf: *mut PyObject, args: *mut PyObject) -> *mut PyObject,
@@ -113,18 +124,18 @@ unsafe impl Send for PyMethodDef {}
113124

114125
#[cfg(py_gil_disabled)]
115126
pub const PyObject_HEAD_INIT: PyObject = {
116-
let mut obj: PyObject = unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
127+
let mut obj: _object = unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
117128
obj.ob_flags = _Py_STATICALLY_ALLOCATED_FLAG as _;
118-
obj
129+
PyObject(std::cell::UnsafeCell::new(obj))
119130
};
120131

121132
#[cfg(not(py_gil_disabled))]
122-
pub const PyObject_HEAD_INIT: PyObject = PyObject {
133+
pub const PyObject_HEAD_INIT: PyObject = PyObject(std::cell::UnsafeCell::new(_object {
123134
__bindgen_anon_1: _object__bindgen_ty_1 {
124135
ob_refcnt_full: _Py_STATIC_IMMORTAL_INITIAL_REFCNT as i64,
125136
},
126137
ob_type: std::ptr::null_mut(),
127-
};
138+
}));
128139

129140
pub const PyModuleDef_HEAD_INIT: PyModuleDef_Base = PyModuleDef_Base {
130141
ob_base: PyObject_HEAD_INIT,

0 commit comments

Comments
 (0)