Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypy/typeshed/stubs/librt/librt/base64.pyi
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
def b64encode(s: bytes) -> bytes: ...
def b64decode(s: bytes | str) -> bytes: ...
def urlsafe_b64encode(s: bytes) -> bytes: ...
def urlsafe_b64decode(s: bytes | str) -> bytes: ...
96 changes: 88 additions & 8 deletions mypyc/lib-rt/librt_base64.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,42 @@

static PyObject *
b64decode_handle_invalid_input(
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen);
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen, bool freesrc);

#define BASE64_MAXBIN ((PY_SSIZE_T_MAX - 3) / 2)

#define STACK_BUFFER_SIZE 1024

static void
convert_encoded_to_urlsafe(char *buf, size_t len) {
// The loop is written to enable SIMD optimizations
for (size_t i = 0; i < len; i++) {
char ch = buf[i];
if (ch == '+') {
ch = '-';
} else if (ch == '/') {
ch = '_';
}
buf[i] = ch;
}
}

static void
convert_urlsafe_to_encoded(const char *src, size_t len, char *buf) {
// The loop is written to enable SIMD optimizations
for (size_t i = 0; i < len; i++) {
char ch = src[i];
if (ch == '-') {
ch = '+';
} else if (ch == '_') {
ch = '/';
}
buf[i] = ch;
}
}

static PyObject *
b64encode_internal(PyObject *obj) {
b64encode_internal(PyObject *obj, bool urlsafe) {
unsigned char *ascii_data;
char *bin_data;
int leftbits = 0;
Expand Down Expand Up @@ -53,6 +81,11 @@ b64encode_internal(PyObject *obj) {
}
size_t actual_len;
base64_encode(bin_data, bin_len, buf, &actual_len, 0);

if (urlsafe) {
convert_encoded_to_urlsafe(buf, actual_len);
}

PyObject *res = PyBytes_FromStringAndSize(buf, actual_len);
if (buflen > STACK_BUFFER_SIZE)
PyMem_Free(buf);
Expand All @@ -65,7 +98,16 @@ b64encode(PyObject *self, PyObject *const *args, size_t nargs) {
PyErr_SetString(PyExc_TypeError, "b64encode() takes exactly one argument");
return 0;
}
return b64encode_internal(args[0]);
return b64encode_internal(args[0], false);
}

static PyObject*
urlsafe_b64encode(PyObject *self, PyObject *const *args, size_t nargs) {
if (nargs != 1) {
PyErr_SetString(PyExc_TypeError, "urlsafe_b64encode() takes exactly one argument");
return 0;
}
return b64encode_internal(args[0], true);
}

static inline int
Expand All @@ -75,7 +117,7 @@ is_valid_base64_char(char c, bool allow_padding) {
}

static PyObject *
b64decode_internal(PyObject *arg) {
b64decode_internal(PyObject *arg, bool urlsafe) {
const char *src;
Py_ssize_t srclen_ssz;

Expand All @@ -102,6 +144,15 @@ b64decode_internal(PyObject *arg) {
return PyBytes_FromStringAndSize(NULL, 0);
}

if (urlsafe) {
char *new_src = PyMem_Malloc(srclen_ssz + 1);
if (new_src == NULL) {
return PyErr_NoMemory();
}
convert_urlsafe_to_encoded(src, srclen_ssz, new_src);
src = new_src;
}

// Quickly ignore invalid characters at the end. Other invalid characters
// are also accepted, but they need a slow path.
while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1], true)) {
Expand All @@ -123,6 +174,9 @@ b64decode_internal(PyObject *arg) {
// Allocate output bytes (uninitialized) of the max capacity
PyObject *out_bytes = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)max_out);
if (out_bytes == NULL) {
if (urlsafe) {
PyMem_Free((void *)src);
}
return NULL; // Propagate memory error
}

Expand All @@ -134,9 +188,12 @@ b64decode_internal(PyObject *arg) {
if (ret != 1) {
if (ret == 0) {
// Slow path: handle non-base64 input
return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen);
return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen, urlsafe);
}
Py_DECREF(out_bytes);
if (urlsafe) {
PyMem_Free((void *)src);
}
if (ret == -1) {
PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build");
} else {
Expand All @@ -145,6 +202,10 @@ b64decode_internal(PyObject *arg) {
return NULL;
}

if (urlsafe) {
PyMem_Free((void *)src);
}

// Sanity-check contract (decoder must not overflow our buffer)
if (outlen > max_out) {
Py_DECREF(out_bytes);
Expand All @@ -164,14 +225,17 @@ b64decode_internal(PyObject *arg) {
// with stdlib b64decode.
static PyObject *
b64decode_handle_invalid_input(
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen)
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen, bool freesrc)
{
// Copy input to a temporary buffer, with non-base64 characters and extra suffix
// characters removed
size_t newbuf_len = 0;
char *newbuf = PyMem_Malloc(srclen);
if (newbuf == NULL) {
Py_DECREF(out_bytes);
if (freesrc) {
PyMem_Free((void *)src);
}
return PyErr_NoMemory();
}

Expand Down Expand Up @@ -208,6 +272,9 @@ b64decode_handle_invalid_input(

// Stdlib always performs a non-strict padding check
if (newbuf_len % 4 != 0) {
if (freesrc) {
PyMem_Free((void *)src);
}
Py_DECREF(out_bytes);
PyMem_Free(newbuf);
PyErr_SetString(PyExc_ValueError, "Incorrect padding");
Expand All @@ -217,6 +284,9 @@ b64decode_handle_invalid_input(
size_t outlen = max_out;
int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0);
PyMem_Free(newbuf);
if (freesrc) {
PyMem_Free((void *)src);
}

if (ret != 1) {
Py_DECREF(out_bytes);
Expand All @@ -239,14 +309,22 @@ b64decode_handle_invalid_input(
return out_bytes;
}


static PyObject*
b64decode(PyObject *self, PyObject *const *args, size_t nargs) {
if (nargs != 1) {
PyErr_SetString(PyExc_TypeError, "b64decode() takes exactly one argument");
return 0;
}
return b64decode_internal(args[0]);
return b64decode_internal(args[0], false);
}

static PyObject*
urlsafe_b64decode(PyObject *self, PyObject *const *args, size_t nargs) {
if (nargs != 1) {
PyErr_SetString(PyExc_TypeError, "urlsafe_b64decode() takes exactly one argument");
return 0;
}
return b64decode_internal(args[0], true);
}

#endif
Expand All @@ -255,6 +333,8 @@ static PyMethodDef librt_base64_module_methods[] = {
#ifdef MYPYC_EXPERIMENTAL
{"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using Base64.")},
{"b64decode", (PyCFunction)b64decode, METH_FASTCALL, PyDoc_STR("Decode a Base64 encoded bytes object or ASCII string.")},
{"urlsafe_b64encode", (PyCFunction)urlsafe_b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using URL and file system safe Base64 alphabet.")},
{"urlsafe_b64decode", (PyCFunction)urlsafe_b64decode, METH_FASTCALL, PyDoc_STR("Decode bytes or ASCII string using URL and file system safe Base64 alphabet.")},
#endif
{NULL, NULL, 0, NULL}
};
Expand Down
8 changes: 4 additions & 4 deletions mypyc/lib-rt/librt_base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import_librt_base64(void)

#else // MYPYC_EXPERIMENTAL

#define LIBRT_BASE64_ABI_VERSION 0
#define LIBRT_BASE64_API_VERSION 1
#define LIBRT_BASE64_ABI_VERSION 1
#define LIBRT_BASE64_API_VERSION 2
#define LIBRT_BASE64_API_LEN 4

static void *LibRTBase64_API[LIBRT_BASE64_API_LEN];

#define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0])
#define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1])
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source)) LibRTBase64_API[2])
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source)) LibRTBase64_API[3])
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2])
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3])

static int
import_librt_base64(void)
Expand Down
24 changes: 24 additions & 0 deletions mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,18 @@
return_type=bytes_rprimitive,
c_function_name="LibRTBase64_b64encode_internal",
error_kind=ERR_MAGIC,
extra_int_constants=[(0, bool_rprimitive)],
experimental=True,
capsule="librt.base64",
)

function_op(
name="librt.base64.urlsafe_b64encode",
arg_types=[bytes_rprimitive],
return_type=bytes_rprimitive,
c_function_name="LibRTBase64_b64encode_internal",
error_kind=ERR_MAGIC,
extra_int_constants=[(1, bool_rprimitive)],
experimental=True,
capsule="librt.base64",
)
Expand All @@ -483,6 +495,18 @@
return_type=bytes_rprimitive,
c_function_name="LibRTBase64_b64decode_internal",
error_kind=ERR_MAGIC,
extra_int_constants=[(0, bool_rprimitive)],
experimental=True,
capsule="librt.base64",
)

function_op(
name="librt.base64.urlsafe_b64decode",
arg_types=[RUnion([bytes_rprimitive, str_rprimitive])],
return_type=bytes_rprimitive,
c_function_name="LibRTBase64_b64decode_internal",
error_kind=ERR_MAGIC,
extra_int_constants=[(1, bool_rprimitive)],
experimental=True,
capsule="librt.base64",
)
33 changes: 29 additions & 4 deletions mypyc/test-data/irbuild-base64.test
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[case testBase64_experimental]
from librt.base64 import b64encode, b64decode
from librt.base64 import b64encode, b64decode, urlsafe_b64encode, urlsafe_b64decode

def enc(b: bytes) -> bytes:
return b64encode(b)
Expand All @@ -9,22 +9,47 @@ def dec_bytes(b: bytes) -> bytes:

def dec_str(b: str) -> bytes:
return b64decode(b)

def uenc(b: bytes) -> bytes:
return urlsafe_b64encode(b)

def udec_bytes(b: bytes) -> bytes:
return urlsafe_b64decode(b)

def udec_str(b: str) -> bytes:
return urlsafe_b64decode(b)
[out]
def enc(b):
b, r0 :: bytes
L0:
r0 = LibRTBase64_b64encode_internal(b)
r0 = LibRTBase64_b64encode_internal(b, 0)
return r0
def dec_bytes(b):
b, r0 :: bytes
L0:
r0 = LibRTBase64_b64decode_internal(b)
r0 = LibRTBase64_b64decode_internal(b, 0)
return r0
def dec_str(b):
b :: str
r0 :: bytes
L0:
r0 = LibRTBase64_b64decode_internal(b)
r0 = LibRTBase64_b64decode_internal(b, 0)
return r0
def uenc(b):
b, r0 :: bytes
L0:
r0 = LibRTBase64_b64encode_internal(b, 1)
return r0
def udec_bytes(b):
b, r0 :: bytes
L0:
r0 = LibRTBase64_b64decode_internal(b, 1)
return r0
def udec_str(b):
b :: str
r0 :: bytes
L0:
r0 = LibRTBase64_b64decode_internal(b, 1)
return r0

[case testBase64ExperimentalDisabled]
Expand Down
Loading