Skip to content

Commit c086137

Browse files
add latin1 and ascii decode functions
1 parent 7f1f197 commit c086137

File tree

6 files changed

+126
-35
lines changed

6 files changed

+126
-35
lines changed

mypyc/irbuild/specialize.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
RTuple,
5050
RType,
5151
bool_rprimitive,
52+
bytes_rprimitive,
5253
c_int_rprimitive,
5354
dict_rprimitive,
5455
int16_rprimitive,
@@ -89,6 +90,7 @@
8990
dict_setdefault_spec_init_op,
9091
dict_values_op,
9192
)
93+
from mypyc.primitives.bytes_ops import bytes_decode_utf8_strict, bytes_decode_latin1_strict, bytes_decode_ascii_strict
9294
from mypyc.primitives.list_ops import new_list_set_item_op
9395
from mypyc.primitives.str_ops import (
9496
str_encode_ascii_strict,
@@ -740,6 +742,52 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
740742
return None
741743

742744

745+
@specialize_function("decode", bytes_rprimitive)
746+
def bytes_decode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
747+
if not isinstance(callee, MemberExpr):
748+
return None
749+
750+
encoding = "utf8"
751+
errors = "strict"
752+
753+
# Handle up to 2 arguments: decode([encoding], [errors])
754+
if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr):
755+
if expr.arg_kinds[0] == ARG_NAMED:
756+
if expr.arg_names[0] == "encoding":
757+
encoding = expr.args[0].value
758+
elif expr.arg_names[0] == "errors":
759+
errors = expr.args[0].value
760+
elif expr.arg_kinds[0] == ARG_POS:
761+
encoding = expr.args[0].value
762+
else:
763+
return None
764+
765+
if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr):
766+
if expr.arg_kinds[1] == ARG_NAMED:
767+
if expr.arg_names[1] == "encoding":
768+
encoding = expr.args[1].value
769+
elif expr.arg_names[1] == "errors":
770+
errors = expr.args[1].value
771+
elif expr.arg_kinds[1] == ARG_POS:
772+
errors = expr.args[1].value
773+
else:
774+
return None
775+
776+
if errors != "strict":
777+
return None
778+
779+
normalized = encoding.lower().replace("-", "").replace("_", "")
780+
781+
if normalized in ("utf8", "utf", "u8", "cp65001"):
782+
return builder.primitive_op(bytes_decode_utf8_strict, [builder.accept(callee.expr)], expr.line)
783+
elif normalized in ("ascii", "usascii", "646"):
784+
return builder.primitive_op(bytes_decode_ascii_strict, [builder.accept(callee.expr)], expr.line)
785+
elif normalized in ("latin1", "latin", "iso88591", "cp819", "8859", "l1"):
786+
return builder.primitive_op(bytes_decode_latin1_strict, [builder.accept(callee.expr)], expr.line)
787+
788+
return None
789+
790+
743791
@specialize_function("mypy_extensions.i64")
744792
def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
745793
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,8 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
765765
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
766766
CPyTagged CPyBytes_Ord(PyObject *obj);
767767
PyObject *CPy_DecodeUtf8(PyObject *bytes_obj, const char *errors);
768+
PyObject *CPy_DecodeLatin1(PyObject *bytes_obj, const char *errors);
769+
PyObject *CPy_DecodeAscii(PyObject *bytes_obj, const char *errors);
768770

769771

770772
int CPyBytes_Compare(PyObject *left, PyObject *right);

mypyc/lib-rt/bytes_ops.c

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,29 @@ PyObject *CPy_DecodeUtf8(PyObject *bytes_obj, const char *errors) {
175175

176176
return PyUnicode_DecodeUTF8(data, size, errors);
177177
}
178+
179+
180+
PyObject *CPy_DecodeLatin1(PyObject *bytes_obj, const char *errors) {
181+
if (!PyBytes_Check(bytes_obj)) {
182+
PyErr_SetString(PyExc_TypeError, "expected bytes object");
183+
return NULL;
184+
}
185+
186+
char *data = PyBytes_AS_STRING(bytes_obj);
187+
Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj);
188+
189+
return PyUnicode_DecodeLatin1(data, size, errors);
190+
}
191+
192+
193+
PyObject *CPy_DecodeAscii(PyObject *bytes_obj, const char *errors) {
194+
if (!PyBytes_Check(bytes_obj)) {
195+
PyErr_SetString(PyExc_TypeError, "expected bytes object");
196+
return NULL;
197+
}
198+
199+
char *data = PyBytes_AS_STRING(bytes_obj);
200+
Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj);
201+
202+
return PyUnicode_DecodeASCII(data, size, errors);
203+
}

mypyc/primitives/bytes_ops.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ERR_NEG_INT,
1919
binary_op,
2020
custom_op,
21+
custom_primitive_op,
2122
function_op,
2223
load_address_op,
2324
method_op,
@@ -108,10 +109,26 @@
108109
error_kind=ERR_MAGIC,
109110
)
110111

111-
method_op(
112+
bytes_decode_utf8_strict = custom_primitive_op(
112113
name="decode",
113-
arg_types=[bytes_rprimitive, bytes_rprimitive],
114+
arg_types=[bytes_rprimitive, str_rprimitive],
114115
return_type=str_rprimitive,
115116
c_function_name="CPy_DecodeUtf8",
116117
error_kind=ERR_MAGIC,
117118
)
119+
120+
bytes_decode_latin1_strict = custom_primitive_op(
121+
name="decode_latin1",
122+
arg_types=[bytes_rprimitive, str_rprimitive],
123+
return_type=str_rprimitive,
124+
c_function_name="CPy_DecodeLatin1",
125+
error_kind=ERR_MAGIC,
126+
)
127+
128+
bytes_decode_ascii_strict = custom_primitive_op(
129+
name="decode_ascii",
130+
arg_types=[bytes_rprimitive, str_rprimitive],
131+
return_type=str_rprimitive,
132+
c_function_name="CPy_DecodeAscii",
133+
error_kind=ERR_MAGIC,
134+
)

mypyc/test-data/irbuild-bytes.test

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,38 @@ L0:
186186
b4 = r10
187187
return 1
188188

189-
[case testDecodeUtf8]
190-
def f(b: bytes) -> str:
191-
return b.decode("utf-8")
189+
[case testDecodeBytes]
190+
def f(b: bytes) -> None:
191+
b.decode()
192+
b.decode('utf8')
193+
b.decode('utf-8', 'strict')
194+
b.decode('utf-8', 'strict')
195+
b.decode('latin1', 'strict')
196+
b.decode('ascii')
197+
b.decode('latin-1')
198+
b.decode('utf-8', 'ignore')
199+
b.decode('ascii', 'replace')
200+
b.decode('latin1', 'ignore')
192201
[out]
193202
def f(b):
194203
b :: bytes
195-
r0, r1 :: str
204+
r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15 :: str
196205
L0:
197-
r0 = 'utf-8'
198-
r1 = CPy_Decode(b, r0, 0)
199-
return r1
206+
r0 = CPy_DecodeUtf8(b)
207+
r1 = CPy_DecodeUtf8(b)
208+
r2 = CPy_DecodeUtf8(b)
209+
r3 = CPy_DecodeUtf8(b)
210+
r4 = CPy_DecodeLatin1(b)
211+
r5 = CPy_DecodeAscii(b)
212+
r6 = CPy_DecodeLatin1(b)
213+
r7 = 'utf-8'
214+
r8 = 'ignore'
215+
r9 = CPy_Decode(b, r7, r8)
216+
r10 = 'ascii'
217+
r11 = 'replace'
218+
r12 = CPy_Decode(b, r10, r11)
219+
r13 = 'latin1'
220+
r14 = 'ignore'
221+
r15 = CPy_Decode(b, r13, r14)
222+
return 1
223+

mypyc/test-data/run-bytes.test

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -323,29 +323,3 @@ class A:
323323
def test_bytes_dunder() -> None:
324324
assert b'%b' % A() == b'aaa'
325325
assert b'%s' % A() == b'aaa'
326-
327-
[case testDecodeUtf8]
328-
from typing import Any
329-
from testutil import assertRaises
330-
from a import bytes_subclass
331-
332-
def test_decode_utf8() -> None:
333-
assert b'hello'.decode('utf-8') == 'hello'
334-
assert b''.decode('utf-8') == ''
335-
336-
x: bytes = bytearray(b'hello')
337-
assert x.decode('utf-8') == 'hello'
338-
assert type(x.decode('utf-8')) == str
339-
340-
y: Any = bytes_subclass()
341-
assert y.decode('utf-8') == 'spook'
342-
343-
n: Any = 123
344-
with assertRaises(AttributeError):
345-
n.decode('utf-8')
346-
347-
348-
[file a.py]
349-
class bytes_subclass(bytes):
350-
def decode(self, encoding='utf-8'):
351-
return 'spook'

0 commit comments

Comments
 (0)