Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ def visit_box(self, op: Box) -> None:
self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True)

def visit_cast(self, op: Cast) -> None:
if op.is_unchecked and op.is_borrowed:
self.emit_line(f"{self.reg(op)} = {self.reg(op.src)};")
return
branch = self.next_branch()
handler = None
if branch is not None:
Expand Down
10 changes: 9 additions & 1 deletion mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,11 +1073,19 @@ class Cast(RegisterOp):

error_kind = ERR_MAGIC

def __init__(self, src: Value, typ: RType, line: int, *, borrow: bool = False) -> None:
def __init__(
self, src: Value, typ: RType, line: int, *, borrow: bool = False, unchecked: bool = False
) -> None:
super().__init__(line)
self.src = src
self.type = typ
# If true, don't incref the result.
self.is_borrowed = borrow
# If true, don't perform a runtime type check (only changes the static type of
# the operand). Used when we know that the cast will always succeed.
self.is_unchecked = unchecked
if unchecked:
self.error_kind = ERR_NEVER

def sources(self) -> list[Value]:
return [self.src]
Expand Down
8 changes: 7 additions & 1 deletion mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,13 @@ def visit_method_call(self, op: MethodCall) -> str:
return s

def visit_cast(self, op: Cast) -> str:
return self.format("%r = %scast(%s, %r)", op, self.borrow_prefix(op), op.type, op.src)
if op.is_unchecked:
prefix = "unchecked "
else:
prefix = ""
return self.format(
"%r = %s%scast(%s, %r)", op, prefix, self.borrow_prefix(op), op.type, op.src
)

def visit_box(self, op: Box) -> str:
return self.format("%r = box(%s, %r)", op, op.src.type, op.src)
Expand Down
2 changes: 1 addition & 1 deletion mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def flatten_nested_unions(types: list[RType]) -> list[RType]:
def optional_value_type(rtype: RType) -> RType | None:
"""If rtype is the union of none_rprimitive and another type X, return X.

Otherwise return None.
Otherwise, return None.
"""
if isinstance(rtype, RUnion) and len(rtype.items) == 2:
if rtype.items[0] == none_rprimitive:
Expand Down
96 changes: 94 additions & 2 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,20 @@ def box(self, src: Value) -> Value:
return src

def unbox_or_cast(
self, src: Value, target_type: RType, line: int, *, can_borrow: bool = False
self,
src: Value,
target_type: RType,
line: int,
*,
can_borrow: bool = False,
unchecked: bool = False,
) -> Value:
if target_type.is_unboxed:
return self.add(Unbox(src, target_type, line))
else:
if can_borrow:
self.keep_alives.append(src)
return self.add(Cast(src, target_type, line, borrow=can_borrow))
return self.add(Cast(src, target_type, line, borrow=can_borrow, unchecked=unchecked))

def coerce(
self,
Expand Down Expand Up @@ -2514,6 +2520,22 @@ def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str, line: int) ->
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype):
return self.compare_bytes(lreg, rreg, expr_op, line)

lopt = optional_value_type(ltype)
ropt = optional_value_type(rtype)

# Can we do a quick comparison of two optional types (special case None values)?
fast_opt_eq = False
if lopt is not None:
if ropt is not None and is_same_type(lopt, ropt) and self._never_equal_to_none(lopt):
fast_opt_eq = True
if is_same_type(lopt, rtype) and self._never_equal_to_none(lopt):
fast_opt_eq = True
elif ropt is not None:
if is_same_type(ropt, ltype) and self._never_equal_to_none(ropt):
fast_opt_eq = True
if fast_opt_eq:
return self._translate_fast_optional_eq_cmp(lreg, rreg, expr_op, line)

if not (isinstance(ltype, RInstance) and ltype == rtype):
return None

Expand All @@ -2540,6 +2562,76 @@ def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str, line: int) ->

return self.gen_method_call(lreg, op_methods[expr_op], [rreg], ltype, line)

def _never_equal_to_none(self, typ: RType) -> bool:
"""Are the values of type never equal to None?"""
# TODO: Support RInstance with no custom __eq__/__ne__ and other primitive types.
return is_str_rprimitive(typ) or is_bytes_rprimitive(typ)

def _translate_fast_optional_eq_cmp(
self, lreg: Value, rreg: Value, expr_op: str, line: int
) -> Value:
"""Generate eq/ne fast path between 'X | None' and ('X | None' or X).

Assume 'X' never compares equal to None.
"""
if not isinstance(lreg.type, RUnion):
lreg, rreg = rreg, lreg
value_typ = optional_value_type(lreg.type)
assert value_typ
res = Register(bool_rprimitive)

# Fast path: left value is None?
cmp = self.add(ComparisonOp(lreg, self.none_object(), ComparisonOp.EQ, line))
l_none = BasicBlock()
l_not_none = BasicBlock()
out = BasicBlock()
self.add(Branch(cmp, l_none, l_not_none, Branch.BOOL))
self.activate_block(l_none)
if not isinstance(rreg.type, RUnion):
val = self.false() if expr_op == "==" else self.true()
self.add(Assign(res, val))
else:
op = ComparisonOp.EQ if expr_op == "==" else ComparisonOp.NEQ
cmp = self.add(ComparisonOp(rreg, self.none_object(), op, line))
self.add(Assign(res, cmp))
self.goto(out)

self.activate_block(l_not_none)
if not isinstance(rreg.type, RUnion):
# Both operands are known to be not None, perform specialized comparison
eq = self.translate_eq_cmp(
self.unbox_or_cast(lreg, value_typ, line, can_borrow=True, unchecked=True),
rreg,
expr_op,
line,
)
assert eq is not None
self.add(Assign(res, eq))
else:
r_none = BasicBlock()
r_not_none = BasicBlock()
# Fast path: eight value is None?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Fast path: eight value is None?
# Fast path: right value is None?

cmp = self.add(ComparisonOp(rreg, self.none_object(), ComparisonOp.EQ, line))
self.add(Branch(cmp, r_none, r_not_none, Branch.BOOL))
self.activate_block(r_none)
# None vs not-None
val = self.false() if expr_op == "==" else self.true()
self.add(Assign(res, val))
self.goto(out)
self.activate_block(r_not_none)
# Both operands are known to be not None, perform specialized comparison
eq = self.translate_eq_cmp(
self.unbox_or_cast(lreg, value_typ, line, can_borrow=True, unchecked=True),
self.unbox_or_cast(rreg, value_typ, line, can_borrow=True, unchecked=True),
expr_op,
line,
)
assert eq is not None
self.add(Assign(res, eq))
self.goto(out)
self.activate_block(out)
return res

def translate_is_op(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Value:
"""Create equality comparison operation between object identities

Expand Down
32 changes: 32 additions & 0 deletions mypyc/test-data/irbuild-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,35 @@ L0:
r10 = CPyBytes_Build(2, var, r9)
b4 = r10
return 1

[case testOptionalBytesEquality]
from typing import Optional

def non_opt_opt(x: bytes, y: Optional[bytes]) -> bool:
return x != y
[out]
def non_opt_opt(x, y):
x :: bytes
y :: union[bytes, None]
r0 :: object
r1 :: bit
r2 :: bool
r3 :: bytes
r4 :: i32
r5, r6 :: bit
L0:
r0 = load_address _Py_NoneStruct
r1 = y == r0
if r1 goto L1 else goto L2 :: bool
L1:
r2 = 1
goto L3
L2:
r3 = unchecked borrow cast(bytes, y)
r4 = CPyBytes_Compare(r3, x)
r5 = r4 >= 0 :: signed
r6 = r4 != 1
r2 = r6
L3:
keep_alive y
return r2
71 changes: 71 additions & 0 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,74 @@ L0:
r2 = 'abc1233.14True'
r3 = CPyStr_Build(7, x, r0, x, r1, x, r2, x)
return r3

[case testOptionalStrEquality1]
from typing import Optional

def opt_opt(x: Optional[str], y: Optional[str]) -> bool:
return x == y
[out]
def opt_opt(x, y):
x, y :: union[str, None]
r0 :: object
r1 :: bit
r2 :: object
r3 :: bit
r4 :: bool
r5 :: object
r6 :: bit
r7, r8 :: str
r9 :: bool
L0:
r0 = load_address _Py_NoneStruct
r1 = x == r0
if r1 goto L1 else goto L2 :: bool
L1:
r2 = load_address _Py_NoneStruct
r3 = y == r2
r4 = r3
goto L5
L2:
r5 = load_address _Py_NoneStruct
r6 = y == r5
if r6 goto L3 else goto L4 :: bool
L3:
r4 = 0
goto L5
L4:
r7 = unchecked borrow cast(str, x)
r8 = unchecked borrow cast(str, y)
r9 = CPyStr_Equal(r7, r8)
r4 = r9
L5:
keep_alive x, y
return r4

[case testOptionalStrEquality2]
from typing import Optional

def opt_non_opt(x: Optional[str], y: str) -> bool:
return x == y
[out]
def opt_non_opt(x, y):
x :: union[str, None]
y :: str
r0 :: object
r1 :: bit
r2 :: bool
r3 :: str
r4 :: bool
L0:
r0 = load_address _Py_NoneStruct
r1 = x == r0
if r1 goto L1 else goto L2 :: bool
L1:
r2 = 0
goto L3
L2:
r3 = unchecked borrow cast(str, x)
r4 = CPyStr_Equal(r3, y)
r2 = r4
L3:
keep_alive x
return r2
27 changes: 27 additions & 0 deletions mypyc/test-data/run-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,30 @@ class subbytearray(bytearray):
[file userdefinedbytes.py]
class bytes:
pass

[case testBytesOptionalEquality]
from __future__ import annotations

def eq_b_opt_b(x: bytes | None, y: bytes) -> bool:
return x == y

def ne_b_b_opt(x: bytes, y: bytes | None) -> bool:
return x != y

def test_optional_eq() -> None:
b = b'x'
assert eq_b_opt_b(b, b)
assert eq_b_opt_b(b + bytes([int()]), b + bytes([int()]))

assert not eq_b_opt_b(b'x', b'y')
assert not eq_b_opt_b(b'y', b'x')
assert not eq_b_opt_b(None, b'x')

def test_optional_ne() -> None:
b = b'x'
assert not ne_b_b_opt(b, b)
assert not ne_b_b_opt(b + b'y', b + bytes() + b'y')

assert ne_b_b_opt(b'x', b'y')
assert ne_b_b_opt(b'y', b'x')
assert ne_b_b_opt(b'x', None)
31 changes: 31 additions & 0 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -1063,3 +1063,34 @@ class subc(str):
[file userdefinedstr.py]
class str:
pass

[case testStrOptionalEquality]
from __future__ import annotations

def eq_s_opt_s_opt(x: str | None, y: str | None) -> bool:
return x == y

def ne_s_opt_s_opt(x: str | None, y: str | None) -> bool:
return x != y

def test_optional_eq() -> None:
s = 'x'
assert eq_s_opt_s_opt(s, s)
assert eq_s_opt_s_opt(s + str(int()), s + str(int()))
assert eq_s_opt_s_opt(None, None)

assert not eq_s_opt_s_opt('x', 'y')
assert not eq_s_opt_s_opt('y', 'x')
assert not eq_s_opt_s_opt(None, 'x')
assert not eq_s_opt_s_opt('x', None)

def test_optional_ne() -> None:
s = 'x'
assert not ne_s_opt_s_opt(s, s)
assert not ne_s_opt_s_opt(s + str(int()), s+ str(int()))
assert not ne_s_opt_s_opt(None, None)

assert ne_s_opt_s_opt('x', 'y')
assert ne_s_opt_s_opt('y', 'x')
assert ne_s_opt_s_opt(None, 'x')
assert ne_s_opt_s_opt('x', None)
Loading