Skip to content

Commit a554d1f

Browse files
[mypyc] Add primitive for bytes decode() method
1 parent a3ce6d5 commit a554d1f

File tree

5 files changed

+60
-1
lines changed

5 files changed

+60
-1
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,6 @@ int CPyList_Insert(PyObject *list, CPyTagged index, PyObject *value);
662662
PyObject *CPyList_Extend(PyObject *o1, PyObject *o2);
663663
int CPyList_Remove(PyObject *list, PyObject *obj);
664664
CPyTagged CPyList_Index(PyObject *list, PyObject *obj);
665-
PyObject *CPySequence_Sort(PyObject *seq);
666665
PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size);
667666
PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq);
668667
PyObject *CPySequence_InPlaceMultiply(PyObject *seq, CPyTagged t_size);
@@ -764,6 +763,7 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index);
764763
PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
765764
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
766765
CPyTagged CPyBytes_Ord(PyObject *obj);
766+
PyObject *CPy_DecodeUtf8(PyObject *bytes_obj, const char *errors);
767767

768768

769769
int CPyBytes_Compare(PyObject *left, PyObject *right);

mypyc/lib-rt/bytes_ops.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,16 @@ CPyTagged CPyBytes_Ord(PyObject *obj) {
162162
PyErr_SetString(PyExc_TypeError, "ord() expects a character");
163163
return CPY_INT_TAG;
164164
}
165+
166+
167+
PyObject *CPy_DecodeUtf8(PyObject *bytes_obj, const char *errors) {
168+
if (!PyBytes_Check(bytes_obj)) {
169+
PyErr_SetString(PyExc_TypeError, "expected bytes object");
170+
return NULL;
171+
}
172+
173+
char *data = PyBytes_AS_STRING(bytes_obj);
174+
Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj);
175+
176+
return PyUnicode_DecodeUTF8(data, size, errors);
177+
}

mypyc/primitives/bytes_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,11 @@
107107
c_function_name="CPyBytes_Ord",
108108
error_kind=ERR_MAGIC,
109109
)
110+
111+
method_op(
112+
name="decode",
113+
arg_types=[bytes_rprimitive, bytes_rprimitive],
114+
return_type=str_rprimitive,
115+
c_function_name="CPy_DecodeUtf8",
116+
error_kind=ERR_MAGIC,
117+
)

mypyc/test-data/irbuild-bytes.test

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,15 @@ L0:
185185
r10 = CPyBytes_Build(2, var, r9)
186186
b4 = r10
187187
return 1
188+
189+
[case testDecodeUtf8]
190+
def f(b: bytes) -> str:
191+
return b.decode("utf-8")
192+
[out]
193+
def f(b):
194+
b :: bytes
195+
r0, r1 :: str
196+
L0:
197+
r0 = 'utf-8'
198+
r1 = CPy_Decode(b, r0, 0)
199+
return r1

mypyc/test-data/run-bytes.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,29 @@ class A:
323323
def test_bytes_dunder() -> None:
324324
assert b'%b' % A() == b'aaa'
325325
assert b'%s' % A() == b'aaa'
326+
327+
[case testDecodeUtf8]
328+
from typing import Any
329+
from testutil import assertRaises
330+
from a import bytes_subclass
331+
332+
def test_decode_utf8() -> None:
333+
assert b'hello'.decode('utf-8') == 'hello'
334+
assert b''.decode('utf-8') == ''
335+
336+
x: bytes = bytearray(b'hello')
337+
assert x.decode('utf-8') == 'hello'
338+
assert type(x.decode('utf-8')) == str
339+
340+
y: Any = bytes_subclass()
341+
assert y.decode('utf-8') == 'spook'
342+
343+
n: Any = 123
344+
with assertRaises(AttributeError):
345+
n.decode('utf-8')
346+
347+
348+
[file a.py]
349+
class bytes_subclass(bytes):
350+
def decode(self, encoding='utf-8'):
351+
return 'spook'

0 commit comments

Comments
 (0)