|
49 | 49 | RTuple, |
50 | 50 | RType, |
51 | 51 | bool_rprimitive, |
| 52 | + bytes_rprimitive, |
52 | 53 | c_int_rprimitive, |
53 | 54 | dict_rprimitive, |
54 | 55 | int16_rprimitive, |
|
89 | 90 | dict_setdefault_spec_init_op, |
90 | 91 | dict_values_op, |
91 | 92 | ) |
| 93 | +from mypyc.primitives.bytes_ops import bytes_decode_utf8_strict, bytes_decode_latin1_strict, bytes_decode_ascii_strict |
92 | 94 | from mypyc.primitives.list_ops import new_list_set_item_op |
93 | 95 | from mypyc.primitives.str_ops import ( |
94 | 96 | str_encode_ascii_strict, |
@@ -740,6 +742,52 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> |
740 | 742 | return None |
741 | 743 |
|
742 | 744 |
|
| 745 | +@specialize_function("decode", bytes_rprimitive) |
| 746 | +def bytes_decode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: |
| 747 | + if not isinstance(callee, MemberExpr): |
| 748 | + return None |
| 749 | + |
| 750 | + encoding = "utf8" |
| 751 | + errors = "strict" |
| 752 | + |
| 753 | + # Handle up to 2 arguments: decode([encoding], [errors]) |
| 754 | + if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr): |
| 755 | + if expr.arg_kinds[0] == ARG_NAMED: |
| 756 | + if expr.arg_names[0] == "encoding": |
| 757 | + encoding = expr.args[0].value |
| 758 | + elif expr.arg_names[0] == "errors": |
| 759 | + errors = expr.args[0].value |
| 760 | + elif expr.arg_kinds[0] == ARG_POS: |
| 761 | + encoding = expr.args[0].value |
| 762 | + else: |
| 763 | + return None |
| 764 | + |
| 765 | + if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr): |
| 766 | + if expr.arg_kinds[1] == ARG_NAMED: |
| 767 | + if expr.arg_names[1] == "encoding": |
| 768 | + encoding = expr.args[1].value |
| 769 | + elif expr.arg_names[1] == "errors": |
| 770 | + errors = expr.args[1].value |
| 771 | + elif expr.arg_kinds[1] == ARG_POS: |
| 772 | + errors = expr.args[1].value |
| 773 | + else: |
| 774 | + return None |
| 775 | + |
| 776 | + if errors != "strict": |
| 777 | + return None |
| 778 | + |
| 779 | + normalized = encoding.lower().replace("-", "").replace("_", "") |
| 780 | + |
| 781 | + if normalized in ("utf8", "utf", "u8", "cp65001"): |
| 782 | + return builder.primitive_op(bytes_decode_utf8_strict, [builder.accept(callee.expr)], expr.line) |
| 783 | + elif normalized in ("ascii", "usascii", "646"): |
| 784 | + return builder.primitive_op(bytes_decode_ascii_strict, [builder.accept(callee.expr)], expr.line) |
| 785 | + elif normalized in ("latin1", "latin", "iso88591", "cp819", "8859", "l1"): |
| 786 | + return builder.primitive_op(bytes_decode_latin1_strict, [builder.accept(callee.expr)], expr.line) |
| 787 | + |
| 788 | + return None |
| 789 | + |
| 790 | + |
743 | 791 | @specialize_function("mypy_extensions.i64") |
744 | 792 | def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: |
745 | 793 | if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS: |
|
0 commit comments