Skip to content

Commit bd17a57

Browse files
Require SupportsBool instead of bool for comparisons. (#14375)
1 parent 28f4bdf commit bd17a57

File tree

2 files changed

+101
-5
lines changed

2 files changed

+101
-5
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing_extensions import assert_type
2+
3+
4+
def test_min_builtin() -> None:
5+
# legal comparisons that succeed at runtime
6+
b1, b2 = bool(True), bool(False)
7+
i1, i2 = int(1), int(2)
8+
s1, s2 = str("a"), str("b")
9+
f1, f2 = float(0.5), float(2.3)
10+
l1, l2 = list[int]([1, 2]), list[int]([3, 4])
11+
t1, t2 = tuple[str, str](("A", "B")), tuple[str, str](("C", "D"))
12+
tN = tuple[str, ...](["A", "B", "C"])
13+
14+
assert_type(min(b1, b2), bool)
15+
assert_type(min(i1, i2), int)
16+
assert_type(min(s1, s2), str)
17+
assert_type(min(f1, f2), float)
18+
19+
# mixed numerical types (note: float = int or float)
20+
assert_type(min(b1, i1), int)
21+
assert_type(min(i1, b1), int)
22+
23+
assert_type(min(b1, f1), float)
24+
assert_type(min(f1, b1), float)
25+
26+
assert_type(min(i1, f1), float)
27+
assert_type(min(f1, i1), float)
28+
29+
# comparisons with lists and tuples
30+
assert_type(min(l1, l2), list[int])
31+
assert_type(min(t1, t2), tuple[str, str])
32+
assert_type(min(tN, t2), tuple[str, ...])
33+
34+
35+
def test_min_bad_builtin() -> None:
36+
# illegal comparisons that fail at runtime
37+
i1 = int(1)
38+
s1 = str("a")
39+
f1 = float(1.0)
40+
c1, c2 = complex(1.0, 2.0), complex(3.0, 4.0)
41+
list_str = list[str](["A", "B"])
42+
list_int = list[int]([2, 3])
43+
tup_str = tuple[str, str](("A", "B"))
44+
tup_int = tuple[int, int]((2, 3))
45+
46+
# True negatives.
47+
min(c1, c2) # type: ignore
48+
49+
# FIXME: False negatives.
50+
min(i1, s1)
51+
min(s1, f1)
52+
min(f1, list_str)
53+
min(list_str, list_int)
54+
min(tup_str, tup_int)
55+
56+
57+
def test_min_custom_comparison() -> None:
58+
class BoolScalar:
59+
def __bool__(self) -> bool: ...
60+
61+
class FloatScalar:
62+
def __float__(self) -> float: ...
63+
def __ge__(self, other: "FloatScalar") -> BoolScalar: ...
64+
def __gt__(self, other: "FloatScalar") -> BoolScalar: ...
65+
def __lt__(self, other: "FloatScalar") -> BoolScalar: ...
66+
def __le__(self, other: "FloatScalar") -> BoolScalar: ...
67+
68+
f1 = FloatScalar()
69+
f2 = FloatScalar()
70+
71+
assert_type(min(f1, f2), FloatScalar)
72+
73+
74+
def test_min_bad_custom_type() -> None:
75+
class FloatScalar:
76+
def __float__(self) -> float: ...
77+
def __ge__(self, other: "FloatScalar") -> object:
78+
return object()
79+
80+
def __gt__(self, other: "FloatScalar") -> object:
81+
return object()
82+
83+
def __lt__(self, other: "FloatScalar") -> object:
84+
return object()
85+
86+
def __le__(self, other: "FloatScalar") -> object:
87+
return object()
88+
89+
f1 = FloatScalar()
90+
f2 = FloatScalar()
91+
92+
# Note: min(f1, f2) works at runtime, but always returns the second argument.
93+
# therefore, we require returning a boolean-like type for comparisons.
94+
min(f1, f2) # type: ignore

stdlib/_typeshed/__init__.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,21 @@ class SupportsNext(Protocol[_T_co]):
8282
class SupportsAnext(Protocol[_T_co]):
8383
def __anext__(self) -> Awaitable[_T_co]: ...
8484

85-
# Comparison protocols
85+
class SupportsBool(Protocol):
86+
def __bool__(self) -> bool: ...
8687

88+
# Comparison protocols
8789
class SupportsDunderLT(Protocol[_T_contra]):
88-
def __lt__(self, other: _T_contra, /) -> bool: ...
90+
def __lt__(self, other: _T_contra, /) -> SupportsBool: ...
8991

9092
class SupportsDunderGT(Protocol[_T_contra]):
91-
def __gt__(self, other: _T_contra, /) -> bool: ...
93+
def __gt__(self, other: _T_contra, /) -> SupportsBool: ...
9294

9395
class SupportsDunderLE(Protocol[_T_contra]):
94-
def __le__(self, other: _T_contra, /) -> bool: ...
96+
def __le__(self, other: _T_contra, /) -> SupportsBool: ...
9597

9698
class SupportsDunderGE(Protocol[_T_contra]):
97-
def __ge__(self, other: _T_contra, /) -> bool: ...
99+
def __ge__(self, other: _T_contra, /) -> SupportsBool: ...
98100

99101
class SupportsAllComparisons(
100102
SupportsDunderLT[Any], SupportsDunderGT[Any], SupportsDunderLE[Any], SupportsDunderGE[Any], Protocol

0 commit comments

Comments
 (0)