Skip to content

Commit 68809c0

Browse files
authored
[mypyc] Specialize bytes.decode calls with common encodings (#19688)
This is similar to #18232, which specialized `encode`. A micro-benchmark that calls `decode` repeatedly was up to 45% faster.
1 parent 6c5b13c commit 68809c0

File tree

7 files changed

+233
-14
lines changed

7 files changed

+233
-14
lines changed

mypyc/irbuild/specialize.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
RTuple,
5151
RType,
5252
bool_rprimitive,
53+
bytes_rprimitive,
5354
c_int_rprimitive,
5455
dict_rprimitive,
5556
int16_rprimitive,
@@ -98,6 +99,9 @@
9899
from mypyc.primitives.misc_ops import isinstance_bool
99100
from mypyc.primitives.set_ops import isinstance_frozenset, isinstance_set
100101
from mypyc.primitives.str_ops import (
102+
bytes_decode_ascii_strict,
103+
bytes_decode_latin1_strict,
104+
bytes_decode_utf8_strict,
101105
isinstance_str,
102106
str_encode_ascii_strict,
103107
str_encode_latin1_strict,
@@ -787,6 +791,67 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
787791
return None
788792

789793

794+
@specialize_function("decode", bytes_rprimitive)
795+
def bytes_decode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
796+
"""Specialize common cases of obj.decode for most used encodings and strict errors."""
797+
798+
if not isinstance(callee, MemberExpr):
799+
return None
800+
801+
# We can only specialize if we have string literals as args
802+
if len(expr.arg_kinds) > 0 and not isinstance(expr.args[0], StrExpr):
803+
return None
804+
if len(expr.arg_kinds) > 1 and not isinstance(expr.args[1], StrExpr):
805+
return None
806+
807+
encoding = "utf8"
808+
errors = "strict"
809+
if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr):
810+
if expr.arg_kinds[0] == ARG_NAMED:
811+
if expr.arg_names[0] == "encoding":
812+
encoding = expr.args[0].value
813+
elif expr.arg_names[0] == "errors":
814+
errors = expr.args[0].value
815+
elif expr.arg_kinds[0] == ARG_POS:
816+
encoding = expr.args[0].value
817+
else:
818+
return None
819+
if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr):
820+
if expr.arg_kinds[1] == ARG_NAMED:
821+
if expr.arg_names[1] == "encoding":
822+
encoding = expr.args[1].value
823+
elif expr.arg_names[1] == "errors":
824+
errors = expr.args[1].value
825+
elif expr.arg_kinds[1] == ARG_POS:
826+
errors = expr.args[1].value
827+
else:
828+
return None
829+
830+
if errors != "strict":
831+
# We can only specialize strict errors
832+
return None
833+
834+
encoding = encoding.lower().replace("_", "-") # normalize
835+
# Specialized encodings and their accepted aliases
836+
if encoding in ["u8", "utf", "utf8", "utf-8", "cp65001"]:
837+
return builder.call_c(bytes_decode_utf8_strict, [builder.accept(callee.expr)], expr.line)
838+
elif encoding in ["646", "ascii", "usascii", "us-ascii"]:
839+
return builder.call_c(bytes_decode_ascii_strict, [builder.accept(callee.expr)], expr.line)
840+
elif encoding in [
841+
"iso8859-1",
842+
"iso-8859-1",
843+
"8859",
844+
"cp819",
845+
"latin",
846+
"latin1",
847+
"latin-1",
848+
"l1",
849+
]:
850+
return builder.call_c(bytes_decode_latin1_strict, [builder.accept(callee.expr)], expr.line)
851+
852+
return None
853+
854+
790855
@specialize_function("mypy_extensions.i64")
791856
def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
792857
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:

mypyc/lib-rt/CPy.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,9 @@ PyObject *CPyStr_Removesuffix(PyObject *self, PyObject *suffix);
752752
bool CPyStr_IsTrue(PyObject *obj);
753753
Py_ssize_t CPyStr_Size_size_t(PyObject *str);
754754
PyObject *CPy_Decode(PyObject *obj, PyObject *encoding, PyObject *errors);
755+
PyObject *CPy_DecodeUTF8(PyObject *bytes);
756+
PyObject *CPy_DecodeASCII(PyObject *bytes);
757+
PyObject *CPy_DecodeLatin1(PyObject *bytes);
755758
PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors);
756759
Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start);
757760
Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end);

mypyc/lib-rt/str_ops.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,45 @@ PyObject *CPy_Decode(PyObject *obj, PyObject *encoding, PyObject *errors) {
513513
}
514514
}
515515

516+
PyObject *CPy_DecodeUTF8(PyObject *bytes) {
517+
if (PyBytes_CheckExact(bytes)) {
518+
char *buffer = PyBytes_AsString(bytes); // Borrowed reference
519+
if (buffer == NULL) {
520+
return NULL;
521+
}
522+
Py_ssize_t size = PyBytes_Size(bytes);
523+
return PyUnicode_DecodeUTF8(buffer, size, "strict");
524+
} else {
525+
return PyUnicode_FromEncodedObject(bytes, "utf-8", "strict");
526+
}
527+
}
528+
529+
PyObject *CPy_DecodeASCII(PyObject *bytes) {
530+
if (PyBytes_CheckExact(bytes)) {
531+
char *buffer = PyBytes_AsString(bytes); // Borrowed reference
532+
if (buffer == NULL) {
533+
return NULL;
534+
}
535+
Py_ssize_t size = PyBytes_Size(bytes);
536+
return PyUnicode_DecodeASCII(buffer, size, "strict");;
537+
} else {
538+
return PyUnicode_FromEncodedObject(bytes, "ascii", "strict");
539+
}
540+
}
541+
542+
PyObject *CPy_DecodeLatin1(PyObject *bytes) {
543+
if (PyBytes_CheckExact(bytes)) {
544+
char *buffer = PyBytes_AsString(bytes); // Borrowed reference
545+
if (buffer == NULL) {
546+
return NULL;
547+
}
548+
Py_ssize_t size = PyBytes_Size(bytes);
549+
return PyUnicode_DecodeLatin1(buffer, size, "strict");
550+
} else {
551+
return PyUnicode_FromEncodedObject(bytes, "latin1", "strict");
552+
}
553+
}
554+
516555
PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors) {
517556
const char *enc = NULL;
518557
const char *err = NULL;

mypyc/primitives/str_ops.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@
387387
extra_int_constants=[(0, pointer_rprimitive)],
388388
)
389389

390-
# obj.decode(encoding, errors)
390+
# bytes.decode(encoding, errors)
391391
method_op(
392392
name="decode",
393393
arg_types=[bytes_rprimitive, str_rprimitive, str_rprimitive],
@@ -396,6 +396,30 @@
396396
error_kind=ERR_MAGIC,
397397
)
398398

399+
# bytes.decode(encoding) - utf8 strict specialization
400+
bytes_decode_utf8_strict = custom_op(
401+
arg_types=[bytes_rprimitive],
402+
return_type=str_rprimitive,
403+
c_function_name="CPy_DecodeUTF8",
404+
error_kind=ERR_MAGIC,
405+
)
406+
407+
# bytes.decode(encoding) - ascii strict specialization
408+
bytes_decode_ascii_strict = custom_op(
409+
arg_types=[bytes_rprimitive],
410+
return_type=str_rprimitive,
411+
c_function_name="CPy_DecodeASCII",
412+
error_kind=ERR_MAGIC,
413+
)
414+
415+
# bytes.decode(encoding) - latin1 strict specialization
416+
bytes_decode_latin1_strict = custom_op(
417+
arg_types=[bytes_rprimitive],
418+
return_type=str_rprimitive,
419+
c_function_name="CPy_DecodeLatin1",
420+
error_kind=ERR_MAGIC,
421+
)
422+
399423
# str.encode()
400424
method_op(
401425
name="encode",

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __getitem__(self, i: int) -> int: ...
171171
@overload
172172
def __getitem__(self, i: slice) -> bytes: ...
173173
def join(self, x: Iterable[object]) -> bytes: ...
174-
def decode(self, x: str=..., y: str=...) -> str: ...
174+
def decode(self, encoding: str=..., errors: str=...) -> str: ...
175175
def __iter__(self) -> Iterator[int]: ...
176176

177177
class bytearray:

mypyc/test-data/irbuild-str.test

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,19 +325,43 @@ L0:
325325
[case testDecode]
326326
def f(b: bytes) -> None:
327327
b.decode()
328+
b.decode('Utf_8')
328329
b.decode('utf-8')
330+
b.decode('UTF8')
331+
b.decode('latin1')
332+
b.decode('Latin-1')
333+
b.decode('ascii')
334+
encoding = 'utf-8'
335+
b.decode(encoding)
329336
b.decode('utf-8', 'backslashreplace')
337+
def variants(b: bytes) -> None:
338+
b.decode(encoding="UTF_8")
339+
b.decode("ascii", errors="strict")
330340
[out]
331341
def f(b):
332342
b :: bytes
333-
r0, r1, r2, r3, r4, r5 :: str
343+
r0, r1, r2, r3, r4, r5, r6, r7, encoding, r8, r9, r10, r11 :: str
334344
L0:
335-
r0 = CPy_Decode(b, 0, 0)
336-
r1 = 'utf-8'
337-
r2 = CPy_Decode(b, r1, 0)
338-
r3 = 'utf-8'
339-
r4 = 'backslashreplace'
340-
r5 = CPy_Decode(b, r3, r4)
345+
r0 = CPy_DecodeUTF8(b)
346+
r1 = CPy_DecodeUTF8(b)
347+
r2 = CPy_DecodeUTF8(b)
348+
r3 = CPy_DecodeUTF8(b)
349+
r4 = CPy_DecodeLatin1(b)
350+
r5 = CPy_DecodeLatin1(b)
351+
r6 = CPy_DecodeASCII(b)
352+
r7 = 'utf-8'
353+
encoding = r7
354+
r8 = CPy_Decode(b, encoding, 0)
355+
r9 = 'utf-8'
356+
r10 = 'backslashreplace'
357+
r11 = CPy_Decode(b, r9, r10)
358+
return 1
359+
def variants(b):
360+
b :: bytes
361+
r0, r1 :: str
362+
L0:
363+
r0 = CPy_DecodeUTF8(b)
364+
r1 = CPy_DecodeASCII(b)
341365
return 1
342366

343367
[case testEncode_64bit]

mypyc/test-data/run-strings.test

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -792,14 +792,23 @@ def test_ord() -> None:
792792
ord('')
793793

794794
[case testDecode]
795+
from testutil import assertRaises
796+
795797
def test_decode() -> None:
796798
assert "\N{GREEK CAPITAL LETTER DELTA}" == '\u0394'
797799
assert "\u0394" == "\u0394"
798800
assert "\U00000394" == '\u0394'
799801
assert b'\x80abc'.decode('utf-8', 'replace') == '\ufffdabc'
800802
assert b'\x80abc'.decode('utf-8', 'backslashreplace') == '\\x80abc'
803+
assert b''.decode() == ''
804+
assert b'a'.decode() == 'a'
801805
assert b'abc'.decode() == 'abc'
802806
assert b'abc'.decode('utf-8') == 'abc'
807+
assert b'abc'.decode('utf-8' + str()) == 'abc'
808+
assert b'abc\x00\xce'.decode('latin-1') == 'abc\x00\xce'
809+
assert b'abc\x00\xce'.decode('latin-1' + str()) == 'abc\x00\xce'
810+
assert b'abc\x00\x7f'.decode('ascii') == 'abc\x00\x7f'
811+
assert b'abc\x00\x7f'.decode('ascii' + str()) == 'abc\x00\x7f'
803812
assert b'\x80abc'.decode('utf-8', 'ignore') == 'abc'
804813
assert b'\x80abc'.decode('UTF-8', 'ignore') == 'abc'
805814
assert b'\x80abc'.decode('Utf-8', 'ignore') == 'abc'
@@ -808,16 +817,71 @@ def test_decode() -> None:
808817
assert b'\xd2\xbb\xb6\xfe\xc8\xfd'.decode('gbk', 'ignore') == '一二三'
809818
assert b'\xd2\xbb\xb6\xfe\xc8\xfd'.decode('latin1', 'ignore') == 'Ò»¶þÈý'
810819
assert b'Z\xc3\xbcrich'.decode("utf-8") == 'Zürich'
811-
try:
812-
b'Z\xc3\xbcrich'.decode('ascii')
813-
assert False
814-
except UnicodeDecodeError:
815-
pass
820+
assert b'Z\xc3\xbcrich'.decode("utf-8" + str()) == 'Zürich'
821+
816822
assert bytearray(range(5)).decode() == '\x00\x01\x02\x03\x04'
817823
b = bytearray(b'\xe4\xbd\xa0\xe5\xa5\xbd')
818824
assert b.decode() == '你好'
819825
assert b.decode('gbk') == '浣犲ソ'
820826
assert b.decode('latin1') == 'ä½\xa0好'
827+
assert b.decode('latin1' + str()) == 'ä½\xa0好'
828+
829+
def test_decode_error() -> None:
830+
try:
831+
b'Z\xc3\xbcrich'.decode('ascii')
832+
assert False
833+
except UnicodeDecodeError:
834+
pass
835+
try:
836+
b'Z\xc3\xbcrich'.decode('ascii' + str())
837+
assert False
838+
except UnicodeDecodeError:
839+
pass
840+
try:
841+
b'Z\xc3y'.decode('utf8')
842+
assert False
843+
except UnicodeDecodeError:
844+
pass
845+
try:
846+
b'Z\xc3y'.decode('utf8' + str())
847+
assert False
848+
except UnicodeDecodeError:
849+
pass
850+
851+
def test_decode_bytearray() -> None:
852+
b: bytes = bytearray(b'foo\x00bar')
853+
assert b.decode() == 'foo\x00bar'
854+
assert b.decode('utf-8') == 'foo\x00bar'
855+
assert b.decode('latin-1') == 'foo\x00bar'
856+
assert b.decode('ascii') == 'foo\x00bar'
857+
assert b.decode('utf-8' + str()) == 'foo\x00bar'
858+
assert b.decode('latin-1' + str()) == 'foo\x00bar'
859+
assert b.decode('ascii' + str()) == 'foo\x00bar'
860+
b2: bytes = bytearray(b'foo\x00bar\xbe')
861+
assert b2.decode('latin-1') == 'foo\x00bar\xbe'
862+
with assertRaises(UnicodeDecodeError):
863+
b2.decode('ascii')
864+
with assertRaises(UnicodeDecodeError):
865+
b2.decode('ascii' + str())
866+
with assertRaises(UnicodeDecodeError):
867+
b2.decode('utf-8')
868+
with assertRaises(UnicodeDecodeError):
869+
b2.decode('utf-8' + str())
870+
b3: bytes = bytearray(b'Z\xc3\xbcrich')
871+
assert b3.decode("utf-8") == 'Zürich'
872+
873+
def test_invalid_encoding() -> None:
874+
try:
875+
b"foo".decode("ut-f-8")
876+
assert False
877+
except Exception as e:
878+
assert repr(e).startswith("LookupError")
879+
try:
880+
encoding = "ut-f-8"
881+
b"foo".decode(encoding)
882+
assert False
883+
except Exception as e:
884+
assert repr(e).startswith("LookupError")
821885

822886
[case testEncode]
823887
from testutil import assertRaises

0 commit comments

Comments
 (0)