Skip to content

Commit 61cd522

Browse files
committed
Update arg logic and add more tests
1 parent e10dec5 commit 61cd522

File tree

2 files changed

+112
-35
lines changed

2 files changed

+112
-35
lines changed

mypyc/irbuild/specialize.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -694,33 +694,46 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
694694
if not isinstance(callee, MemberExpr):
695695
return None
696696

697-
# We can only specialize strict errors
698-
if (
699-
len(expr.arg_kinds) > 1
700-
and isinstance(expr.args[1], StrExpr)
701-
and expr.args[1].value != "strict"
702-
):
697+
# We can only specialize if we have string literals as args
698+
if len(expr.arg_kinds) > 0 and not isinstance(expr.args[0], StrExpr):
703699
return None
704-
705-
if (
706-
len(expr.args) > 0
707-
and expr.arg_kinds[0] == ARG_NAMED
708-
and expr.arg_names[0] == "errors"
709-
and isinstance(expr.args[0], StrExpr)
710-
and expr.args[0].value != "strict"
711-
):
700+
if len(expr.arg_kinds) > 1 and not isinstance(expr.args[1], StrExpr):
712701
return None
713702

714703
encoding = "utf8"
715-
if len(expr.args) > 0 and isinstance(expr.args[0], StrExpr):
716-
encoding = expr.args[0].value.lower().replace("-", "_")
704+
errors = "strict"
705+
if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr):
706+
if expr.arg_kinds[0] == ARG_NAMED:
707+
if expr.arg_names[0] == "encoding":
708+
encoding = expr.args[0].value
709+
elif expr.arg_names[0] == "errors":
710+
errors = expr.args[0].value
711+
elif expr.arg_kinds[0] == ARG_POS:
712+
encoding = expr.args[0].value
713+
else:
714+
return None
715+
if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr):
716+
if expr.arg_kinds[1] == ARG_NAMED:
717+
if expr.arg_names[1] == "encoding":
718+
encoding = expr.args[1].value
719+
elif expr.arg_names[1] == "errors":
720+
errors = expr.args[1].value
721+
elif expr.arg_kinds[1] == ARG_POS:
722+
errors = expr.args[1].value
723+
else:
724+
return None
725+
726+
if errors != "strict":
727+
# We can only specialize strict errors
728+
return None
717729

730+
encoding = encoding.lower().replace("-", "").replace("_", "") # normalize
718731
# Specialized encodings and their accepted aliases
719-
if encoding in ['u8', 'utf', 'utf8', 'utf_8', 'cp65001']:
732+
if encoding in ["u8", "utf", "utf8", "cp65001"]:
720733
return builder.call_c(str_encode_utf8_strict, [builder.accept(callee.expr)], expr.line)
721-
elif encoding in ["ascii", "646", "us_ascii"]:
734+
elif encoding in ["646", "ascii", "usascii"]:
722735
return builder.call_c(str_encode_ascii_strict, [builder.accept(callee.expr)], expr.line)
723-
elif encoding in ['iso_8859_1', 'iso8859_1', '8859', 'cp819', 'latin', 'latin1', 'latin_1', 'l1']:
736+
elif encoding in ["iso88591", "8859", "cp819", "latin", "latin1", "l1"]:
724737
return builder.call_c(str_encode_latin1_strict, [builder.accept(callee.expr)], expr.line)
725738

726739
return None

mypyc/test-data/irbuild-str.test

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -293,30 +293,94 @@ L0:
293293
def f(s: str) -> None:
294294
s.encode()
295295
s.encode('utf-8')
296-
s.encode('utf-8', 'strict')
297-
s.encode('utf-8', errors='strict')
298-
s.encode('utf-8', 'backslashreplace')
296+
s.encode('utf8', 'strict')
297+
s.encode('latin1', errors='strict')
299298
s.encode(encoding='ascii')
299+
s.encode(errors='strict', encoding='latin-1')
300+
s.encode('utf-8', 'backslashreplace')
300301
s.encode('ascii', 'backslashreplace')
302+
encoding = 'utf8'
303+
s.encode(encoding)
304+
errors = 'strict'
305+
s.encode('utf8', errors)
306+
s.encode('utf8', errors=errors)
307+
s.encode(errors=errors)
308+
s.encode(encoding=encoding, errors=errors)
309+
s.encode('latin2')
310+
301311
[out]
302312
def f(s):
303313
s :: str
304-
r0, r1, r2, r3 :: bytes
305-
r4, r5 :: str
306-
r6, r7 :: bytes
307-
r8, r9 :: str
308-
r10 :: bytes
314+
r0, r1, r2, r3, r4, r5 :: bytes
315+
r6, r7 :: str
316+
r8 :: bytes
317+
r9, r10 :: str
318+
r11 :: bytes
319+
r12, encoding :: str
320+
r13 :: bytes
321+
r14, errors, r15 :: str
322+
r16 :: bytes
323+
r17, r18 :: str
324+
r19 :: object
325+
r20 :: str
326+
r21 :: tuple
327+
r22 :: dict
328+
r23 :: object
329+
r24 :: str
330+
r25 :: object
331+
r26 :: str
332+
r27 :: tuple
333+
r28 :: dict
334+
r29 :: object
335+
r30 :: str
336+
r31 :: object
337+
r32, r33 :: str
338+
r34 :: tuple
339+
r35 :: dict
340+
r36 :: object
341+
r37 :: str
342+
r38 :: bytes
309343
L0:
310344
r0 = PyUnicode_AsUTF8String(s)
311345
r1 = PyUnicode_AsUTF8String(s)
312346
r2 = PyUnicode_AsUTF8String(s)
313-
r3 = PyUnicode_AsUTF8String(s)
314-
r4 = 'utf-8'
315-
r5 = 'backslashreplace'
316-
r6 = CPy_Encode(s, r4, r5)
317-
r7 = PyUnicode_AsASCIIString(s)
318-
r8 = 'ascii'
319-
r9 = 'backslashreplace'
320-
r10 = CPy_Encode(s, r8, r9)
347+
r3 = PyUnicode_AsLatin1String(s)
348+
r4 = PyUnicode_AsASCIIString(s)
349+
r5 = PyUnicode_AsLatin1String(s)
350+
r6 = 'utf-8'
351+
r7 = 'backslashreplace'
352+
r8 = CPy_Encode(s, r6, r7)
353+
r9 = 'ascii'
354+
r10 = 'backslashreplace'
355+
r11 = CPy_Encode(s, r9, r10)
356+
r12 = 'utf8'
357+
encoding = r12
358+
r13 = CPy_Encode(s, encoding, 0)
359+
r14 = 'strict'
360+
errors = r14
361+
r15 = 'utf8'
362+
r16 = CPy_Encode(s, r15, errors)
363+
r17 = 'utf8'
364+
r18 = 'encode'
365+
r19 = CPyObject_GetAttr(s, r18)
366+
r20 = 'errors'
367+
r21 = PyTuple_Pack(1, r17)
368+
r22 = CPyDict_Build(1, r20, errors)
369+
r23 = PyObject_Call(r19, r21, r22)
370+
r24 = 'encode'
371+
r25 = CPyObject_GetAttr(s, r24)
372+
r26 = 'errors'
373+
r27 = PyTuple_Pack(0)
374+
r28 = CPyDict_Build(1, r26, errors)
375+
r29 = PyObject_Call(r25, r27, r28)
376+
r30 = 'encode'
377+
r31 = CPyObject_GetAttr(s, r30)
378+
r32 = 'encoding'
379+
r33 = 'errors'
380+
r34 = PyTuple_Pack(0)
381+
r35 = CPyDict_Build(2, r32, encoding, r33, errors)
382+
r36 = PyObject_Call(r31, r34, r35)
383+
r37 = 'latin2'
384+
r38 = CPy_Encode(s, r37, 0)
321385
return 1
322386

0 commit comments

Comments
 (0)