Skip to content

Commit cc45fde

Browse files
committed
Check for invalid padding
1 parent 836c5f3 commit cc45fde

File tree

2 files changed

+66
-4
lines changed

2 files changed

+66
-4
lines changed

mypyc/lib-rt/librt_base64.c

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#define PY_SSIZE_T_CLEAN
22
#include <Python.h>
3+
#include <stdbool.h>
34
#include "librt_base64.h"
45
#include "libbase64.h"
56
#include "pythoncapi_compat.h"
@@ -67,9 +68,9 @@ b64encode(PyObject *self, PyObject *const *args, size_t nargs) {
6768
}
6869

6970
static inline int
70-
is_valid_base64_char(char c) {
71+
is_valid_base64_char(char c, bool allow_padding) {
7172
return ((c >= 'A' && c <= 'Z') | (c >= 'a' && c <= 'z') |
72-
(c >= '0' && c <= '9') | (c == '+') | (c == '/') | (c == '='));
73+
(c >= '0' && c <= '9') | (c == '+') | (c == '/') | (allow_padding && c == '='));
7374
}
7475

7576
static PyObject *
@@ -102,7 +103,7 @@ b64decode_internal(PyObject *arg) {
102103

103104
// Quickly ignore invalid characters at the end. Other invalid characters
104105
// are also accepted, but they need a slow path.
105-
while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1])) {
106+
while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1], true)) {
106107
srclen_ssz--;
107108
}
108109

@@ -172,11 +173,38 @@ b64decode_handle_invalid(PyObject *out_bytes, char *outbuf, size_t max_out, cons
172173
size_t newbuf_len = 0;
173174
for (i = 0; i < srclen; i++) {
174175
char c = src[i];
175-
if (is_valid_base64_char(c)) {
176+
if (is_valid_base64_char(c, false)) {
176177
newbuf[newbuf_len++] = c;
178+
} else if (c == '=') {
179+
// Copy necessary amount of padding
180+
int remainder = newbuf_len % 4;
181+
if (remainder == 0) {
182+
// No padding needed -- ignore padding
183+
break;
184+
}
185+
int numpad = 4 - remainder;
186+
// Check that there is at least the required amount padding (CPython ignores
187+
// extra padding)
188+
while (numpad > 0) {
189+
if (i == srclen || src[i] != '=') {
190+
break;
191+
}
192+
newbuf[newbuf_len++] = '=';
193+
i++;
194+
numpad--;
195+
while (i < srclen && !is_valid_base64_char(src[i], true)) {
196+
i++;
197+
}
198+
}
199+
break;
177200
}
178201
}
179202

203+
if (newbuf_len % 4 != 0) {
204+
PyErr_SetString(PyExc_ValueError, "Incorrect padding");
205+
return NULL;
206+
}
207+
180208
size_t outlen = max_out;
181209
int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0);
182210
PyMem_Free(newbuf);

mypyc/test-data/run-base64.test

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[case testAllBase64Features_librt_experimental]
22
from typing import Any
33
import base64
4+
import binascii
45

56
from librt.base64 import b64encode, b64decode
67

@@ -93,6 +94,39 @@ def test_decode_with_non_base64_chars() -> None:
9394

9495
check_decode(b" e A = = ", encoded=True)
9596

97+
# Special case: Two different encodings of the same data
98+
check_decode(b"eAa=", encoded=True)
99+
check_decode(b"eAY=", encoded=True)
100+
101+
def check_decode_error(b: bytes, ignore_stdlib: bool = False) -> None:
102+
if not ignore_stdlib:
103+
with assertRaises(binascii.Error):
104+
getattr(base64, "b64decode")(b)
105+
106+
# The raised error is different, since librt shouldn't depend on binascii
107+
with assertRaises(ValueError):
108+
b64decode(b)
109+
110+
def test_decode_with_invalid_padding() -> None:
111+
check_decode_error(b"eA")
112+
check_decode_error(b"eA=")
113+
check_decode_error(b"eHk")
114+
check_decode_error(b"eA = ")
115+
116+
# Here stdlib behavior seems nonsensical, so we don't try to duplicate it
117+
check_decode_error(b"eA=a=", ignore_stdlib=True)
118+
119+
def test_decode_with_extra_data_after_padding() -> None:
120+
check_decode(b"=", encoded=True)
121+
check_decode(b"==", encoded=True)
122+
check_decode(b"===", encoded=True)
123+
check_decode(b"====", encoded=True)
124+
check_decode(b"eA===", encoded=True)
125+
check_decode(b"eHk==", encoded=True)
126+
check_decode(b"eA==x", encoded=True)
127+
check_decode(b"eHk=x", encoded=True)
128+
check_decode(b"eA==abc=======efg", encoded=True)
129+
96130
def test_decode_wrapper() -> None:
97131
dec: Any = b64decode
98132
assert dec(b"eA==") == b"x"

0 commit comments

Comments
 (0)