Skip to content

Commit 0c52510

Browse files
serhiy-storchakamcepl
authored andcommitted
Fix use-after-free in the unicode-escape decoder with error handler
If the error handler is used, a new bytes object is created to set as the object attribute of UnicodeDecodeError, and that bytes object then replaces the original data. A pointer to the decoded data will became invalid after destroying that temporary bytes object. So we need other way to return the first invalid escape from _PyUnicode_DecodeUnicodeEscapeInternal(). _PyBytes_DecodeEscape() does not have such issue, because it does not use the error handlers registry, but it should be changed for compatibility with _PyUnicode_DecodeUnicodeEscapeInternal().
1 parent 310cd89 commit 0c52510

File tree

11 files changed

+182
-67
lines changed

11 files changed

+182
-67
lines changed

Include/internal/pycore_bytesobject.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ extern "C" {
88
# error "this header requires Py_BUILD_CORE define"
99
#endif
1010

11+
// Helper for PyBytes_DecodeEscape that detects invalid escape chars.
12+
// Export for test_peg_generator.
13+
PyAPI_FUNC(PyObject*) _PyBytes_DecodeEscape2(const char *, Py_ssize_t,
14+
const char *,
15+
int *, const char **);
1116

1217
/* Substring Search.
1318

Include/internal/pycore_unicodeobject.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ extern void _PyUnicode_ClearInterned(PyInterpreterState *interp);
7979
// Like PyUnicode_AsUTF8(), but check for embedded null characters.
8080
extern const char* _PyUnicode_AsUTF8NoNUL(PyObject *);
8181

82+
// Helper for PyUnicode_DecodeUnicodeEscape that detects invalid escape
83+
// chars.
84+
// Export for test_peg_generator.
85+
PyAPI_FUNC(PyObject*) _PyUnicode_DecodeUnicodeEscapeInternal2(
86+
const char *string, /* Unicode-Escape encoded string */
87+
Py_ssize_t length, /* size of string */
88+
const char *errors, /* error handling */
89+
Py_ssize_t *consumed, /* bytes consumed */
90+
int *first_invalid_escape_char, /* on return, if not -1, contain the first
91+
invalid escaped char (<= 0xff) or invalid
92+
octal escape (> 0xff) in string. */
93+
const char **first_invalid_escape_ptr); /* on return, if not NULL, may
94+
point to the first invalid escaped
95+
char in string.
96+
May be NULL if errors is not NULL. */
97+
8298

8399
#ifdef __cplusplus
84100
}

Lib/test/test_codeccallbacks.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import codecs
22
import html.entities
33
import itertools
4+
import re
45
import sys
56
import unicodedata
67
import unittest
@@ -1124,7 +1125,7 @@ def test_bug828737(self):
11241125
text = 'abc<def>ghi'*n
11251126
text.translate(charmap)
11261127

1127-
def test_mutatingdecodehandler(self):
1128+
def test_mutating_decode_handler(self):
11281129
baddata = [
11291130
("ascii", b"\xff"),
11301131
("utf-7", b"++"),
@@ -1159,6 +1160,42 @@ def mutating(exc):
11591160
for (encoding, data) in baddata:
11601161
self.assertEqual(data.decode(encoding, "test.mutating"), "\u4242")
11611162

1163+
def test_mutating_decode_handler_unicode_escape(self):
1164+
decode = codecs.unicode_escape_decode
1165+
def mutating(exc):
1166+
if isinstance(exc, UnicodeDecodeError):
1167+
r = data.get(exc.object[:exc.end])
1168+
if r is not None:
1169+
exc.object = r[0] + exc.object[exc.end:]
1170+
return ('\u0404', r[1])
1171+
raise AssertionError("don't know how to handle %r" % exc)
1172+
1173+
codecs.register_error('test.mutating2', mutating)
1174+
data = {
1175+
br'\x0': (b'\\', 0),
1176+
br'\x3': (b'xxx\\', 3),
1177+
br'\x5': (b'x\\', 1),
1178+
}
1179+
def check(input, expected, msg):
1180+
with self.assertWarns(DeprecationWarning) as cm:
1181+
self.assertEqual(decode(input, 'test.mutating2'), (expected, len(input)))
1182+
self.assertIn(msg, str(cm.warning))
1183+
1184+
check(br'\x0n\z', '\u0404\n\\z', r"invalid escape sequence")
1185+
check(br'\x0n\501', '\u0404\n\u0141', r'invalid octal escape sequence')
1186+
check(br'\x0z', '\u0404\\z', r'invalid escape sequence')
1187+
1188+
check(br'\x3n\zr', '\u0404\n\\zr', r'invalid escape sequence')
1189+
check(br'\x3zr', '\u0404\\zr', r'invalid escape sequence')
1190+
check(br'\x3z5', '\u0404\\z5', r'invalid escape sequence')
1191+
check(memoryview(br'\x3z5x')[:-1], '\u0404\\z5', r'invalid escape sequence')
1192+
check(memoryview(br'\x3z5xy')[:-2], '\u0404\\z5', r'invalid escape sequence')
1193+
1194+
check(br'\x5n\z', '\u0404\n\\z', r'invalid escape sequence')
1195+
check(br'\x5n\501', '\u0404\n\u0141', r'invalid octal escape sequence')
1196+
check(br'\x5z', '\u0404\\z', r'invalid escape sequence')
1197+
check(memoryview(br'\x5zy')[:-1], '\u0404\\z', r'invalid escape sequence')
1198+
11621199
# issue32583
11631200
def test_crashing_decode_handler(self):
11641201
# better generating one more character to fill the extra space slot

Lib/test/test_codecs.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,23 +1196,39 @@ def test_escape(self):
11961196
check(br"[\1010]", b"[A0]")
11971197
check(br"[\x41]", b"[A]")
11981198
check(br"[\x410]", b"[A0]")
1199+
1200+
def test_warnings(self):
1201+
decode = codecs.escape_decode
1202+
check = coding_checker(self, decode)
11991203
for i in range(97, 123):
12001204
b = bytes([i])
12011205
if b not in b'abfnrtvx':
1202-
with self.assertWarns(DeprecationWarning):
1206+
with self.assertWarnsRegex(DeprecationWarning,
1207+
r"'\\%c' is an invalid escape sequence" % i):
12031208
check(b"\\" + b, b"\\" + b)
1204-
with self.assertWarns(DeprecationWarning):
1209+
with self.assertWarnsRegex(DeprecationWarning,
1210+
r"invalid escape sequence"):
12051211
check(b"\\" + b.upper(), b"\\" + b.upper())
1206-
with self.assertWarns(DeprecationWarning):
1212+
with self.assertWarnsRegex(DeprecationWarning,
1213+
r"'\\8' is an invalid escape sequence"):
12071214
check(br"\8", b"\\8")
12081215
with self.assertWarns(DeprecationWarning):
12091216
check(br"\9", b"\\9")
1210-
with self.assertWarns(DeprecationWarning):
1217+
with self.assertWarnsRegex(DeprecationWarning,
1218+
r'invalid escape sequence') as cm:
12111219
check(b"\\\xfa", b"\\\xfa")
12121220
for i in range(0o400, 0o1000):
1213-
with self.assertWarns(DeprecationWarning):
1221+
with self.assertWarnsRegex(DeprecationWarning,
1222+
r'invalid octal escape sequence'):
12141223
check(rb'\%o' % i, bytes([i & 0o377]))
12151224

1225+
with self.assertWarnsRegex(DeprecationWarning,
1226+
r'invalid escape sequence'):
1227+
self.assertEqual(decode(br'\x\z', 'ignore'), (b'\\z', 4))
1228+
with self.assertWarnsRegex(DeprecationWarning,
1229+
r'invalid octal escape sequence'):
1230+
self.assertEqual(decode(br'\x\501', 'ignore'), (b'A', 6))
1231+
12161232
def test_errors(self):
12171233
decode = codecs.escape_decode
12181234
self.assertRaises(ValueError, decode, br"\x")
@@ -2479,24 +2495,40 @@ def test_escape_decode(self):
24792495
check(br"[\x410]", "[A0]")
24802496
check(br"\u20ac", "\u20ac")
24812497
check(br"\U0001d120", "\U0001d120")
2498+
2499+
def test_decode_warnings(self):
2500+
decode = codecs.unicode_escape_decode
2501+
check = coding_checker(self, decode)
24822502
for i in range(97, 123):
24832503
b = bytes([i])
24842504
if b not in b'abfnrtuvx':
2485-
with self.assertWarns(DeprecationWarning):
2505+
with self.assertWarnsRegex(DeprecationWarning,
2506+
r'invalid escape sequence'):
24862507
check(b"\\" + b, "\\" + chr(i))
24872508
if b.upper() not in b'UN':
2488-
with self.assertWarns(DeprecationWarning):
2509+
with self.assertWarnsRegex(DeprecationWarning,
2510+
'invalid escape sequence'):
24892511
check(b"\\" + b.upper(), "\\" + chr(i-32))
2490-
with self.assertWarns(DeprecationWarning):
2512+
with self.assertWarnsRegex(DeprecationWarning,
2513+
r'invalid escape sequence'):
24912514
check(br"\8", "\\8")
24922515
with self.assertWarns(DeprecationWarning):
24932516
check(br"\9", "\\9")
2494-
with self.assertWarns(DeprecationWarning):
2517+
with self.assertWarnsRegex(DeprecationWarning,
2518+
r'invalid escape sequence') as cm:
24952519
check(b"\\\xfa", "\\\xfa")
24962520
for i in range(0o400, 0o1000):
2497-
with self.assertWarns(DeprecationWarning):
2521+
with self.assertWarnsRegex(DeprecationWarning,
2522+
r'invalid octal escape sequence'):
24982523
check(rb'\%o' % i, chr(i))
24992524

2525+
with self.assertWarnsRegex(DeprecationWarning,
2526+
r'invalid escape sequence'):
2527+
self.assertEqual(decode(br'\x\z', 'ignore'), ('\\z', 4))
2528+
with self.assertWarnsRegex(DeprecationWarning,
2529+
r'invalid octal escape sequence'):
2530+
self.assertEqual(decode(br'\x\501', 'ignore'), ('\u0141', 6))
2531+
25002532
def test_decode_errors(self):
25012533
decode = codecs.unicode_escape_decode
25022534
for c, d in (b'x', 2), (b'u', 4), (b'U', 4):

Lib/test/test_codeop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def test_filename(self):
281281
def test_warning(self):
282282
# Test that the warning is only returned once.
283283
with warnings_helper.check_warnings(
284-
('"is" with \'str\' literal', SyntaxWarning),
285-
("invalid escape sequence", SyntaxWarning),
284+
(r'"is" with.*literal', SyntaxWarning),
285+
(r'invalid escape sequence', SyntaxWarning),
286286
) as w:
287287
compile_command(r"'\e' is 0")
288288
self.assertEqual(len(w.warnings), 2)

Lib/test/test_string_literals.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def test_eval_str_invalid_escape(self):
116116
warnings.simplefilter('always', category=SyntaxWarning)
117117
eval("'''\n\\z'''")
118118
self.assertEqual(len(w), 1)
119-
self.assertEqual(str(w[0].message), r"invalid escape sequence '\z'")
119+
self.assertEqual(str(w[0].message), r"'\z' is an invalid escape sequence. ")
120120
self.assertEqual(w[0].filename, '<string>')
121121
self.assertEqual(w[0].lineno, 2)
122122

@@ -153,7 +153,7 @@ def test_eval_str_invalid_octal_escape(self):
153153
eval("'''\n\\407'''")
154154
self.assertEqual(len(w), 1)
155155
self.assertEqual(str(w[0].message),
156-
r"invalid octal escape sequence '\407'")
156+
r"'\407' is an invalid octal escape sequence. ")
157157
self.assertEqual(w[0].filename, '<string>')
158158
self.assertEqual(w[0].lineno, 2)
159159

@@ -228,7 +228,7 @@ def test_eval_bytes_invalid_escape(self):
228228
warnings.simplefilter('always', category=SyntaxWarning)
229229
eval("b'''\n\\z'''")
230230
self.assertEqual(len(w), 1)
231-
self.assertEqual(str(w[0].message), r"invalid escape sequence '\z'")
231+
self.assertEqual(str(w[0].message), r"'\z' is an invalid escape sequence. ")
232232
self.assertEqual(w[0].filename, '<string>')
233233
self.assertEqual(w[0].lineno, 2)
234234

@@ -252,7 +252,7 @@ def test_eval_bytes_invalid_octal_escape(self):
252252
eval("b'''\n\\407'''")
253253
self.assertEqual(len(w), 1)
254254
self.assertEqual(str(w[0].message),
255-
r"invalid octal escape sequence '\407'")
255+
r"'\407' is an invalid octal escape sequence. ")
256256
self.assertEqual(w[0].filename, '<string>')
257257
self.assertEqual(w[0].lineno, 2)
258258

Lib/test/test_unparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def test_multiquote_joined_string(self):
653653

654654
def test_backslash_in_format_spec(self):
655655
import re
656-
msg = re.escape("invalid escape sequence '\\ '")
656+
msg = re.escape("invalid escape sequence")
657657
with self.assertWarnsRegex(SyntaxWarning, msg):
658658
self.check_ast_roundtrip("""f"{x:\\ }" """)
659659
self.check_ast_roundtrip("""f"{x:\\n}" """)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix use-after-free in the "unicode-escape" decoder with a non-"strict" error
2+
handler.

Objects/bytesobject.c

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,10 +1048,11 @@ _PyBytes_FormatEx(const char *format, Py_ssize_t format_len,
10481048
}
10491049

10501050
/* Unescape a backslash-escaped string. */
1051-
PyObject *_PyBytes_DecodeEscape(const char *s,
1051+
PyObject *_PyBytes_DecodeEscape2(const char *s,
10521052
Py_ssize_t len,
10531053
const char *errors,
1054-
const char **first_invalid_escape)
1054+
int *first_invalid_escape_char,
1055+
const char **first_invalid_escape_ptr)
10551056
{
10561057
int c;
10571058
char *p;
@@ -1065,7 +1066,8 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
10651066
return NULL;
10661067
writer.overallocate = 1;
10671068

1068-
*first_invalid_escape = NULL;
1069+
*first_invalid_escape_char = -1;
1070+
*first_invalid_escape_ptr = NULL;
10691071

10701072
end = s + len;
10711073
while (s < end) {
@@ -1103,9 +1105,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
11031105
c = (c<<3) + *s++ - '0';
11041106
}
11051107
if (c > 0377) {
1106-
if (*first_invalid_escape == NULL) {
1107-
*first_invalid_escape = s-3; /* Back up 3 chars, since we've
1108-
already incremented s. */
1108+
if (*first_invalid_escape_char == -1) {
1109+
*first_invalid_escape_char = c;
1110+
/* Back up 3 chars, since we've already incremented s. */
1111+
*first_invalid_escape_ptr = s - 3;
11091112
}
11101113
}
11111114
*p++ = c;
@@ -1146,9 +1149,10 @@ PyObject *_PyBytes_DecodeEscape(const char *s,
11461149
break;
11471150

11481151
default:
1149-
if (*first_invalid_escape == NULL) {
1150-
*first_invalid_escape = s-1; /* Back up one char, since we've
1151-
already incremented s. */
1152+
if (*first_invalid_escape_char == -1) {
1153+
*first_invalid_escape_char = (unsigned char)s[-1];
1154+
/* Back up one char, since we've already incremented s. */
1155+
*first_invalid_escape_ptr = s - 1;
11521156
}
11531157
*p++ = '\\';
11541158
s--;
@@ -1168,26 +1172,27 @@ PyObject *PyBytes_DecodeEscape(const char *s,
11681172
Py_ssize_t Py_UNUSED(unicode),
11691173
const char *Py_UNUSED(recode_encoding))
11701174
{
1171-
const char* first_invalid_escape;
1172-
PyObject *result = _PyBytes_DecodeEscape(s, len, errors,
1173-
&first_invalid_escape);
1175+
int first_invalid_escape_char;
1176+
const char *first_invalid_escape_ptr;
1177+
PyObject *result = _PyBytes_DecodeEscape2(s, len, errors,
1178+
&first_invalid_escape_char,
1179+
&first_invalid_escape_ptr);
11741180
if (result == NULL)
11751181
return NULL;
1176-
if (first_invalid_escape != NULL) {
1177-
unsigned char c = *first_invalid_escape;
1178-
if ('4' <= c && c <= '7') {
1182+
if (first_invalid_escape_char != -1) {
1183+
if (first_invalid_escape_char > 0xff) {
11791184
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
1180-
"invalid octal escape sequence '\\%.3s'",
1181-
first_invalid_escape) < 0)
1185+
"'\\%o' is an invalid octal escape sequence. ",
1186+
first_invalid_escape_char) < 0)
11821187
{
11831188
Py_DECREF(result);
11841189
return NULL;
11851190
}
11861191
}
11871192
else {
11881193
if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1,
1189-
"invalid escape sequence '\\%c'",
1190-
c) < 0)
1194+
"'\\%c' is an invalid escape sequence. ",
1195+
first_invalid_escape_char) < 0)
11911196
{
11921197
Py_DECREF(result);
11931198
return NULL;

0 commit comments

Comments
 (0)