Skip to content

Commit e10dec5

Browse files
committed
Optimize str.encode with specializations for common used encodings
1 parent 499adae commit e10dec5

File tree

4 files changed

+90
-12
lines changed

4 files changed

+90
-12
lines changed

mypyc/irbuild/specialize.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@
9090
)
9191
from mypyc.primitives.list_ops import new_list_set_item_op
9292
from mypyc.primitives.tuple_ops import new_tuple_set_item_op
93+
from mypyc.primitives.str_ops import (
94+
str_encode_utf8_strict,
95+
str_encode_ascii_strict,
96+
str_encode_latin1_strict,
97+
)
9398

9499
# Specializers are attempted before compiling the arguments to the
95100
# function. Specializers can return None to indicate that they failed
@@ -682,6 +687,45 @@ def translate_fstring(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Va
682687
return None
683688

684689

690+
@specialize_function("encode", str_rprimitive)
691+
def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
692+
"""Specialize common cases of str.encode for most used encodings and strict errors."""
693+
694+
if not isinstance(callee, MemberExpr):
695+
return None
696+
697+
# We can only specialize strict errors
698+
if (
699+
len(expr.arg_kinds) > 1
700+
and isinstance(expr.args[1], StrExpr)
701+
and expr.args[1].value != "strict"
702+
):
703+
return None
704+
705+
if (
706+
len(expr.args) > 0
707+
and expr.arg_kinds[0] == ARG_NAMED
708+
and expr.arg_names[0] == "errors"
709+
and isinstance(expr.args[0], StrExpr)
710+
and expr.args[0].value != "strict"
711+
):
712+
return None
713+
714+
encoding = "utf8"
715+
if len(expr.args) > 0 and isinstance(expr.args[0], StrExpr):
716+
encoding = expr.args[0].value.lower().replace("-", "_")
717+
718+
# Specialized encodings and their accepted aliases
719+
if encoding in ['u8', 'utf', 'utf8', 'utf_8', 'cp65001']:
720+
return builder.call_c(str_encode_utf8_strict, [builder.accept(callee.expr)], expr.line)
721+
elif encoding in ["ascii", "646", "us_ascii"]:
722+
return builder.call_c(str_encode_ascii_strict, [builder.accept(callee.expr)], expr.line)
723+
elif encoding in ['iso_8859_1', 'iso8859_1', '8859', 'cp819', 'latin', 'latin1', 'latin_1', 'l1']:
724+
return builder.call_c(str_encode_latin1_strict, [builder.accept(callee.expr)], expr.line)
725+
726+
return None
727+
728+
685729
@specialize_function("mypy_extensions.i64")
686730
def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
687731
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:

mypyc/primitives/str_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,30 @@
219219
extra_int_constants=[(0, pointer_rprimitive)],
220220
)
221221

222+
# str.encode(encoding) - utf8 strict specialization
223+
str_encode_utf8_strict = custom_op(
224+
arg_types=[str_rprimitive],
225+
return_type=bytes_rprimitive,
226+
c_function_name="PyUnicode_AsUTF8String",
227+
error_kind=ERR_MAGIC,
228+
)
229+
230+
# str.encode(encoding) - ascii strict specialization
231+
str_encode_ascii_strict = custom_op(
232+
arg_types=[str_rprimitive],
233+
return_type=bytes_rprimitive,
234+
c_function_name="PyUnicode_AsASCIIString",
235+
error_kind=ERR_MAGIC,
236+
)
237+
238+
# str.encode(encoding) - latin1 strict specialization
239+
str_encode_latin1_strict = custom_op(
240+
arg_types=[str_rprimitive],
241+
return_type=bytes_rprimitive,
242+
c_function_name="PyUnicode_AsLatin1String",
243+
error_kind=ERR_MAGIC,
244+
)
245+
222246
# str.encode(encoding, errors)
223247
method_op(
224248
name="encode",

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def upper(self) -> str: ...
110110
def startswith(self, x: str, start: int=..., end: int=...) -> bool: ...
111111
def endswith(self, x: str, start: int=..., end: int=...) -> bool: ...
112112
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
113-
def encode(self, x: str=..., y: str=...) -> bytes: ...
113+
def encode(self, encoding: str=..., errors: str=...) -> bytes: ...
114114

115115
class float:
116116
def __init__(self, x: object) -> None: pass

mypyc/test-data/irbuild-str.test

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,20 +293,30 @@ L0:
293293
def f(s: str) -> None:
294294
s.encode()
295295
s.encode('utf-8')
296+
s.encode('utf-8', 'strict')
297+
s.encode('utf-8', errors='strict')
298+
s.encode('utf-8', 'backslashreplace')
299+
s.encode(encoding='ascii')
296300
s.encode('ascii', 'backslashreplace')
297301
[out]
298302
def f(s):
299303
s :: str
300-
r0 :: bytes
301-
r1 :: str
302-
r2 :: bytes
303-
r3, r4 :: str
304-
r5 :: bytes
304+
r0, r1, r2, r3 :: bytes
305+
r4, r5 :: str
306+
r6, r7 :: bytes
307+
r8, r9 :: str
308+
r10 :: bytes
305309
L0:
306-
r0 = CPy_Encode(s, 0, 0)
307-
r1 = 'utf-8'
308-
r2 = CPy_Encode(s, r1, 0)
309-
r3 = 'ascii'
310-
r4 = 'backslashreplace'
311-
r5 = CPy_Encode(s, r3, r4)
310+
r0 = PyUnicode_AsUTF8String(s)
311+
r1 = PyUnicode_AsUTF8String(s)
312+
r2 = PyUnicode_AsUTF8String(s)
313+
r3 = PyUnicode_AsUTF8String(s)
314+
r4 = 'utf-8'
315+
r5 = 'backslashreplace'
316+
r6 = CPy_Encode(s, r4, r5)
317+
r7 = PyUnicode_AsASCIIString(s)
318+
r8 = 'ascii'
319+
r9 = 'backslashreplace'
320+
r10 = CPy_Encode(s, r8, r9)
312321
return 1
322+

0 commit comments

Comments
 (0)