Skip to content

Commit 4ef638a

Browse files
authored
TYP: try out TypeGuard (#51309)
1 parent a29c206 commit 4ef638a

File tree

19 files changed

+69
-46
lines changed

19 files changed

+69
-46
lines changed

pandas/_libs/lib.pyi

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TODO(npdtypes): Many types specified here can be made more specific/accurate;
22
# the more specific versions are specified in comments
3-
3+
from decimal import Decimal
44
from typing import (
55
Any,
66
Callable,
@@ -13,9 +13,12 @@ from typing import (
1313

1414
import numpy as np
1515

16+
from pandas._libs.interval import Interval
17+
from pandas._libs.tslibs import Period
1618
from pandas._typing import (
1719
ArrayLike,
1820
DtypeObj,
21+
TypeGuard,
1922
npt,
2023
)
2124

@@ -38,13 +41,13 @@ def infer_dtype(value: object, skipna: bool = ...) -> str: ...
3841
def is_iterator(obj: object) -> bool: ...
3942
def is_scalar(val: object) -> bool: ...
4043
def is_list_like(obj: object, allow_sets: bool = ...) -> bool: ...
41-
def is_period(val: object) -> bool: ...
42-
def is_interval(val: object) -> bool: ...
43-
def is_decimal(val: object) -> bool: ...
44-
def is_complex(val: object) -> bool: ...
45-
def is_bool(val: object) -> bool: ...
46-
def is_integer(val: object) -> bool: ...
47-
def is_float(val: object) -> bool: ...
44+
def is_period(val: object) -> TypeGuard[Period]: ...
45+
def is_interval(val: object) -> TypeGuard[Interval]: ...
46+
def is_decimal(val: object) -> TypeGuard[Decimal]: ...
47+
def is_complex(val: object) -> TypeGuard[complex]: ...
48+
def is_bool(val: object) -> TypeGuard[bool | np.bool_]: ...
49+
def is_integer(val: object) -> TypeGuard[int | np.integer]: ...
50+
def is_float(val: object) -> TypeGuard[float]: ...
4851
def is_interval_array(values: np.ndarray) -> bool: ...
4952
def is_datetime64_array(values: np.ndarray) -> bool: ...
5053
def is_timedelta_or_timedelta64_array(values: np.ndarray) -> bool: ...

pandas/_typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,19 @@
8484
# Name "npt._ArrayLikeInt_co" is not defined [name-defined]
8585
NumpySorter = Optional[npt._ArrayLikeInt_co] # type: ignore[name-defined]
8686

87+
if sys.version_info >= (3, 10):
88+
from typing import TypeGuard
89+
else:
90+
from typing_extensions import TypeGuard # pyright: reportUnusedImport = false
91+
8792
if sys.version_info >= (3, 11):
8893
from typing import Self
8994
else:
9095
from typing_extensions import Self # pyright: reportUnusedImport = false
9196
else:
9297
npt: Any = None
9398
Self: Any = None
99+
TypeGuard: Any = None
94100

95101
HashableT = TypeVar("HashableT", bound=Hashable)
96102

pandas/compat/numpy/function.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
overload,
2626
)
2727

28+
import numpy as np
2829
from numpy import ndarray
2930

3031
from pandas._libs.lib import (
@@ -215,7 +216,7 @@ def validate_clip_with_axis(
215216
)
216217

217218

218-
def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
219+
def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
219220
"""
220221
If this function is called via the 'numpy' library, the third parameter in
221222
its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
@@ -224,6 +225,8 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
224225
if not is_bool(skipna):
225226
args = (skipna,) + args
226227
skipna = True
228+
elif isinstance(skipna, np.bool_):
229+
skipna = bool(skipna)
227230

228231
validate_cum_func(args, kwargs, fname=name)
229232
return skipna

pandas/core/arrays/datetimelike.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2172,7 +2172,6 @@ def validate_periods(periods: int | float | None) -> int | None:
21722172
periods = int(periods)
21732173
elif not lib.is_integer(periods):
21742174
raise TypeError(f"periods must be a number, got {periods}")
2175-
periods = cast(int, periods)
21762175
return periods
21772176

21782177

pandas/core/dtypes/cast.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,9 @@ def maybe_box_native(value: Scalar | None | NAType) -> Scalar | None | NAType:
191191
scalar or Series
192192
"""
193193
if is_float(value):
194-
# error: Argument 1 to "float" has incompatible type
195-
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
196-
# expected "Union[SupportsFloat, _SupportsIndex, str]"
197-
value = float(value) # type: ignore[arg-type]
194+
value = float(value)
198195
elif is_integer(value):
199-
# error: Argument 1 to "int" has incompatible type
200-
# "Union[Union[str, int, float, bool], Union[Any, Timestamp, Timedelta, Any]]";
201-
# expected "Union[str, SupportsInt, _SupportsIndex, _SupportsTrunc]"
202-
value = int(value) # type: ignore[arg-type]
196+
value = int(value)
203197
elif is_bool(value):
204198
value = bool(value)
205199
elif isinstance(value, (np.datetime64, np.timedelta64)):

pandas/core/dtypes/inference.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
from collections import abc
66
from numbers import Number
77
import re
8-
from typing import Pattern
8+
from typing import (
9+
TYPE_CHECKING,
10+
Hashable,
11+
Pattern,
12+
)
913

1014
import numpy as np
1115

1216
from pandas._libs import lib
1317

18+
if TYPE_CHECKING:
19+
from pandas._typing import TypeGuard
20+
1421
is_bool = lib.is_bool
1522

1623
is_integer = lib.is_integer
@@ -30,7 +37,7 @@
3037
is_iterator = lib.is_iterator
3138

3239

33-
def is_number(obj) -> bool:
40+
def is_number(obj) -> TypeGuard[Number | np.number]:
3441
"""
3542
Check if the object is a number.
3643
@@ -132,7 +139,7 @@ def is_file_like(obj) -> bool:
132139
return bool(hasattr(obj, "__iter__"))
133140

134141

135-
def is_re(obj) -> bool:
142+
def is_re(obj) -> TypeGuard[Pattern]:
136143
"""
137144
Check if the object is a regex pattern instance.
138145
@@ -325,7 +332,7 @@ def is_named_tuple(obj) -> bool:
325332
return isinstance(obj, abc.Sequence) and hasattr(obj, "_fields")
326333

327334

328-
def is_hashable(obj) -> bool:
335+
def is_hashable(obj) -> TypeGuard[Hashable]:
329336
"""
330337
Return True if hash(obj) will succeed, False otherwise.
331338

pandas/core/frame.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9516,11 +9516,7 @@ def melt(
95169516
)
95179517
def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame:
95189518
if not lib.is_integer(periods):
9519-
if not (
9520-
is_float(periods)
9521-
# error: "int" has no attribute "is_integer"
9522-
and periods.is_integer() # type: ignore[attr-defined]
9523-
):
9519+
if not (is_float(periods) and periods.is_integer()):
95249520
raise ValueError("periods must be an integer")
95259521
periods = int(periods)
95269522

@@ -10412,8 +10408,13 @@ def _series_round(ser: Series, decimals: int) -> Series:
1041210408
new_cols = list(_dict_round(self, decimals))
1041310409
elif is_integer(decimals):
1041410410
# Dispatch to Block.round
10411+
# Argument "decimals" to "round" of "BaseBlockManager" has incompatible
10412+
# type "Union[int, integer[Any]]"; expected "int"
1041510413
return self._constructor(
10416-
self._mgr.round(decimals=decimals, using_cow=using_copy_on_write()),
10414+
self._mgr.round(
10415+
decimals=decimals, # type: ignore[arg-type]
10416+
using_cow=using_copy_on_write(),
10417+
),
1041710418
).__finalize__(self, method="round")
1041810419
else:
1041910420
raise TypeError("decimals must be an integer, a dict-like or a Series")

pandas/core/generic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4096,7 +4096,8 @@ class animal locomotion
40964096
loc, new_index = index._get_loc_level(key, level=0)
40974097
if not drop_level:
40984098
if lib.is_integer(loc):
4099-
new_index = index[loc : loc + 1]
4099+
# Slice index must be an integer or None
4100+
new_index = index[loc : loc + 1] # type: ignore[misc]
41004101
else:
41014102
new_index = index[loc]
41024103
else:

pandas/core/indexes/api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@
7070

7171

7272
def get_objs_combined_axis(
73-
objs, intersect: bool = False, axis: Axis = 0, sort: bool = True, copy: bool = False
73+
objs,
74+
intersect: bool = False,
75+
axis: Axis = 0,
76+
sort: bool = True,
77+
copy: bool = False,
7478
) -> Index:
7579
"""
7680
Extract combined index: return intersection or union (depending on the

pandas/core/indexes/multi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2700,6 +2700,7 @@ def _partial_tup_index(self, tup: tuple, side: Literal["left", "right"] = "left"
27002700
for k, (lab, lev, level_codes) in enumerate(zipped):
27012701
section = level_codes[start:end]
27022702

2703+
loc: npt.NDArray[np.intp] | np.intp | int
27032704
if lab not in lev and not isna(lab):
27042705
# short circuit
27052706
try:
@@ -2931,7 +2932,8 @@ def get_loc_level(self, key, level: IndexLabel = 0, drop_level: bool = True):
29312932
loc, mi = self._get_loc_level(key, level=level)
29322933
if not drop_level:
29332934
if lib.is_integer(loc):
2934-
mi = self[loc : loc + 1]
2935+
# Slice index must be an integer or None
2936+
mi = self[loc : loc + 1] # type: ignore[misc]
29352937
else:
29362938
mi = self[loc]
29372939
return loc, mi

0 commit comments

Comments
 (0)