Skip to content

Commit b8d6ea3

Browse files
authored
Faster _base.b64encode with custom implementation (#13)
1 parent 3dc0283 commit b8d6ea3

File tree

3 files changed

+166
-30
lines changed

3 files changed

+166
-30
lines changed

Cargo.lock

Lines changed: 0 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Modules/_base64/Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ version = "0.1.0"
44
edition = "2024"
55

66
[dependencies]
7-
base64 = "0.22.1"
87
cpython-sys ={ path = "../cpython-sys" }
98

109
[lib]
1110
name = "_base64"
12-
crate-type = ["staticlib"]
11+
crate-type = ["staticlib"]

Modules/_base64/src/lib.rs

Lines changed: 165 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,183 @@
11
use std::cell::UnsafeCell;
2-
3-
use std::ffi::CStr;
4-
use std::ffi::CString;
5-
use std::ffi::c_char;
6-
use std::ffi::c_int;
7-
use std::ffi::c_void;
2+
use std::ffi::{c_char, c_int, c_void};
3+
use std::mem::MaybeUninit;
4+
use std::ptr;
5+
use std::slice;
86

97
use cpython_sys::METH_FASTCALL;
10-
use cpython_sys::Py_ssize_t;
118
use cpython_sys::PyBytes_AsString;
12-
use cpython_sys::PyBytes_FromString;
9+
use cpython_sys::PyBytes_FromStringAndSize;
10+
use cpython_sys::PyBuffer_Release;
1311
use cpython_sys::PyMethodDef;
1412
use cpython_sys::PyMethodDefFuncPointer;
1513
use cpython_sys::PyModuleDef;
1614
use cpython_sys::PyModuleDef_HEAD_INIT;
1715
use cpython_sys::PyModuleDef_Init;
1816
use cpython_sys::PyObject;
17+
use cpython_sys::PyObject_GetBuffer;
18+
use cpython_sys::Py_DecRef;
19+
use cpython_sys::PyErr_NoMemory;
20+
use cpython_sys::PyErr_SetString;
21+
use cpython_sys::PyExc_TypeError;
22+
use cpython_sys::Py_buffer;
23+
use cpython_sys::Py_ssize_t;
24+
25+
const PYBUF_SIMPLE: c_int = 0;
26+
const PAD_BYTE: u8 = b'=';
27+
const ENCODE_TABLE: [u8; 64] =
28+
*b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
29+
30+
#[inline]
31+
fn encoded_output_len(input_len: usize) -> Option<usize> {
32+
input_len
33+
.checked_add(2)
34+
.map(|n| n / 3)
35+
.and_then(|blocks| blocks.checked_mul(4))
36+
}
37+
38+
#[inline]
39+
fn encode_into(input: &[u8], output: &mut [u8]) -> usize {
40+
let mut src_index = 0;
41+
let mut dst_index = 0;
42+
let len = input.len();
43+
44+
while src_index + 3 <= len {
45+
let chunk = (u32::from(input[src_index]) << 16)
46+
| (u32::from(input[src_index + 1]) << 8)
47+
| u32::from(input[src_index + 2]);
48+
output[dst_index] = ENCODE_TABLE[((chunk >> 18) & 0x3f) as usize];
49+
output[dst_index + 1] = ENCODE_TABLE[((chunk >> 12) & 0x3f) as usize];
50+
output[dst_index + 2] = ENCODE_TABLE[((chunk >> 6) & 0x3f) as usize];
51+
output[dst_index + 3] = ENCODE_TABLE[(chunk & 0x3f) as usize];
52+
src_index += 3;
53+
dst_index += 4;
54+
}
1955

20-
use base64::prelude::*;
56+
match len - src_index {
57+
0 => {}
58+
1 => {
59+
let chunk = u32::from(input[src_index]) << 16;
60+
output[dst_index] = ENCODE_TABLE[((chunk >> 18) & 0x3f) as usize];
61+
output[dst_index + 1] = ENCODE_TABLE[((chunk >> 12) & 0x3f) as usize];
62+
output[dst_index + 2] = PAD_BYTE;
63+
output[dst_index + 3] = PAD_BYTE;
64+
dst_index += 4;
65+
}
66+
2 => {
67+
let chunk = (u32::from(input[src_index]) << 16)
68+
| (u32::from(input[src_index + 1]) << 8);
69+
output[dst_index] = ENCODE_TABLE[((chunk >> 18) & 0x3f) as usize];
70+
output[dst_index + 1] = ENCODE_TABLE[((chunk >> 12) & 0x3f) as usize];
71+
output[dst_index + 2] = ENCODE_TABLE[((chunk >> 6) & 0x3f) as usize];
72+
output[dst_index + 3] = PAD_BYTE;
73+
dst_index += 4;
74+
}
75+
_ => unreachable!("len - src_index cannot exceed 2"),
76+
}
77+
78+
dst_index
79+
}
80+
81+
struct BorrowedBuffer {
82+
view: Py_buffer,
83+
}
84+
85+
impl BorrowedBuffer {
86+
unsafe fn from_object(obj: *mut PyObject) -> Result<Self, ()> {
87+
let mut view = MaybeUninit::<Py_buffer>::uninit();
88+
if unsafe { PyObject_GetBuffer(obj, view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 {
89+
return Err(());
90+
}
91+
Ok(Self {
92+
view: unsafe { view.assume_init() },
93+
})
94+
}
95+
96+
fn len(&self) -> Py_ssize_t {
97+
self.view.len
98+
}
99+
100+
fn as_ptr(&self) -> *const u8 {
101+
self.view.buf.cast::<u8>() as *const u8
102+
}
103+
}
104+
105+
impl Drop for BorrowedBuffer {
106+
fn drop(&mut self) {
107+
unsafe {
108+
PyBuffer_Release(&mut self.view);
109+
}
110+
}
111+
}
21112

22113
#[unsafe(no_mangle)]
23-
pub unsafe extern "C" fn standard_b64encode(
114+
pub unsafe extern "C" fn b64encode(
24115
_module: *mut PyObject,
25116
args: *mut *mut PyObject,
26-
_nargs: Py_ssize_t,
117+
nargs: Py_ssize_t,
27118
) -> *mut PyObject {
28-
let buff = unsafe { *args };
29-
let ptr = unsafe { PyBytes_AsString(buff) };
30-
if ptr.is_null() {
31-
// Error handling omitted for now
32-
unimplemented!("Error handling goes here...")
119+
if nargs != 1 {
120+
unsafe {
121+
PyErr_SetString(
122+
PyExc_TypeError,
123+
c"b64encode() takes exactly one argument".as_ptr(),
124+
);
125+
}
126+
return ptr::null_mut();
127+
}
128+
129+
let source = unsafe { *args };
130+
let buffer = match unsafe { BorrowedBuffer::from_object(source) } {
131+
Ok(buf) => buf,
132+
Err(_) => return ptr::null_mut(),
133+
};
134+
135+
let view_len = buffer.len();
136+
if view_len < 0 {
137+
unsafe {
138+
PyErr_SetString(
139+
PyExc_TypeError,
140+
c"b64encode() argument has negative length".as_ptr(),
141+
);
142+
}
143+
return ptr::null_mut();
33144
}
34-
let cdata = unsafe { CStr::from_ptr(ptr) };
35-
let res = BASE64_STANDARD.encode(cdata.to_bytes());
36-
unsafe { PyBytes_FromString(CString::new(res).unwrap().as_ptr()) }
145+
let input_len = view_len as usize;
146+
let input = unsafe { slice::from_raw_parts(buffer.as_ptr(), input_len) };
147+
148+
let Some(output_len) = encoded_output_len(input_len) else {
149+
unsafe {
150+
PyErr_NoMemory();
151+
}
152+
return ptr::null_mut();
153+
};
154+
155+
if output_len > isize::MAX as usize {
156+
unsafe {
157+
PyErr_NoMemory();
158+
}
159+
return ptr::null_mut();
160+
}
161+
162+
let result = unsafe {
163+
PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t)
164+
};
165+
if result.is_null() {
166+
return ptr::null_mut();
167+
}
168+
169+
let dest_ptr = unsafe { PyBytes_AsString(result) };
170+
if dest_ptr.is_null() {
171+
unsafe {
172+
Py_DecRef(result);
173+
}
174+
return ptr::null_mut();
175+
}
176+
let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::<u8>(), output_len) };
177+
178+
let written = encode_into(input, dest);
179+
debug_assert_eq!(written, output_len);
180+
result
37181
}
38182

39183
#[unsafe(no_mangle)]
@@ -62,9 +206,9 @@ unsafe impl Sync for ModuleDef {}
62206
pub static _BASE64_MODULE_METHODS: [PyMethodDef; 2] = {
63207
[
64208
PyMethodDef {
65-
ml_name: c"standard_b64encode".as_ptr() as *mut c_char,
209+
ml_name: c"b64encode".as_ptr() as *mut c_char,
66210
ml_meth: PyMethodDefFuncPointer {
67-
PyCFunctionFast: standard_b64encode,
211+
PyCFunctionFast: b64encode,
68212
},
69213
ml_flags: METH_FASTCALL,
70214
ml_doc: c"Demo for the _base64 module".as_ptr() as *mut c_char,

0 commit comments

Comments
 (0)