Skip to content

Commit b38dcc6

Browse files
Merge branch 'master' into solve_intermediate_variable_bug
2 parents 507c439 + 35d8c69 commit b38dcc6

File tree

16 files changed

+198
-141
lines changed

16 files changed

+198
-141
lines changed

mypyc/codegen/emit.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
RType,
2929
RUnion,
3030
int_rprimitive,
31-
is_bit_rprimitive,
32-
is_bool_rprimitive,
31+
is_bool_or_bit_rprimitive,
3332
is_bytes_rprimitive,
3433
is_dict_rprimitive,
3534
is_fixed_width_rtype,
@@ -615,8 +614,7 @@ def emit_cast(
615614
or is_range_rprimitive(typ)
616615
or is_float_rprimitive(typ)
617616
or is_int_rprimitive(typ)
618-
or is_bool_rprimitive(typ)
619-
or is_bit_rprimitive(typ)
617+
or is_bool_or_bit_rprimitive(typ)
620618
or is_fixed_width_rtype(typ)
621619
):
622620
if declare_dest:
@@ -638,7 +636,7 @@ def emit_cast(
638636
elif is_int_rprimitive(typ) or is_fixed_width_rtype(typ):
639637
# TODO: Range check for fixed-width types?
640638
prefix = "PyLong"
641-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
639+
elif is_bool_or_bit_rprimitive(typ):
642640
prefix = "PyBool"
643641
else:
644642
assert False, f"unexpected primitive type: {typ}"
@@ -889,7 +887,7 @@ def emit_unbox(
889887
self.emit_line("else {")
890888
self.emit_line(failure)
891889
self.emit_line("}")
892-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
890+
elif is_bool_or_bit_rprimitive(typ):
893891
# Whether we are borrowing or not makes no difference.
894892
if declare_dest:
895893
self.emit_line(f"char {dest};")
@@ -1015,7 +1013,7 @@ def emit_box(
10151013
if is_int_rprimitive(typ) or is_short_int_rprimitive(typ):
10161014
# Steal the existing reference if it exists.
10171015
self.emit_line(f"{declaration}{dest} = CPyTagged_StealAsObject({src});")
1018-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
1016+
elif is_bool_or_bit_rprimitive(typ):
10191017
# N.B: bool is special cased to produce a borrowed value
10201018
# after boxing, so we don't need to increment the refcount
10211019
# when this comes directly from a Box op.

mypyc/ir/ops.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ class to enable the new behavior. Sometimes adding a new abstract
4242
cstring_rprimitive,
4343
float_rprimitive,
4444
int_rprimitive,
45-
is_bit_rprimitive,
46-
is_bool_rprimitive,
45+
is_bool_or_bit_rprimitive,
4746
is_int_rprimitive,
4847
is_none_rprimitive,
4948
is_pointer_rprimitive,
@@ -1089,11 +1088,7 @@ def __init__(self, src: Value, line: int = -1) -> None:
10891088
self.src = src
10901089
self.type = object_rprimitive
10911090
# When we box None and bool values, we produce a borrowed result
1092-
if (
1093-
is_none_rprimitive(self.src.type)
1094-
or is_bool_rprimitive(self.src.type)
1095-
or is_bit_rprimitive(self.src.type)
1096-
):
1091+
if is_none_rprimitive(self.src.type) or is_bool_or_bit_rprimitive(self.src.type):
10971092
self.is_borrowed = True
10981093

10991094
def sources(self) -> list[Value]:

mypyc/ir/rtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,10 @@ def is_bit_rprimitive(rtype: RType) -> bool:
582582
return isinstance(rtype, RPrimitive) and rtype.name == "bit"
583583

584584

585+
def is_bool_or_bit_rprimitive(rtype: RType) -> bool:
586+
return is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype)
587+
588+
585589
def is_object_rprimitive(rtype: RType) -> bool:
586590
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.object"
587591

mypyc/irbuild/classdef.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
handle_non_ext_method,
6565
load_type,
6666
)
67+
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
6768
from mypyc.irbuild.util import dataclass_type, get_func_def, is_constant, is_dataclass_decorator
6869
from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op
6970
from mypyc.primitives.generic_ops import (
@@ -135,6 +136,14 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
135136
cls_builder = NonExtClassBuilder(builder, cdef)
136137

137138
for stmt in cdef.defs.body:
139+
if (
140+
isinstance(stmt, (FuncDef, Decorator, OverloadedFuncDef))
141+
and stmt.name == GENERATOR_HELPER_NAME
142+
):
143+
builder.error(
144+
f'Method name "{stmt.name}" is reserved for mypyc internal use', stmt.line
145+
)
146+
138147
if isinstance(stmt, OverloadedFuncDef) and stmt.is_property:
139148
if isinstance(cls_builder, NonExtClassBuilder):
140149
# properties with both getters and setters in non_extension

mypyc/irbuild/generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
setup_func_for_recursive_call,
5151
)
5252
from mypyc.irbuild.nonlocalcontrol import ExceptNonlocalControl
53+
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
5354
from mypyc.primitives.exc_ops import (
5455
error_catch_op,
5556
exc_matches_op,
@@ -236,11 +237,11 @@ def add_helper_to_generator_class(
236237
builder: IRBuilder, arg_regs: list[Register], blocks: list[BasicBlock], fn_info: FuncInfo
237238
) -> FuncDecl:
238239
"""Generates a helper method for a generator class, called by '__next__' and 'throw'."""
239-
helper_fn_decl = fn_info.generator_class.ir.method_decls["__mypyc_generator_helper__"]
240+
helper_fn_decl = fn_info.generator_class.ir.method_decls[GENERATOR_HELPER_NAME]
240241
helper_fn_ir = FuncIR(
241242
helper_fn_decl, arg_regs, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name
242243
)
243-
fn_info.generator_class.ir.methods["__mypyc_generator_helper__"] = helper_fn_ir
244+
fn_info.generator_class.ir.methods[GENERATOR_HELPER_NAME] = helper_fn_ir
244245
builder.functions.append(helper_fn_ir)
245246
fn_info.env_class.env_user_function = helper_fn_ir
246247

mypyc/irbuild/ll_builder.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@
9393
dict_rprimitive,
9494
float_rprimitive,
9595
int_rprimitive,
96-
is_bit_rprimitive,
97-
is_bool_rprimitive,
96+
is_bool_or_bit_rprimitive,
9897
is_bytes_rprimitive,
9998
is_dict_rprimitive,
10099
is_fixed_width_rtype,
@@ -175,7 +174,12 @@
175174
unary_ops,
176175
)
177176
from mypyc.primitives.set_ops import new_set_op
178-
from mypyc.primitives.str_ops import str_check_if_true, str_ssize_t_size_op, unicode_compare
177+
from mypyc.primitives.str_ops import (
178+
str_check_if_true,
179+
str_eq,
180+
str_ssize_t_size_op,
181+
unicode_compare,
182+
)
179183
from mypyc.primitives.tuple_ops import list_tuple_op, new_tuple_op, new_tuple_with_length_op
180184
from mypyc.rt_subtype import is_runtime_subtype
181185
from mypyc.sametype import is_same_type
@@ -376,16 +380,12 @@ def coerce(
376380
):
377381
# Equivalent types
378382
return src
379-
elif (is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)) and is_tagged(
380-
target_type
381-
):
383+
elif is_bool_or_bit_rprimitive(src_type) and is_tagged(target_type):
382384
shifted = self.int_op(
383385
bool_rprimitive, src, Integer(1, bool_rprimitive), IntOp.LEFT_SHIFT
384386
)
385387
return self.add(Extend(shifted, target_type, signed=False))
386-
elif (
387-
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
388-
) and is_fixed_width_rtype(target_type):
388+
elif is_bool_or_bit_rprimitive(src_type) and is_fixed_width_rtype(target_type):
389389
return self.add(Extend(src, target_type, signed=False))
390390
elif isinstance(src, Integer) and is_float_rprimitive(target_type):
391391
if is_tagged(src_type):
@@ -1336,7 +1336,11 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13361336
return self.compare_strings(lreg, rreg, op, line)
13371337
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
13381338
return self.compare_bytes(lreg, rreg, op, line)
1339-
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
1339+
if (
1340+
is_bool_or_bit_rprimitive(ltype)
1341+
and is_bool_or_bit_rprimitive(rtype)
1342+
and op in BOOL_BINARY_OPS
1343+
):
13401344
if op in ComparisonOp.signed_ops:
13411345
return self.bool_comparison_op(lreg, rreg, op, line)
13421346
else:
@@ -1350,7 +1354,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13501354
op_id = int_op_to_id[op]
13511355
else:
13521356
op_id = IntOp.DIV
1353-
if is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
1357+
if is_bool_or_bit_rprimitive(rtype):
13541358
rreg = self.coerce(rreg, ltype, line)
13551359
rtype = ltype
13561360
if is_fixed_width_rtype(rtype) or is_tagged(rtype):
@@ -1362,7 +1366,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13621366
elif op in ComparisonOp.signed_ops:
13631367
if is_int_rprimitive(rtype):
13641368
rreg = self.coerce_int_to_fixed_width(rreg, ltype, line)
1365-
elif is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
1369+
elif is_bool_or_bit_rprimitive(rtype):
13661370
rreg = self.coerce(rreg, ltype, line)
13671371
op_id = ComparisonOp.signed_ops[op]
13681372
if is_fixed_width_rtype(rreg.type):
@@ -1382,13 +1386,13 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13821386
)
13831387
if is_tagged(ltype):
13841388
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
1385-
if is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
1389+
if is_bool_or_bit_rprimitive(ltype):
13861390
lreg = self.coerce(lreg, rtype, line)
13871391
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
13881392
elif op in ComparisonOp.signed_ops:
13891393
if is_int_rprimitive(ltype):
13901394
lreg = self.coerce_int_to_fixed_width(lreg, rtype, line)
1391-
elif is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
1395+
elif is_bool_or_bit_rprimitive(ltype):
13921396
lreg = self.coerce(lreg, rtype, line)
13931397
op_id = ComparisonOp.signed_ops[op]
13941398
if isinstance(lreg, Integer):
@@ -1471,6 +1475,11 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -
14711475

14721476
def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
14731477
"""Compare two strings"""
1478+
if op == "==":
1479+
return self.primitive_op(str_eq, [lhs, rhs], line)
1480+
elif op == "!=":
1481+
eq = self.primitive_op(str_eq, [lhs, rhs], line)
1482+
return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line))
14741483
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
14751484
error_constant = Integer(-1, c_int_rprimitive, line)
14761485
compare_error_check = self.add(
@@ -1534,7 +1543,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
15341543
compare = self.binary_op(lhs_item, rhs_item, op, line)
15351544
# Cast to bool if necessary since most types uses comparison returning a object type
15361545
# See generic_ops.py for more information
1537-
if not (is_bool_rprimitive(compare.type) or is_bit_rprimitive(compare.type)):
1546+
if not is_bool_or_bit_rprimitive(compare.type):
15381547
compare = self.primitive_op(bool_op, [compare], line)
15391548
if i < len(lhs.type.types) - 1:
15401549
branch = Branch(compare, early_stop, check_blocks[i + 1], Branch.BOOL)
@@ -1553,7 +1562,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
15531562

15541563
def translate_instance_contains(self, inst: Value, item: Value, op: str, line: int) -> Value:
15551564
res = self.gen_method_call(inst, "__contains__", [item], None, line)
1556-
if not is_bool_rprimitive(res.type):
1565+
if not is_bool_or_bit_rprimitive(res.type):
15571566
res = self.primitive_op(bool_op, [res], line)
15581567
if op == "not in":
15591568
res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), "^", line)
@@ -1580,7 +1589,7 @@ def unary_not(self, value: Value, line: int) -> Value:
15801589

15811590
def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
15821591
typ = value.type
1583-
if is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
1592+
if is_bool_or_bit_rprimitive(typ):
15841593
if expr_op == "not":
15851594
return self.unary_not(value, line)
15861595
if expr_op == "+":
@@ -1738,7 +1747,7 @@ def bool_value(self, value: Value) -> Value:
17381747
17391748
The result type can be bit_rprimitive or bool_rprimitive.
17401749
"""
1741-
if is_bool_rprimitive(value.type) or is_bit_rprimitive(value.type):
1750+
if is_bool_or_bit_rprimitive(value.type):
17421751
result = value
17431752
elif is_runtime_subtype(value.type, int_rprimitive):
17441753
zero = Integer(0, short_int_rprimitive)

mypyc/irbuild/prepare.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
from mypyc.options import CompilerOptions
7272
from mypyc.sametype import is_same_type
7373

74+
GENERATOR_HELPER_NAME = "__mypyc_generator_helper__"
75+
7476

7577
def build_type_map(
7678
mapper: Mapper,
@@ -229,7 +231,7 @@ def create_generator_class_if_needed(
229231

230232
# The implementation of most generator functionality is behind this magic method.
231233
helper_fn_decl = FuncDecl(
232-
"__mypyc_generator_helper__", name, module_name, helper_sig, internal=True
234+
GENERATOR_HELPER_NAME, name, module_name, helper_sig, internal=True
233235
)
234236
cir.method_decls[helper_fn_decl.name] = helper_fn_decl
235237

mypyc/irbuild/statement.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
FinallyNonlocalControl,
9191
TryFinallyNonlocalControl,
9292
)
93+
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
9394
from mypyc.irbuild.targets import (
9495
AssignmentTarget,
9596
AssignmentTargetAttr,
@@ -933,7 +934,7 @@ def emit_yield_from_or_await(
933934
to_yield_reg = Register(object_rprimitive)
934935
received_reg = Register(object_rprimitive)
935936

936-
helper_method = "__mypyc_generator_helper__"
937+
helper_method = GENERATOR_HELPER_NAME
937938
if (
938939
isinstance(val, (Call, MethodCall))
939940
and isinstance(val.type, RInstance)

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
726726
#define RIGHTSTRIP 1
727727
#define BOTHSTRIP 2
728728

729+
char CPyStr_Equal(PyObject *str1, PyObject *str2);
729730
PyObject *CPyStr_Build(Py_ssize_t len, ...);
730731
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
731732
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);

mypyc/lib-rt/str_ops.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@ make_bloom_mask(int kind, const void* ptr, Py_ssize_t len)
6464
#undef BLOOM_UPDATE
6565
}
6666

67+
// Adapted from CPython 3.13.1 (_PyUnicode_Equal)
68+
char CPyStr_Equal(PyObject *str1, PyObject *str2) {
69+
if (str1 == str2) {
70+
return 1;
71+
}
72+
Py_ssize_t len = PyUnicode_GET_LENGTH(str1);
73+
if (PyUnicode_GET_LENGTH(str2) != len)
74+
return 0;
75+
int kind = PyUnicode_KIND(str1);
76+
if (PyUnicode_KIND(str2) != kind)
77+
return 0;
78+
const void *data1 = PyUnicode_DATA(str1);
79+
const void *data2 = PyUnicode_DATA(str2);
80+
return memcmp(data1, data2, len * kind) == 0;
81+
}
82+
6783
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
6884
if (PyUnicode_READY(str) != -1) {
6985
if (CPyTagged_CheckShort(index)) {

0 commit comments

Comments
 (0)