Skip to content

Commit e6f8566

Browse files
committed
Fix urlsafe_b64decode bug and test more
1 parent 1eb4d3f commit e6f8566

File tree

2 files changed

+49
-18
lines changed

2 files changed

+49
-18
lines changed

mypyc/lib-rt/librt_base64.c

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ b64decode_handle_invalid_input(
1616
#define STACK_BUFFER_SIZE 1024
1717

1818
static void
19-
convert_encoded_to_urlsafe(char *buf, size_t actual_len) {
19+
convert_encoded_to_urlsafe(char *buf, size_t len) {
2020
// The loop is written to enable SIMD optimizations
21-
for (size_t i = 0; i < actual_len; i++) {
21+
for (size_t i = 0; i < len; i++) {
2222
char ch = buf[i];
2323
if (ch == '+') {
2424
ch = '-';
@@ -30,9 +30,9 @@ convert_encoded_to_urlsafe(char *buf, size_t actual_len) {
3030
}
3131

3232
static void
33-
convert_urlsafe_to_encoded(const char *src, size_t actual_len, char *buf) {
33+
convert_urlsafe_to_encoded(const char *src, size_t len, char *buf) {
3434
// The loop is written to enable SIMD optimizations
35-
for (size_t i = 0; i < actual_len; i++) {
35+
for (size_t i = 0; i < len; i++) {
3636
char ch = src[i];
3737
if (ch == '-') {
3838
ch = '+';
@@ -144,6 +144,15 @@ b64decode_internal(PyObject *arg, bool urlsafe) {
144144
return PyBytes_FromStringAndSize(NULL, 0);
145145
}
146146

147+
if (urlsafe) {
148+
char *new_src = PyMem_Malloc(srclen_ssz + 1);
149+
if (new_src == NULL) {
150+
return PyErr_NoMemory();
151+
}
152+
convert_urlsafe_to_encoded(src, srclen_ssz, new_src);
153+
src = new_src;
154+
}
155+
147156
// Quickly ignore invalid characters at the end. Other invalid characters
148157
// are also accepted, but they need a slow path.
149158
while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1], true)) {
@@ -162,15 +171,6 @@ b64decode_internal(PyObject *arg, bool urlsafe) {
162171
return NULL;
163172
}
164173

165-
if (urlsafe) {
166-
char *new_src = PyMem_Malloc(srclen_ssz);
167-
if (new_src == NULL) {
168-
return PyErr_NoMemory();
169-
}
170-
convert_urlsafe_to_encoded(src, srclen_ssz, new_src);
171-
src = new_src;
172-
}
173-
174174
// Allocate output bytes (uninitialized) of the max capacity
175175
PyObject *out_bytes = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)max_out);
176176
if (out_bytes == NULL) {

mypyc/test-data/run-base64.test

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
from typing import Any
33
import base64
44
import binascii
5+
import random
56

67
from librt.base64 import b64encode, b64decode, urlsafe_b64encode, urlsafe_b64decode
78

89
from testutil import assertRaises
910

11+
rand_values = [random.randbytes(random.randint(1, 2000)) for _ in range(2000)]
12+
1013
def test_encode_basic() -> None:
1114
assert b64encode(b"x") == b"eA=="
1215

@@ -35,6 +38,9 @@ def test_encode_different_strings() -> None:
3538
for b in b"", b"ab", b"bac", b"1234", b"xyz88", b"abc" * 200:
3639
check_encode(b)
3740

41+
for b in rand_values:
42+
check_encode(b)
43+
3844
def test_encode_wrapper() -> None:
3945
enc: Any = b64encode
4046
assert enc(b"x") == b"eA=="
@@ -84,6 +90,9 @@ def test_decode_different_strings() -> None:
8490
for b in b"", b"ab", b"bac", b"1234", b"xyz88", b"abc" * 200:
8591
check_decode(b)
8692

93+
for b in rand_values:
94+
check_decode(b)
95+
8796
def is_base64_char(x: int) -> bool:
8897
c = chr(x)
8998
return ('a' <= c <= 'z') or ('A' <= c <= 'Z') or ('0' <= c <= '9') or c in '+/='
@@ -150,18 +159,40 @@ def test_decode_wrapper() -> None:
150159
with assertRaises(TypeError):
151160
dec(b"x", b"y")
152161

153-
def check_urlsafe_b64encode(b: bytes) -> None:
162+
def check_urlsafe_encode(b: bytes) -> None:
154163
assert urlsafe_b64encode(b) == getattr(base64, "urlsafe_b64encode")(b)
155164

156165
def test_urlsafe_b64encode() -> None:
157-
check_urlsafe_b64encode(bytes([x for x in range(256)]))
158-
159-
def check_urlsafe_b64decode(b: bytes) -> None:
166+
check_urlsafe_encode(b"")
167+
check_urlsafe_encode(b"a")
168+
check_urlsafe_encode(b"\xf8")
169+
check_urlsafe_encode(b"\xfc")
170+
check_urlsafe_encode(b"\xfcx")
171+
check_urlsafe_encode(b"\xfcxy")
172+
check_urlsafe_encode(b"\xfcxyz")
173+
check_urlsafe_encode(bytes([x for x in range(256)]))
174+
175+
for b in rand_values:
176+
check_urlsafe_encode(b)
177+
178+
def check_urlsafe_decode(b: bytes) -> None:
160179
enc = urlsafe_b64encode(b)
161180
assert urlsafe_b64decode(enc) == getattr(base64, "urlsafe_b64decode")(enc)
181+
enc2 = b64encode(b)
182+
assert urlsafe_b64decode(enc2) == getattr(base64, "urlsafe_b64decode")(enc2)
162183

163184
def test_urlsafe_b64decode() -> None:
164-
check_urlsafe_b64decode(bytes([x for x in range(256)]))
185+
check_urlsafe_decode(b"")
186+
check_urlsafe_decode(b"a")
187+
check_urlsafe_decode(b"\xf8")
188+
check_urlsafe_decode(b"\xfc")
189+
check_urlsafe_decode(b"\xfcx")
190+
check_urlsafe_decode(b"\xfcxy")
191+
check_urlsafe_decode(b"\xfcxyz")
192+
check_urlsafe_decode(bytes([x for x in range(256)]))
193+
194+
for b in rand_values:
195+
check_urlsafe_decode(b)
165196

166197
[case testBase64FeaturesNotAvailableInNonExperimentalBuild_librt_base64]
167198
# This also ensures librt.base64 can be built without experimental features

0 commit comments

Comments
 (0)