Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
RTuple,
RType,
bool_rprimitive,
bytes_rprimitive,
c_int_rprimitive,
dict_rprimitive,
int16_rprimitive,
Expand Down Expand Up @@ -98,6 +99,9 @@
from mypyc.primitives.misc_ops import isinstance_bool
from mypyc.primitives.set_ops import isinstance_frozenset, isinstance_set
from mypyc.primitives.str_ops import (
bytes_decode_ascii_strict,
bytes_decode_latin1_strict,
bytes_decode_utf8_strict,
isinstance_str,
str_encode_ascii_strict,
str_encode_latin1_strict,
Expand Down Expand Up @@ -787,6 +791,67 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
return None


@specialize_function("decode", bytes_rprimitive)
def bytes_decode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
"""Specialize common cases of obj.decode for most used encodings and strict errors."""

if not isinstance(callee, MemberExpr):
return None

# We can only specialize if we have string literals as args
if len(expr.arg_kinds) > 0 and not isinstance(expr.args[0], StrExpr):
return None
if len(expr.arg_kinds) > 1 and not isinstance(expr.args[1], StrExpr):
return None

encoding = "utf8"
errors = "strict"
if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr):
if expr.arg_kinds[0] == ARG_NAMED:
if expr.arg_names[0] == "encoding":
encoding = expr.args[0].value
elif expr.arg_names[0] == "errors":
errors = expr.args[0].value
elif expr.arg_kinds[0] == ARG_POS:
encoding = expr.args[0].value
else:
return None
if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr):
if expr.arg_kinds[1] == ARG_NAMED:
if expr.arg_names[1] == "encoding":
encoding = expr.args[1].value
elif expr.arg_names[1] == "errors":
errors = expr.args[1].value
elif expr.arg_kinds[1] == ARG_POS:
errors = expr.args[1].value
else:
return None

if errors != "strict":
# We can only specialize strict errors
return None

encoding = encoding.lower().replace("_", "-") # normalize
# Specialized encodings and their accepted aliases
if encoding in ["u8", "utf", "utf8", "utf-8", "cp65001"]:
return builder.call_c(bytes_decode_utf8_strict, [builder.accept(callee.expr)], expr.line)
elif encoding in ["646", "ascii", "usascii", "us-ascii"]:
return builder.call_c(bytes_decode_ascii_strict, [builder.accept(callee.expr)], expr.line)
elif encoding in [
"iso8859-1",
"iso-8859-1",
"8859",
"cp819",
"latin",
"latin1",
"latin-1",
"l1",
]:
return builder.call_c(bytes_decode_latin1_strict, [builder.accept(callee.expr)], expr.line)

return None


@specialize_function("mypy_extensions.i64")
def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:
Expand Down
3 changes: 3 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,9 @@ PyObject *CPyStr_Removesuffix(PyObject *self, PyObject *suffix);
bool CPyStr_IsTrue(PyObject *obj);
Py_ssize_t CPyStr_Size_size_t(PyObject *str);
PyObject *CPy_Decode(PyObject *obj, PyObject *encoding, PyObject *errors);
PyObject *CPy_DecodeUTF8(PyObject *bytes);
PyObject *CPy_DecodeASCII(PyObject *bytes);
PyObject *CPy_DecodeLatin1(PyObject *bytes);
PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors);
Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start);
Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end);
Expand Down
39 changes: 39 additions & 0 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,45 @@ PyObject *CPy_Decode(PyObject *obj, PyObject *encoding, PyObject *errors) {
}
}

PyObject *CPy_DecodeUTF8(PyObject *bytes) {
if (PyBytes_CheckExact(bytes)) {
char *buffer = PyBytes_AsString(bytes); // Borrowed reference
if (buffer == NULL) {
return NULL;
}
Py_ssize_t size = PyBytes_Size(bytes);
return PyUnicode_DecodeUTF8(buffer, size, "strict");
} else {
return PyUnicode_FromEncodedObject(bytes, "utf-8", "strict");
}
}

PyObject *CPy_DecodeASCII(PyObject *bytes) {
if (PyBytes_CheckExact(bytes)) {
char *buffer = PyBytes_AsString(bytes); // Borrowed reference
if (buffer == NULL) {
return NULL;
}
Py_ssize_t size = PyBytes_Size(bytes);
return PyUnicode_DecodeASCII(buffer, size, "strict");;
} else {
return PyUnicode_FromEncodedObject(bytes, "ascii", "strict");
}
}

PyObject *CPy_DecodeLatin1(PyObject *bytes) {
if (PyBytes_CheckExact(bytes)) {
char *buffer = PyBytes_AsString(bytes); // Borrowed reference
if (buffer == NULL) {
return NULL;
}
Py_ssize_t size = PyBytes_Size(bytes);
return PyUnicode_DecodeLatin1(buffer, size, "strict");
} else {
return PyUnicode_FromEncodedObject(bytes, "latin1", "strict");
}
}

PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors) {
const char *enc = NULL;
const char *err = NULL;
Expand Down
26 changes: 25 additions & 1 deletion mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@
extra_int_constants=[(0, pointer_rprimitive)],
)

# obj.decode(encoding, errors)
# bytes.decode(encoding, errors)
method_op(
name="decode",
arg_types=[bytes_rprimitive, str_rprimitive, str_rprimitive],
Expand All @@ -396,6 +396,30 @@
error_kind=ERR_MAGIC,
)

# bytes.decode(encoding) - utf8 strict specialization
bytes_decode_utf8_strict = custom_op(
arg_types=[bytes_rprimitive],
return_type=str_rprimitive,
c_function_name="CPy_DecodeUTF8",
error_kind=ERR_MAGIC,
)

# bytes.decode(encoding) - ascii strict specialization
bytes_decode_ascii_strict = custom_op(
arg_types=[bytes_rprimitive],
return_type=str_rprimitive,
c_function_name="CPy_DecodeASCII",
error_kind=ERR_MAGIC,
)

# bytes.decode(encoding) - latin1 strict specialization
bytes_decode_latin1_strict = custom_op(
arg_types=[bytes_rprimitive],
return_type=str_rprimitive,
c_function_name="CPy_DecodeLatin1",
error_kind=ERR_MAGIC,
)

# str.encode()
method_op(
name="encode",
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __getitem__(self, i: int) -> int: ...
@overload
def __getitem__(self, i: slice) -> bytes: ...
def join(self, x: Iterable[object]) -> bytes: ...
def decode(self, x: str=..., y: str=...) -> str: ...
def decode(self, encoding: str=..., errors: str=...) -> str: ...
def __iter__(self) -> Iterator[int]: ...

class bytearray:
Expand Down
38 changes: 31 additions & 7 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -325,19 +325,43 @@ L0:
[case testDecode]
def f(b: bytes) -> None:
b.decode()
b.decode('Utf_8')
b.decode('utf-8')
b.decode('UTF8')
b.decode('latin1')
b.decode('Latin-1')
b.decode('ascii')
encoding = 'utf-8'
b.decode(encoding)
b.decode('utf-8', 'backslashreplace')
def variants(b: bytes) -> None:
b.decode(encoding="UTF_8")
b.decode("ascii", errors="strict")
[out]
def f(b):
b :: bytes
r0, r1, r2, r3, r4, r5 :: str
r0, r1, r2, r3, r4, r5, r6, r7, encoding, r8, r9, r10, r11 :: str
L0:
r0 = CPy_Decode(b, 0, 0)
r1 = 'utf-8'
r2 = CPy_Decode(b, r1, 0)
r3 = 'utf-8'
r4 = 'backslashreplace'
r5 = CPy_Decode(b, r3, r4)
r0 = CPy_DecodeUTF8(b)
r1 = CPy_DecodeUTF8(b)
r2 = CPy_DecodeUTF8(b)
r3 = CPy_DecodeUTF8(b)
r4 = CPy_DecodeLatin1(b)
r5 = CPy_DecodeLatin1(b)
r6 = CPy_DecodeASCII(b)
r7 = 'utf-8'
encoding = r7
r8 = CPy_Decode(b, encoding, 0)
r9 = 'utf-8'
r10 = 'backslashreplace'
r11 = CPy_Decode(b, r9, r10)
return 1
def variants(b):
b :: bytes
r0, r1 :: str
L0:
r0 = CPy_DecodeUTF8(b)
r1 = CPy_DecodeASCII(b)
return 1

[case testEncode_64bit]
Expand Down
74 changes: 69 additions & 5 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -792,14 +792,23 @@ def test_ord() -> None:
ord('')

[case testDecode]
from testutil import assertRaises

def test_decode() -> None:
assert "\N{GREEK CAPITAL LETTER DELTA}" == '\u0394'
assert "\u0394" == "\u0394"
assert "\U00000394" == '\u0394'
assert b'\x80abc'.decode('utf-8', 'replace') == '\ufffdabc'
assert b'\x80abc'.decode('utf-8', 'backslashreplace') == '\\x80abc'
assert b''.decode() == ''
assert b'a'.decode() == 'a'
assert b'abc'.decode() == 'abc'
assert b'abc'.decode('utf-8') == 'abc'
assert b'abc'.decode('utf-8' + str()) == 'abc'
assert b'abc\x00\xce'.decode('latin-1') == 'abc\x00\xce'
assert b'abc\x00\xce'.decode('latin-1' + str()) == 'abc\x00\xce'
assert b'abc\x00\x7f'.decode('ascii') == 'abc\x00\x7f'
assert b'abc\x00\x7f'.decode('ascii' + str()) == 'abc\x00\x7f'
assert b'\x80abc'.decode('utf-8', 'ignore') == 'abc'
assert b'\x80abc'.decode('UTF-8', 'ignore') == 'abc'
assert b'\x80abc'.decode('Utf-8', 'ignore') == 'abc'
Expand All @@ -808,16 +817,71 @@ def test_decode() -> None:
assert b'\xd2\xbb\xb6\xfe\xc8\xfd'.decode('gbk', 'ignore') == '一二三'
assert b'\xd2\xbb\xb6\xfe\xc8\xfd'.decode('latin1', 'ignore') == 'Ò»¶þÈý'
assert b'Z\xc3\xbcrich'.decode("utf-8") == 'Zürich'
try:
b'Z\xc3\xbcrich'.decode('ascii')
assert False
except UnicodeDecodeError:
pass
assert b'Z\xc3\xbcrich'.decode("utf-8" + str()) == 'Zürich'

assert bytearray(range(5)).decode() == '\x00\x01\x02\x03\x04'
b = bytearray(b'\xe4\xbd\xa0\xe5\xa5\xbd')
assert b.decode() == '你好'
assert b.decode('gbk') == '浣犲ソ'
assert b.decode('latin1') == 'ä½\xa0好'
assert b.decode('latin1' + str()) == 'ä½\xa0好'

def test_decode_error() -> None:
try:
b'Z\xc3\xbcrich'.decode('ascii')
assert False
except UnicodeDecodeError:
pass
try:
b'Z\xc3\xbcrich'.decode('ascii' + str())
assert False
except UnicodeDecodeError:
pass
try:
b'Z\xc3y'.decode('utf8')
assert False
except UnicodeDecodeError:
pass
try:
b'Z\xc3y'.decode('utf8' + str())
assert False
except UnicodeDecodeError:
pass

def test_decode_bytearray() -> None:
b: bytes = bytearray(b'foo\x00bar')
assert b.decode() == 'foo\x00bar'
assert b.decode('utf-8') == 'foo\x00bar'
assert b.decode('latin-1') == 'foo\x00bar'
assert b.decode('ascii') == 'foo\x00bar'
assert b.decode('utf-8' + str()) == 'foo\x00bar'
assert b.decode('latin-1' + str()) == 'foo\x00bar'
assert b.decode('ascii' + str()) == 'foo\x00bar'
b2: bytes = bytearray(b'foo\x00bar\xbe')
assert b2.decode('latin-1') == 'foo\x00bar\xbe'
with assertRaises(UnicodeDecodeError):
b2.decode('ascii')
with assertRaises(UnicodeDecodeError):
b2.decode('ascii' + str())
with assertRaises(UnicodeDecodeError):
b2.decode('utf-8')
with assertRaises(UnicodeDecodeError):
b2.decode('utf-8' + str())
b3: bytes = bytearray(b'Z\xc3\xbcrich')
assert b3.decode("utf-8") == 'Zürich'

def test_invalid_encoding() -> None:
try:
b"foo".decode("ut-f-8")
assert False
except Exception as e:
assert repr(e).startswith("LookupError")
try:
encoding = "ut-f-8"
b"foo".decode(encoding)
assert False
except Exception as e:
assert repr(e).startswith("LookupError")

[case testEncode]
from testutil import assertRaises
Expand Down
Loading