Skip to content

Commit 836c5f3

Browse files
committed
Filter out invalid base64 characters
1 parent aeca40a commit 836c5f3

File tree

2 files changed

+69
-4
lines changed

2 files changed

+69
-4
lines changed

mypyc/lib-rt/librt_base64.c

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
#ifdef MYPYC_EXPERIMENTAL
88

9+
static PyObject *
10+
b64decode_handle_invalid(PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen);
11+
912
#define BASE64_MAXBIN ((PY_SSIZE_T_MAX - 3) / 2)
1013

1114
#define STACK_BUFFER_SIZE 1024
@@ -63,6 +66,12 @@ b64encode(PyObject *self, PyObject *const *args, size_t nargs) {
6366
return b64encode_internal(args[0]);
6467
}
6568

69+
static inline int
70+
is_valid_base64_char(char c) {
71+
return ((c >= 'A' && c <= 'Z') | (c >= 'a' && c <= 'z') |
72+
(c >= '0' && c <= '9') | (c == '+') | (c == '/') | (c == '='));
73+
}
74+
6675
static PyObject *
6776
b64decode_internal(PyObject *arg) {
6877
const char *src;
@@ -91,6 +100,12 @@ b64decode_internal(PyObject *arg) {
91100
return PyBytes_FromStringAndSize(NULL, 0);
92101
}
93102

103+
// Quickly ignore invalid characters at the end. Other invalid characters
104+
// are also accepted, but they need a slow path.
105+
while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1])) {
106+
srclen_ssz--;
107+
}
108+
94109
// Compute an output capacity that's at least 3/4 of input, without overflow:
95110
// ceil(3/4 * N) == N - floor(N/4)
96111
size_t srclen = (size_t)srclen_ssz;
@@ -112,14 +127,14 @@ b64decode_internal(PyObject *arg) {
112127
char *outbuf = PyBytes_AS_STRING(out_bytes);
113128
size_t outlen = max_out;
114129

115-
// Decode (flags = 0 for plain input)
116130
int ret = base64_decode(src, srclen, outbuf, &outlen, 0);
117131

118132
if (ret != 1) {
119-
Py_DECREF(out_bytes);
120133
if (ret == 0) {
121-
PyErr_SetString(PyExc_ValueError, "Only base64 data is allowed");
122-
} else if (ret == -1) {
134+
return b64decode_handle_invalid(out_bytes, outbuf, max_out, src, srclen);
135+
}
136+
Py_DECREF(out_bytes);
137+
if (ret == -1) {
123138
PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build");
124139
} else {
125140
PyErr_SetString(PyExc_RuntimeError, "base64_decode failed");
@@ -149,6 +164,45 @@ b64decode_internal(PyObject *arg) {
149164
#endif
150165
}
151166

167+
static PyObject *
168+
b64decode_handle_invalid(PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen)
169+
{
170+
size_t i;
171+
char *newbuf = PyMem_Malloc(srclen);
172+
size_t newbuf_len = 0;
173+
for (i = 0; i < srclen; i++) {
174+
char c = src[i];
175+
if (is_valid_base64_char(c)) {
176+
newbuf[newbuf_len++] = c;
177+
}
178+
}
179+
180+
size_t outlen = max_out;
181+
int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0);
182+
PyMem_Free(newbuf);
183+
184+
if (ret != 1) {
185+
Py_DECREF(out_bytes);
186+
if (ret == 0) {
187+
PyErr_SetString(PyExc_ValueError, "Only base64 data is allowed");
188+
}
189+
if (ret == -1) {
190+
PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build");
191+
} else {
192+
PyErr_SetString(PyExc_RuntimeError, "base64_decode failed");
193+
}
194+
return NULL;
195+
}
196+
197+
// Shrink in place to the actual decoded length
198+
if (_PyBytes_Resize(&out_bytes, (Py_ssize_t)outlen) < 0) {
199+
// _PyBytes_Resize sets an exception and may free the old object
200+
return NULL;
201+
}
202+
return out_bytes;
203+
}
204+
205+
152206
static PyObject*
153207
b64decode(PyObject *self, PyObject *const *args, size_t nargs) {
154208
if (nargs != 1) {

mypyc/test-data/run-base64.test

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ def test_decode_different_strings() -> None:
8282
for b in b"", b"ab", b"bac", b"1234", b"xyz88", b"abc" * 200:
8383
check_decode(b)
8484

85+
def test_decode_with_non_base64_chars() -> None:
86+
# For stdlib compatibility, non-base64 characters should be ignored.
87+
88+
# Invalid characters as a suffix use a fast path.
89+
check_decode(b"eA== ", encoded=True)
90+
check_decode(b"eA==\n", encoded=True)
91+
check_decode(b"eA== \t\n", encoded=True)
92+
check_decode(b"\n", encoded=True)
93+
94+
check_decode(b" e A = = ", encoded=True)
95+
8596
def test_decode_wrapper() -> None:
8697
dec: Any = b64decode
8798
assert dec(b"eA==") == b"x"

0 commit comments

Comments
 (0)