Skip to content

Commit 6a97dc6

Browse files
authored
[mypyc] Speed up equality with optional str/bytes types (#19758)
Specialize most equality (`==` and `!=`) operations when one of the operands is `str | None` or `bytes | None`. First check if the value is `None`, and based on that branch into fast path operations. Previously we used a generic C API primitive for such comparisons, which was quite slow. This could be generalized to other optional types, but let's start with `str | None` since it's a very common type and it's often used in equality tests. `bytes | None` is also covered, since it's very similar to the `str` case. Also add support for unchecked `Cast` operations in the IR. These don't perform a runtime type check -- they can be used to narrow the static type when it can be known statically that the cast is always safe.
1 parent d9c77cb commit 6a97dc6

File tree

9 files changed

+275
-5
lines changed

9 files changed

+275
-5
lines changed

mypyc/codegen/emitfunc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,9 @@ def visit_box(self, op: Box) -> None:
657657
self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True)
658658

659659
def visit_cast(self, op: Cast) -> None:
660+
if op.is_unchecked and op.is_borrowed:
661+
self.emit_line(f"{self.reg(op)} = {self.reg(op.src)};")
662+
return
660663
branch = self.next_branch()
661664
handler = None
662665
if branch is not None:

mypyc/ir/ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1073,11 +1073,19 @@ class Cast(RegisterOp):
10731073

10741074
error_kind = ERR_MAGIC
10751075

1076-
def __init__(self, src: Value, typ: RType, line: int, *, borrow: bool = False) -> None:
1076+
def __init__(
1077+
self, src: Value, typ: RType, line: int, *, borrow: bool = False, unchecked: bool = False
1078+
) -> None:
10771079
super().__init__(line)
10781080
self.src = src
10791081
self.type = typ
1082+
# If true, don't incref the result.
10801083
self.is_borrowed = borrow
1084+
# If true, don't perform a runtime type check (only changes the static type of
1085+
# the operand). Used when we know that the cast will always succeed.
1086+
self.is_unchecked = unchecked
1087+
if unchecked:
1088+
self.error_kind = ERR_NEVER
10811089

10821090
def sources(self) -> list[Value]:
10831091
return [self.src]

mypyc/ir/pprint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,13 @@ def visit_method_call(self, op: MethodCall) -> str:
196196
return s
197197

198198
def visit_cast(self, op: Cast) -> str:
199-
return self.format("%r = %scast(%s, %r)", op, self.borrow_prefix(op), op.type, op.src)
199+
if op.is_unchecked:
200+
prefix = "unchecked "
201+
else:
202+
prefix = ""
203+
return self.format(
204+
"%r = %s%scast(%s, %r)", op, prefix, self.borrow_prefix(op), op.type, op.src
205+
)
200206

201207
def visit_box(self, op: Box) -> str:
202208
return self.format("%r = box(%s, %r)", op, op.src.type, op.src)

mypyc/ir/rtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ def flatten_nested_unions(types: list[RType]) -> list[RType]:
10271027
def optional_value_type(rtype: RType) -> RType | None:
10281028
"""If rtype is the union of none_rprimitive and another type X, return X.
10291029
1030-
Otherwise return None.
1030+
Otherwise, return None.
10311031
"""
10321032
if isinstance(rtype, RUnion) and len(rtype.items) == 2:
10331033
if rtype.items[0] == none_rprimitive:

mypyc/irbuild/ll_builder.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,20 @@ def box(self, src: Value) -> Value:
336336
return src
337337

338338
def unbox_or_cast(
339-
self, src: Value, target_type: RType, line: int, *, can_borrow: bool = False
339+
self,
340+
src: Value,
341+
target_type: RType,
342+
line: int,
343+
*,
344+
can_borrow: bool = False,
345+
unchecked: bool = False,
340346
) -> Value:
341347
if target_type.is_unboxed:
342348
return self.add(Unbox(src, target_type, line))
343349
else:
344350
if can_borrow:
345351
self.keep_alives.append(src)
346-
return self.add(Cast(src, target_type, line, borrow=can_borrow))
352+
return self.add(Cast(src, target_type, line, borrow=can_borrow, unchecked=unchecked))
347353

348354
def coerce(
349355
self,
@@ -2514,6 +2520,22 @@ def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str, line: int) ->
25142520
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype):
25152521
return self.compare_bytes(lreg, rreg, expr_op, line)
25162522

2523+
lopt = optional_value_type(ltype)
2524+
ropt = optional_value_type(rtype)
2525+
2526+
# Can we do a quick comparison of two optional types (special case None values)?
2527+
fast_opt_eq = False
2528+
if lopt is not None:
2529+
if ropt is not None and is_same_type(lopt, ropt) and self._never_equal_to_none(lopt):
2530+
fast_opt_eq = True
2531+
if is_same_type(lopt, rtype) and self._never_equal_to_none(lopt):
2532+
fast_opt_eq = True
2533+
elif ropt is not None:
2534+
if is_same_type(ropt, ltype) and self._never_equal_to_none(ropt):
2535+
fast_opt_eq = True
2536+
if fast_opt_eq:
2537+
return self._translate_fast_optional_eq_cmp(lreg, rreg, expr_op, line)
2538+
25172539
if not (isinstance(ltype, RInstance) and ltype == rtype):
25182540
return None
25192541

@@ -2540,6 +2562,76 @@ def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str, line: int) ->
25402562

25412563
return self.gen_method_call(lreg, op_methods[expr_op], [rreg], ltype, line)
25422564

2565+
def _never_equal_to_none(self, typ: RType) -> bool:
2566+
"""Are the values of type never equal to None?"""
2567+
# TODO: Support RInstance with no custom __eq__/__ne__ and other primitive types.
2568+
return is_str_rprimitive(typ) or is_bytes_rprimitive(typ)
2569+
2570+
def _translate_fast_optional_eq_cmp(
2571+
self, lreg: Value, rreg: Value, expr_op: str, line: int
2572+
) -> Value:
2573+
"""Generate eq/ne fast path between 'X | None' and ('X | None' or X).
2574+
2575+
Assume 'X' never compares equal to None.
2576+
"""
2577+
if not isinstance(lreg.type, RUnion):
2578+
lreg, rreg = rreg, lreg
2579+
value_typ = optional_value_type(lreg.type)
2580+
assert value_typ
2581+
res = Register(bool_rprimitive)
2582+
2583+
# Fast path: left value is None?
2584+
cmp = self.add(ComparisonOp(lreg, self.none_object(), ComparisonOp.EQ, line))
2585+
l_none = BasicBlock()
2586+
l_not_none = BasicBlock()
2587+
out = BasicBlock()
2588+
self.add(Branch(cmp, l_none, l_not_none, Branch.BOOL))
2589+
self.activate_block(l_none)
2590+
if not isinstance(rreg.type, RUnion):
2591+
val = self.false() if expr_op == "==" else self.true()
2592+
self.add(Assign(res, val))
2593+
else:
2594+
op = ComparisonOp.EQ if expr_op == "==" else ComparisonOp.NEQ
2595+
cmp = self.add(ComparisonOp(rreg, self.none_object(), op, line))
2596+
self.add(Assign(res, cmp))
2597+
self.goto(out)
2598+
2599+
self.activate_block(l_not_none)
2600+
if not isinstance(rreg.type, RUnion):
2601+
# Both operands are known to be not None, perform specialized comparison
2602+
eq = self.translate_eq_cmp(
2603+
self.unbox_or_cast(lreg, value_typ, line, can_borrow=True, unchecked=True),
2604+
rreg,
2605+
expr_op,
2606+
line,
2607+
)
2608+
assert eq is not None
2609+
self.add(Assign(res, eq))
2610+
else:
2611+
r_none = BasicBlock()
2612+
r_not_none = BasicBlock()
2613+
# Fast path: right value is None?
2614+
cmp = self.add(ComparisonOp(rreg, self.none_object(), ComparisonOp.EQ, line))
2615+
self.add(Branch(cmp, r_none, r_not_none, Branch.BOOL))
2616+
self.activate_block(r_none)
2617+
# None vs not-None
2618+
val = self.false() if expr_op == "==" else self.true()
2619+
self.add(Assign(res, val))
2620+
self.goto(out)
2621+
self.activate_block(r_not_none)
2622+
# Both operands are known to be not None, perform specialized comparison
2623+
eq = self.translate_eq_cmp(
2624+
self.unbox_or_cast(lreg, value_typ, line, can_borrow=True, unchecked=True),
2625+
self.unbox_or_cast(rreg, value_typ, line, can_borrow=True, unchecked=True),
2626+
expr_op,
2627+
line,
2628+
)
2629+
assert eq is not None
2630+
self.add(Assign(res, eq))
2631+
self.goto(out)
2632+
self.activate_block(out)
2633+
return res
2634+
25432635
def translate_is_op(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Value:
25442636
"""Create equality comparison operation between object identities
25452637

mypyc/test-data/irbuild-bytes.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,35 @@ L0:
185185
r10 = CPyBytes_Build(2, var, r9)
186186
b4 = r10
187187
return 1
188+
189+
[case testOptionalBytesEquality]
190+
from typing import Optional
191+
192+
def non_opt_opt(x: bytes, y: Optional[bytes]) -> bool:
193+
return x != y
194+
[out]
195+
def non_opt_opt(x, y):
196+
x :: bytes
197+
y :: union[bytes, None]
198+
r0 :: object
199+
r1 :: bit
200+
r2 :: bool
201+
r3 :: bytes
202+
r4 :: i32
203+
r5, r6 :: bit
204+
L0:
205+
r0 = load_address _Py_NoneStruct
206+
r1 = y == r0
207+
if r1 goto L1 else goto L2 :: bool
208+
L1:
209+
r2 = 1
210+
goto L3
211+
L2:
212+
r3 = unchecked borrow cast(bytes, y)
213+
r4 = CPyBytes_Compare(r3, x)
214+
r5 = r4 >= 0 :: signed
215+
r6 = r4 != 1
216+
r2 = r6
217+
L3:
218+
keep_alive y
219+
return r2

mypyc/test-data/irbuild-str.test

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,74 @@ L0:
669669
r2 = 'abc1233.14True'
670670
r3 = CPyStr_Build(7, x, r0, x, r1, x, r2, x)
671671
return r3
672+
673+
[case testOptionalStrEquality1]
674+
from typing import Optional
675+
676+
def opt_opt(x: Optional[str], y: Optional[str]) -> bool:
677+
return x == y
678+
[out]
679+
def opt_opt(x, y):
680+
x, y :: union[str, None]
681+
r0 :: object
682+
r1 :: bit
683+
r2 :: object
684+
r3 :: bit
685+
r4 :: bool
686+
r5 :: object
687+
r6 :: bit
688+
r7, r8 :: str
689+
r9 :: bool
690+
L0:
691+
r0 = load_address _Py_NoneStruct
692+
r1 = x == r0
693+
if r1 goto L1 else goto L2 :: bool
694+
L1:
695+
r2 = load_address _Py_NoneStruct
696+
r3 = y == r2
697+
r4 = r3
698+
goto L5
699+
L2:
700+
r5 = load_address _Py_NoneStruct
701+
r6 = y == r5
702+
if r6 goto L3 else goto L4 :: bool
703+
L3:
704+
r4 = 0
705+
goto L5
706+
L4:
707+
r7 = unchecked borrow cast(str, x)
708+
r8 = unchecked borrow cast(str, y)
709+
r9 = CPyStr_Equal(r7, r8)
710+
r4 = r9
711+
L5:
712+
keep_alive x, y
713+
return r4
714+
715+
[case testOptionalStrEquality2]
716+
from typing import Optional
717+
718+
def opt_non_opt(x: Optional[str], y: str) -> bool:
719+
return x == y
720+
[out]
721+
def opt_non_opt(x, y):
722+
x :: union[str, None]
723+
y :: str
724+
r0 :: object
725+
r1 :: bit
726+
r2 :: bool
727+
r3 :: str
728+
r4 :: bool
729+
L0:
730+
r0 = load_address _Py_NoneStruct
731+
r1 = x == r0
732+
if r1 goto L1 else goto L2 :: bool
733+
L1:
734+
r2 = 0
735+
goto L3
736+
L2:
737+
r3 = unchecked borrow cast(str, x)
738+
r4 = CPyStr_Equal(r3, y)
739+
r2 = r4
740+
L3:
741+
keep_alive x
742+
return r2

mypyc/test-data/run-bytes.test

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,30 @@ class subbytearray(bytearray):
374374
[file userdefinedbytes.py]
375375
class bytes:
376376
pass
377+
378+
[case testBytesOptionalEquality]
379+
from __future__ import annotations
380+
381+
def eq_b_opt_b(x: bytes | None, y: bytes) -> bool:
382+
return x == y
383+
384+
def ne_b_b_opt(x: bytes, y: bytes | None) -> bool:
385+
return x != y
386+
387+
def test_optional_eq() -> None:
388+
b = b'x'
389+
assert eq_b_opt_b(b, b)
390+
assert eq_b_opt_b(b + bytes([int()]), b + bytes([int()]))
391+
392+
assert not eq_b_opt_b(b'x', b'y')
393+
assert not eq_b_opt_b(b'y', b'x')
394+
assert not eq_b_opt_b(None, b'x')
395+
396+
def test_optional_ne() -> None:
397+
b = b'x'
398+
assert not ne_b_b_opt(b, b)
399+
assert not ne_b_b_opt(b + b'y', b + bytes() + b'y')
400+
401+
assert ne_b_b_opt(b'x', b'y')
402+
assert ne_b_b_opt(b'y', b'x')
403+
assert ne_b_b_opt(b'x', None)

mypyc/test-data/run-strings.test

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,3 +1063,34 @@ class subc(str):
10631063
[file userdefinedstr.py]
10641064
class str:
10651065
pass
1066+
1067+
[case testStrOptionalEquality]
1068+
from __future__ import annotations
1069+
1070+
def eq_s_opt_s_opt(x: str | None, y: str | None) -> bool:
1071+
return x == y
1072+
1073+
def ne_s_opt_s_opt(x: str | None, y: str | None) -> bool:
1074+
return x != y
1075+
1076+
def test_optional_eq() -> None:
1077+
s = 'x'
1078+
assert eq_s_opt_s_opt(s, s)
1079+
assert eq_s_opt_s_opt(s + str(int()), s + str(int()))
1080+
assert eq_s_opt_s_opt(None, None)
1081+
1082+
assert not eq_s_opt_s_opt('x', 'y')
1083+
assert not eq_s_opt_s_opt('y', 'x')
1084+
assert not eq_s_opt_s_opt(None, 'x')
1085+
assert not eq_s_opt_s_opt('x', None)
1086+
1087+
def test_optional_ne() -> None:
1088+
s = 'x'
1089+
assert not ne_s_opt_s_opt(s, s)
1090+
assert not ne_s_opt_s_opt(s + str(int()), s+ str(int()))
1091+
assert not ne_s_opt_s_opt(None, None)
1092+
1093+
assert ne_s_opt_s_opt('x', 'y')
1094+
assert ne_s_opt_s_opt('y', 'x')
1095+
assert ne_s_opt_s_opt(None, 'x')
1096+
assert ne_s_opt_s_opt('x', None)

0 commit comments

Comments
 (0)