diff --git a/mypy/typeshed/stubs/librt/librt/base64.pyi b/mypy/typeshed/stubs/librt/librt/base64.pyi index 1cea838505d6..275258a5c8ff 100644 --- a/mypy/typeshed/stubs/librt/librt/base64.pyi +++ b/mypy/typeshed/stubs/librt/librt/base64.pyi @@ -1,2 +1,4 @@ def b64encode(s: bytes) -> bytes: ... def b64decode(s: bytes | str) -> bytes: ... +def urlsafe_b64encode(s: bytes) -> bytes: ... +def urlsafe_b64decode(s: bytes | str) -> bytes: ... diff --git a/mypyc/lib-rt/librt_base64.c b/mypyc/lib-rt/librt_base64.c index b23be59959c5..002db0c50d56 100644 --- a/mypyc/lib-rt/librt_base64.c +++ b/mypyc/lib-rt/librt_base64.c @@ -9,14 +9,42 @@ static PyObject * b64decode_handle_invalid_input( - PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen); + PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen, bool freesrc); #define BASE64_MAXBIN ((PY_SSIZE_T_MAX - 3) / 2) #define STACK_BUFFER_SIZE 1024 +static void +convert_encoded_to_urlsafe(char *buf, size_t len) { + // The loop is written to enable SIMD optimizations + for (size_t i = 0; i < len; i++) { + char ch = buf[i]; + if (ch == '+') { + ch = '-'; + } else if (ch == '/') { + ch = '_'; + } + buf[i] = ch; + } +} + +static void +convert_urlsafe_to_encoded(const char *src, size_t len, char *buf) { + // The loop is written to enable SIMD optimizations + for (size_t i = 0; i < len; i++) { + char ch = src[i]; + if (ch == '-') { + ch = '+'; + } else if (ch == '_') { + ch = '/'; + } + buf[i] = ch; + } +} + static PyObject * -b64encode_internal(PyObject *obj) { +b64encode_internal(PyObject *obj, bool urlsafe) { unsigned char *ascii_data; char *bin_data; int leftbits = 0; @@ -53,6 +81,11 @@ b64encode_internal(PyObject *obj) { } size_t actual_len; base64_encode(bin_data, bin_len, buf, &actual_len, 0); + + if (urlsafe) { + convert_encoded_to_urlsafe(buf, actual_len); + } + PyObject *res = PyBytes_FromStringAndSize(buf, actual_len); if (buflen > STACK_BUFFER_SIZE) PyMem_Free(buf); @@ -65,7 +98,16 @@ b64encode(PyObject *self, PyObject *const *args, size_t nargs) { PyErr_SetString(PyExc_TypeError, "b64encode() takes exactly one argument"); return 0; } - return b64encode_internal(args[0]); + return b64encode_internal(args[0], false); +} + +static PyObject* +urlsafe_b64encode(PyObject *self, PyObject *const *args, size_t nargs) { + if (nargs != 1) { + PyErr_SetString(PyExc_TypeError, "urlsafe_b64encode() takes exactly one argument"); + return 0; + } + return b64encode_internal(args[0], true); } static inline int @@ -75,7 +117,7 @@ is_valid_base64_char(char c, bool allow_padding) { } static PyObject * -b64decode_internal(PyObject *arg) { +b64decode_internal(PyObject *arg, bool urlsafe) { const char *src; Py_ssize_t srclen_ssz; @@ -102,6 +144,15 @@ b64decode_internal(PyObject *arg) { return PyBytes_FromStringAndSize(NULL, 0); } + if (urlsafe) { + char *new_src = PyMem_Malloc(srclen_ssz + 1); + if (new_src == NULL) { + return PyErr_NoMemory(); + } + convert_urlsafe_to_encoded(src, srclen_ssz, new_src); + src = new_src; + } + // Quickly ignore invalid characters at the end. Other invalid characters // are also accepted, but they need a slow path. while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1], true)) { @@ -123,6 +174,9 @@ b64decode_internal(PyObject *arg) { // Allocate output bytes (uninitialized) of the max capacity PyObject *out_bytes = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)max_out); if (out_bytes == NULL) { + if (urlsafe) { + PyMem_Free((void *)src); + } return NULL; // Propagate memory error } @@ -134,9 +188,12 @@ b64decode_internal(PyObject *arg) { if (ret != 1) { if (ret == 0) { // Slow path: handle non-base64 input - return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen); + return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen, urlsafe); } Py_DECREF(out_bytes); + if (urlsafe) { + PyMem_Free((void *)src); + } if (ret == -1) { PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build"); } else { @@ -145,6 +202,10 @@ b64decode_internal(PyObject *arg) { return NULL; } + if (urlsafe) { + PyMem_Free((void *)src); + } + // Sanity-check contract (decoder must not overflow our buffer) if (outlen > max_out) { Py_DECREF(out_bytes); @@ -164,7 +225,7 @@ b64decode_internal(PyObject *arg) { // with stdlib b64decode. static PyObject * b64decode_handle_invalid_input( - PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen) + PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen, bool freesrc) { // Copy input to a temporary buffer, with non-base64 characters and extra suffix // characters removed @@ -172,6 +233,9 @@ b64decode_handle_invalid_input( char *newbuf = PyMem_Malloc(srclen); if (newbuf == NULL) { Py_DECREF(out_bytes); + if (freesrc) { + PyMem_Free((void *)src); + } return PyErr_NoMemory(); } @@ -208,6 +272,9 @@ b64decode_handle_invalid_input( // Stdlib always performs a non-strict padding check if (newbuf_len % 4 != 0) { + if (freesrc) { + PyMem_Free((void *)src); + } Py_DECREF(out_bytes); PyMem_Free(newbuf); PyErr_SetString(PyExc_ValueError, "Incorrect padding"); @@ -217,6 +284,9 @@ b64decode_handle_invalid_input( size_t outlen = max_out; int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0); PyMem_Free(newbuf); + if (freesrc) { + PyMem_Free((void *)src); + } if (ret != 1) { Py_DECREF(out_bytes); @@ -239,14 +309,22 @@ b64decode_handle_invalid_input( return out_bytes; } - static PyObject* b64decode(PyObject *self, PyObject *const *args, size_t nargs) { if (nargs != 1) { PyErr_SetString(PyExc_TypeError, "b64decode() takes exactly one argument"); return 0; } - return b64decode_internal(args[0]); + return b64decode_internal(args[0], false); +} + +static PyObject* +urlsafe_b64decode(PyObject *self, PyObject *const *args, size_t nargs) { + if (nargs != 1) { + PyErr_SetString(PyExc_TypeError, "urlsafe_b64decode() takes exactly one argument"); + return 0; + } + return b64decode_internal(args[0], true); } #endif @@ -255,6 +333,8 @@ static PyMethodDef librt_base64_module_methods[] = { #ifdef MYPYC_EXPERIMENTAL {"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using Base64.")}, {"b64decode", (PyCFunction)b64decode, METH_FASTCALL, PyDoc_STR("Decode a Base64 encoded bytes object or ASCII string.")}, + {"urlsafe_b64encode", (PyCFunction)urlsafe_b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using URL and file system safe Base64 alphabet.")}, + {"urlsafe_b64decode", (PyCFunction)urlsafe_b64decode, METH_FASTCALL, PyDoc_STR("Decode bytes or ASCII string using URL and file system safe Base64 alphabet.")}, #endif {NULL, NULL, 0, NULL} }; diff --git a/mypyc/lib-rt/librt_base64.h b/mypyc/lib-rt/librt_base64.h index 177cd0c1cfef..fedfefd9a38a 100644 --- a/mypyc/lib-rt/librt_base64.h +++ b/mypyc/lib-rt/librt_base64.h @@ -12,16 +12,16 @@ import_librt_base64(void) #else // MYPYC_EXPERIMENTAL -#define LIBRT_BASE64_ABI_VERSION 0 -#define LIBRT_BASE64_API_VERSION 1 +#define LIBRT_BASE64_ABI_VERSION 1 +#define LIBRT_BASE64_API_VERSION 2 #define LIBRT_BASE64_API_LEN 4 static void *LibRTBase64_API[LIBRT_BASE64_API_LEN]; #define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0]) #define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1]) -#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source)) LibRTBase64_API[2]) -#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source)) LibRTBase64_API[3]) +#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2]) +#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3]) static int import_librt_base64(void) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 1b4c438b58fe..31c053af17e3 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -473,6 +473,18 @@ return_type=bytes_rprimitive, c_function_name="LibRTBase64_b64encode_internal", error_kind=ERR_MAGIC, + extra_int_constants=[(0, bool_rprimitive)], + experimental=True, + capsule="librt.base64", +) + +function_op( + name="librt.base64.urlsafe_b64encode", + arg_types=[bytes_rprimitive], + return_type=bytes_rprimitive, + c_function_name="LibRTBase64_b64encode_internal", + error_kind=ERR_MAGIC, + extra_int_constants=[(1, bool_rprimitive)], experimental=True, capsule="librt.base64", ) @@ -483,6 +495,18 @@ return_type=bytes_rprimitive, c_function_name="LibRTBase64_b64decode_internal", error_kind=ERR_MAGIC, + extra_int_constants=[(0, bool_rprimitive)], + experimental=True, + capsule="librt.base64", +) + +function_op( + name="librt.base64.urlsafe_b64decode", + arg_types=[RUnion([bytes_rprimitive, str_rprimitive])], + return_type=bytes_rprimitive, + c_function_name="LibRTBase64_b64decode_internal", + error_kind=ERR_MAGIC, + extra_int_constants=[(1, bool_rprimitive)], experimental=True, capsule="librt.base64", ) diff --git a/mypyc/test-data/irbuild-base64.test b/mypyc/test-data/irbuild-base64.test index 4d41f2912700..f81bd5a79b17 100644 --- a/mypyc/test-data/irbuild-base64.test +++ b/mypyc/test-data/irbuild-base64.test @@ -1,5 +1,5 @@ [case testBase64_experimental] -from librt.base64 import b64encode, b64decode +from librt.base64 import b64encode, b64decode, urlsafe_b64encode, urlsafe_b64decode def enc(b: bytes) -> bytes: return b64encode(b) @@ -9,22 +9,47 @@ def dec_bytes(b: bytes) -> bytes: def dec_str(b: str) -> bytes: return b64decode(b) + +def uenc(b: bytes) -> bytes: + return urlsafe_b64encode(b) + +def udec_bytes(b: bytes) -> bytes: + return urlsafe_b64decode(b) + +def udec_str(b: str) -> bytes: + return urlsafe_b64decode(b) [out] def enc(b): b, r0 :: bytes L0: - r0 = LibRTBase64_b64encode_internal(b) + r0 = LibRTBase64_b64encode_internal(b, 0) return r0 def dec_bytes(b): b, r0 :: bytes L0: - r0 = LibRTBase64_b64decode_internal(b) + r0 = LibRTBase64_b64decode_internal(b, 0) return r0 def dec_str(b): b :: str r0 :: bytes L0: - r0 = LibRTBase64_b64decode_internal(b) + r0 = LibRTBase64_b64decode_internal(b, 0) + return r0 +def uenc(b): + b, r0 :: bytes +L0: + r0 = LibRTBase64_b64encode_internal(b, 1) + return r0 +def udec_bytes(b): + b, r0 :: bytes +L0: + r0 = LibRTBase64_b64decode_internal(b, 1) + return r0 +def udec_str(b): + b :: str + r0 :: bytes +L0: + r0 = LibRTBase64_b64decode_internal(b, 1) return r0 [case testBase64ExperimentalDisabled] diff --git a/mypyc/test-data/run-base64.test b/mypyc/test-data/run-base64.test index 8d7eb7c13482..bf8ea4590e5e 100644 --- a/mypyc/test-data/run-base64.test +++ b/mypyc/test-data/run-base64.test @@ -2,11 +2,14 @@ from typing import Any import base64 import binascii +import random -from librt.base64 import b64encode, b64decode +from librt.base64 import b64encode, b64decode, urlsafe_b64encode, urlsafe_b64decode from testutil import assertRaises +rand_values = [random.randbytes(random.randint(1, 2000)) for _ in range(2000)] + def test_encode_basic() -> None: assert b64encode(b"x") == b"eA==" @@ -35,15 +38,19 @@ def test_encode_different_strings() -> None: for b in b"", b"ab", b"bac", b"1234", b"xyz88", b"abc" * 200: check_encode(b) -def test_encode_wrapper() -> None: - enc: Any = b64encode - assert enc(b"x") == b"eA==" + for b in rand_values: + check_encode(b) - with assertRaises(TypeError): - enc() +def test_encode_wrappers() -> None: + funcs: list[Any] = [b64encode, urlsafe_b64encode] + for enc in funcs: + assert enc(b"x") == b"eA==" - with assertRaises(TypeError): - enc(b"x", b"y") + with assertRaises(TypeError): + enc() + + with assertRaises(TypeError): + enc(b"x", b"y") def test_decode_basic() -> None: assert b64decode(b"eA==") == b"x" @@ -84,6 +91,9 @@ def test_decode_different_strings() -> None: for b in b"", b"ab", b"bac", b"1234", b"xyz88", b"abc" * 200: check_decode(b) + for b in rand_values: + check_decode(b) + def is_base64_char(x: int) -> bool: c = chr(x) return ('a' <= c <= 'z') or ('A' <= c <= 'Z') or ('0' <= c <= '9') or c in '+/=' @@ -140,15 +150,59 @@ def test_decode_with_extra_data_after_padding() -> None: check_decode(b"eHk=x", encoded=True) check_decode(b"eA==abc=======efg", encoded=True) -def test_decode_wrapper() -> None: - dec: Any = b64decode - assert dec(b"eA==") == b"x" - - with assertRaises(TypeError): - dec() - - with assertRaises(TypeError): - dec(b"x", b"y") +def test_decode_wrappers() -> None: + funcs: list[Any] = [b64decode, urlsafe_b64decode] + for dec in funcs: + assert dec(b"eA==") == b"x" + + with assertRaises(TypeError): + dec() + + with assertRaises(TypeError): + dec(b"x", b"y") + +def check_urlsafe_encode(b: bytes) -> None: + assert urlsafe_b64encode(b) == getattr(base64, "urlsafe_b64encode")(b) + +def test_urlsafe_b64encode() -> None: + check_urlsafe_encode(b"") + check_urlsafe_encode(b"a") + check_urlsafe_encode(b"\xf8") + check_urlsafe_encode(b"\xfc") + check_urlsafe_encode(b"\xfcx") + check_urlsafe_encode(b"\xfcxy") + check_urlsafe_encode(b"\xfcxyz") + check_urlsafe_encode(bytes([x for x in range(256)])) + + for b in rand_values: + check_urlsafe_encode(b) + +def check_urlsafe_decode(b: bytes) -> None: + enc = urlsafe_b64encode(b) + assert urlsafe_b64decode(enc) == getattr(base64, "urlsafe_b64decode")(enc) + enc2 = b64encode(b) + assert urlsafe_b64decode(enc2) == getattr(base64, "urlsafe_b64decode")(enc2) + +def test_urlsafe_b64decode() -> None: + # Don't test everything, since the implementation is mostly shared with b64decode. + check_urlsafe_decode(b"") + check_urlsafe_decode(b"a") + check_urlsafe_decode(b"\xf8") + check_urlsafe_decode(b"\xfc") + check_urlsafe_decode(b"\xfcx") + check_urlsafe_decode(b"\xfcxy") + check_urlsafe_decode(b"\xfcxyz") + check_urlsafe_decode(bytes([x for x in range(256)])) + + for b in rand_values: + check_urlsafe_decode(b) + + assert urlsafe_b64decode(b" e A = == !") == b"x" + +def test_urlsafe_b64decode_errors() -> None: + for b in b"eA", b"eA=", b"eHk": + with assertRaises(ValueError): + b64decode(b) [case testBase64FeaturesNotAvailableInNonExperimentalBuild_librt_base64] # This also ensures librt.base64 can be built without experimental features