Skip to content

Commit 7bdfb49

Browse files
committed
Add numpy array shapes for return types
1 parent d6dc067 commit 7bdfb49

26 files changed

+415
-280
lines changed

pandas-stubs/_libs/interval.pyi

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ from pandas.core.series import (
2121
from pandas._typing import (
2222
IntervalClosedType,
2323
IntervalT,
24-
np_ndarray_bool,
24+
np_1darray,
2525
npt,
2626
)
2727

@@ -170,7 +170,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
170170
@overload
171171
def __gt__(self, other: Interval[_OrderableT]) -> bool: ...
172172
@overload
173-
def __gt__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
173+
def __gt__(
174+
self: IntervalT, other: IntervalIndex[IntervalT]
175+
) -> np_1darray[np.bool]: ...
174176
@overload
175177
def __gt__(
176178
self,
@@ -179,7 +181,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
179181
@overload
180182
def __lt__(self, other: Interval[_OrderableT]) -> bool: ...
181183
@overload
182-
def __lt__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
184+
def __lt__(
185+
self: IntervalT, other: IntervalIndex[IntervalT]
186+
) -> np_1darray[np.bool]: ...
183187
@overload
184188
def __lt__(
185189
self,
@@ -188,7 +192,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
188192
@overload
189193
def __ge__(self, other: Interval[_OrderableT]) -> bool: ...
190194
@overload
191-
def __ge__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
195+
def __ge__(
196+
self: IntervalT, other: IntervalIndex[IntervalT]
197+
) -> np_1darray[np.bool]: ...
192198
@overload
193199
def __ge__(
194200
self,
@@ -197,19 +203,25 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
197203
@overload
198204
def __le__(self, other: Interval[_OrderableT]) -> bool: ...
199205
@overload
200-
def __le__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
206+
def __le__(
207+
self: IntervalT, other: IntervalIndex[IntervalT]
208+
) -> np_1darray[np.bool]: ...
201209
@overload
202210
def __eq__(self, other: Interval[_OrderableT]) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
203211
@overload
204-
def __eq__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
212+
def __eq__(
213+
self: IntervalT, other: IntervalIndex[IntervalT]
214+
) -> np_1darray[np.bool]: ...
205215
@overload
206216
def __eq__(self, other: Series[_OrderableT]) -> Series[bool]: ... # type: ignore[overload-overlap]
207217
@overload
208218
def __eq__(self, other: object) -> Literal[False]: ...
209219
@overload
210220
def __ne__(self, other: Interval[_OrderableT]) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
211221
@overload
212-
def __ne__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
222+
def __ne__(
223+
self: IntervalT, other: IntervalIndex[IntervalT]
224+
) -> np_1darray[np.bool]: ...
213225
@overload
214226
def __ne__(self, other: Series[_OrderableT]) -> Series[bool]: ... # type: ignore[overload-overlap]
215227
@overload

pandas-stubs/_libs/tslibs/timestamps.pyi

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ from pandas._libs.tslibs import (
4040
Timedelta,
4141
)
4242
from pandas._typing import (
43+
ShapeT,
4344
TimestampNonexistent,
4445
TimeUnit,
45-
np_ndarray_bool,
46+
np_1darray,
47+
np_ndarray,
4648
npt,
4749
)
4850

@@ -180,40 +182,48 @@ class Timestamp(datetime, SupportsIndex):
180182
@overload # type: ignore[override]
181183
def __le__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[misc]
182184
@overload
185+
def __le__(self, other: DatetimeIndex) -> np_1darray[np.bool]: ...
186+
@overload
183187
def __le__(
184-
self, other: DatetimeIndex | npt.NDArray[np.datetime64]
185-
) -> np_ndarray_bool: ...
188+
self, other: np_ndarray[ShapeT, np.datetime64]
189+
) -> np_ndarray[ShapeT, np.bool]: ...
186190
@overload
187191
def __le__(self, other: TimestampSeries) -> Series[bool]: ...
188192
@overload # type: ignore[override]
189193
def __lt__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[misc]
190194
@overload
195+
def __lt__(self, other: DatetimeIndex) -> np_1darray[np.bool]: ...
196+
@overload
191197
def __lt__(
192-
self, other: DatetimeIndex | npt.NDArray[np.datetime64]
193-
) -> np_ndarray_bool: ...
198+
self, other: np_ndarray[ShapeT, np.datetime64]
199+
) -> np_ndarray[ShapeT, np.bool]: ...
194200
@overload
195201
def __lt__(self, other: TimestampSeries) -> Series[bool]: ...
196202
@overload # type: ignore[override]
197203
def __ge__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[misc]
198204
@overload
205+
def __ge__(self, other: DatetimeIndex) -> np_1darray[np.bool]: ...
206+
@overload
199207
def __ge__(
200-
self, other: DatetimeIndex | npt.NDArray[np.datetime64]
201-
) -> np_ndarray_bool: ...
208+
self, other: np_ndarray[ShapeT, np.datetime64]
209+
) -> np_ndarray[ShapeT, np.bool]: ...
202210
@overload
203211
def __ge__(self, other: TimestampSeries) -> Series[bool]: ...
204212
@overload # type: ignore[override]
205213
def __gt__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[misc]
206214
@overload
215+
def __gt__(self, other: DatetimeIndex) -> np_1darray[np.bool]: ...
216+
@overload
207217
def __gt__(
208-
self, other: DatetimeIndex | npt.NDArray[np.datetime64]
209-
) -> np_ndarray_bool: ...
218+
self, other: np_ndarray[ShapeT, np.datetime64]
219+
) -> np_ndarray[ShapeT, np.bool]: ...
210220
@overload
211221
def __gt__(self, other: TimestampSeries) -> Series[bool]: ...
212222
# error: Signature of "__add__" incompatible with supertype "date"/"datetime"
213223
@overload # type: ignore[override]
214224
def __add__(
215-
self, other: npt.NDArray[np.timedelta64]
216-
) -> npt.NDArray[np.datetime64]: ...
225+
self, other: np_ndarray[ShapeT, np.timedelta64]
226+
) -> np_ndarray[ShapeT, np.datetime64]: ...
217227
@overload
218228
def __add__(self, other: timedelta | np.timedelta64 | Tick) -> Self: ...
219229
@overload
@@ -226,8 +236,8 @@ class Timestamp(datetime, SupportsIndex):
226236
def __radd__(self, other: TimedeltaIndex) -> DatetimeIndex: ...
227237
@overload
228238
def __radd__(
229-
self, other: npt.NDArray[np.timedelta64]
230-
) -> npt.NDArray[np.datetime64]: ...
239+
self, other: np_ndarray[ShapeT, np.timedelta64]
240+
) -> np_ndarray[ShapeT, np.datetime64]: ...
231241
# TODO: test dt64
232242
@overload # type: ignore[override]
233243
def __sub__(self, other: Timestamp | datetime | np.datetime64) -> Timedelta: ...
@@ -241,22 +251,26 @@ class Timestamp(datetime, SupportsIndex):
241251
def __sub__(self, other: TimestampSeries) -> TimedeltaSeries: ...
242252
@overload
243253
def __sub__(
244-
self, other: npt.NDArray[np.timedelta64]
245-
) -> npt.NDArray[np.datetime64]: ...
254+
self, other: np_ndarray[ShapeT, np.timedelta64]
255+
) -> np_ndarray[ShapeT, np.datetime64]: ...
246256
@overload
247257
def __eq__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
248258
@overload
249259
def __eq__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
250260
@overload
251-
def __eq__(self, other: npt.NDArray[np.datetime64] | Index) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
261+
def __eq__(self, other: Index) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
262+
@overload # TODO: using shape-aware arrays similar to other methods doesn't work in mypy
263+
def __eq__(self, other: npt.NDArray[np.datetime64]) -> npt.NDArray[np.bool]: ... # type: ignore[overload-overlap]
252264
@overload
253265
def __eq__(self, other: object) -> Literal[False]: ...
254266
@overload
255267
def __ne__(self, other: Timestamp | datetime | np.datetime64) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
256268
@overload
257269
def __ne__(self, other: TimestampSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
258270
@overload
259-
def __ne__(self, other: npt.NDArray[np.datetime64] | Index) -> np_ndarray_bool: ... # type: ignore[overload-overlap]
271+
def __ne__(self, other: Index) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
272+
@overload # TODO: using shape-aware arrays similar to other methods doesn't work in mypy
273+
def __ne__(self, other: npt.NDArray[np.datetime64]) -> npt.NDArray[np.bool]: ... # type: ignore[overload-overlap]
260274
@overload
261275
def __ne__(self, other: object) -> Literal[True]: ...
262276
def __hash__(self) -> int: ...

pandas-stubs/_typing.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,15 @@ np_ndarray_complex: TypeAlias = npt.NDArray[np.complexfloating]
819819
np_ndarray_bool: TypeAlias = npt.NDArray[np.bool_]
820820
np_ndarray_str: TypeAlias = npt.NDArray[np.str_]
821821

822+
# Define shape and generic type variables with defaults similar to numpy
823+
GenericT = TypeVar("GenericT", bound=np.generic, default=Any)
824+
ShapeT = TypeVar("ShapeT", bound=tuple[int, ...], default=tuple[Any, ...])
825+
# Numpy ndarray with more ergonomic typevar
826+
np_ndarray: TypeAlias = np.ndarray[ShapeT, np.dtype[GenericT]]
827+
# Numpy arrays with known shape (Do not use as argument types, only as return types)
828+
np_1darray: TypeAlias = np.ndarray[tuple[int], np.dtype[GenericT]]
829+
np_2darray: TypeAlias = np.ndarray[tuple[int, int], np.dtype[GenericT]]
830+
822831
IndexType: TypeAlias = slice | np_ndarray_anyint | Index | list[int] | Series[int]
823832
MaskType: TypeAlias = Series[bool] | np_ndarray_bool | list[bool]
824833

pandas-stubs/core/algorithms.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ from pandas._typing import (
1818
AnyArrayLike,
1919
IntervalT,
2020
TakeIndexer,
21+
np_1darray,
2122
)
2223

2324
# These are type: ignored because the Index types overlap due to inheritance but indices
@@ -54,14 +55,14 @@ def factorize(
5455
sort: bool = ...,
5556
use_na_sentinel: bool = ...,
5657
size_hint: int | None = ...,
57-
) -> tuple[np.ndarray, Index]: ...
58+
) -> tuple[np_1darray, Index]: ...
5859
@overload
5960
def factorize(
6061
values: Categorical,
6162
sort: bool = ...,
6263
use_na_sentinel: bool = ...,
6364
size_hint: int | None = ...,
64-
) -> tuple[np.ndarray, Categorical]: ...
65+
) -> tuple[np_1darray, Categorical]: ...
6566
def value_counts(
6667
values: AnyArrayLike | list | tuple,
6768
sort: bool = True,

pandas-stubs/core/arrays/base.pyi

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from pandas._typing import (
1212
ScalarIndexer,
1313
SequenceIndexer,
1414
TakeIndexer,
15+
np_1darray,
1516
npt,
1617
)
1718

@@ -31,7 +32,7 @@ class ExtensionArray:
3132
dtype: npt.DTypeLike | None = ...,
3233
copy: bool = False,
3334
na_value: Scalar = ...,
34-
) -> np.ndarray: ...
35+
) -> np_1darray: ...
3536
@property
3637
def dtype(self) -> ExtensionDtype: ...
3738
@property
@@ -44,13 +45,13 @@ class ExtensionArray:
4445
def isna(self) -> ArrayLike: ...
4546
def argsort(
4647
self, *, ascending: bool = ..., kind: str = ..., **kwargs
47-
) -> np.ndarray: ...
48+
) -> np_1darray: ...
4849
def fillna(self, value=..., method=None, limit=None): ...
4950
def dropna(self): ...
5051
def shift(self, periods: int = 1, fill_value: object = ...) -> Self: ...
5152
def unique(self): ...
5253
def searchsorted(self, value, side: str = ..., sorter=...): ...
53-
def factorize(self, use_na_sentinel: bool = True) -> tuple[np.ndarray, Self]: ...
54+
def factorize(self, use_na_sentinel: bool = True) -> tuple[np_1darray, Self]: ...
5455
def repeat(self, repeats, axis=...): ...
5556
def take(
5657
self,
@@ -60,7 +61,7 @@ class ExtensionArray:
6061
fill_value=...,
6162
) -> Self: ...
6263
def copy(self) -> Self: ...
63-
def view(self, dtype=...) -> Self | np.ndarray: ...
64+
def view(self, dtype=...) -> Self | np_1darray: ...
6465
def ravel(self, order="C") -> Self: ...
6566
def tolist(self) -> list: ...
6667
def _reduce(

pandas-stubs/core/arrays/categorical.pyi

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ from pandas._typing import (
2525
ScalarIndexer,
2626
SequenceIndexer,
2727
TakeIndexer,
28-
np_ndarray_bool,
29-
np_ndarray_int,
28+
np_1darray,
3029
)
3130

3231
from pandas.core.dtypes.dtypes import CategoricalDtype as CategoricalDtype
@@ -63,7 +62,7 @@ class Categorical(ExtensionArray):
6362
fastpath: bool = ...,
6463
) -> Categorical: ...
6564
@property
66-
def codes(self) -> np_ndarray_int: ...
65+
def codes(self) -> np_1darray[np.signedinteger]: ...
6766
def set_ordered(self, value) -> Categorical: ...
6867
def as_ordered(self) -> Categorical: ...
6968
def as_unordered(self) -> Categorical: ...
@@ -90,18 +89,18 @@ class Categorical(ExtensionArray):
9089
@property
9190
def shape(self): ...
9291
def shift(self, periods=1, fill_value=...): ...
93-
def __array__(self, dtype=...) -> np.ndarray: ...
92+
def __array__(self, dtype=...) -> np_1darray: ...
9493
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): ...
9594
@property
9695
def T(self): ...
9796
@property
9897
def nbytes(self) -> int: ...
9998
def memory_usage(self, deep: bool = ...): ...
10099
def searchsorted(self, value, side: str = ..., sorter=...): ...
101-
def isna(self) -> np_ndarray_bool: ...
102-
def isnull(self) -> np_ndarray_bool: ...
103-
def notna(self) -> np_ndarray_bool: ...
104-
def notnull(self) -> np_ndarray_bool: ...
100+
def isna(self) -> np_1darray[np.bool]: ...
101+
def isnull(self) -> np_1darray[np.bool]: ...
102+
def notna(self) -> np_1darray[np.bool]: ...
103+
def notnull(self) -> np_1darray[np.bool]: ...
105104
def dropna(self): ...
106105
def value_counts(self, dropna: bool = True): ...
107106
def check_for_ordered(self, op) -> None: ...

pandas-stubs/core/arrays/interval.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ from pandas._typing import (
2121
ScalarIndexer,
2222
SequenceIndexer,
2323
TakeIndexer,
24-
np_ndarray_bool,
24+
np_1darray,
2525
)
2626

2727
IntervalOrNA: TypeAlias = Interval | float
@@ -99,7 +99,7 @@ class IntervalArray(IntervalMixin, ExtensionArray):
9999
def mid(self) -> Index: ...
100100
@property
101101
def is_non_overlapping_monotonic(self) -> bool: ...
102-
def __array__(self, dtype=...) -> np.ndarray: ...
102+
def __array__(self, dtype=...) -> np_1darray: ...
103103
def __arrow_array__(self, type=...): ...
104104
def to_tuples(self, na_tuple: bool = True): ...
105105
def repeat(self, repeats, axis: Axis | None = ...): ...
@@ -108,5 +108,5 @@ class IntervalArray(IntervalMixin, ExtensionArray):
108108
@overload
109109
def contains(
110110
self, other: Scalar | ExtensionArray | Index | np.ndarray
111-
) -> np_ndarray_bool: ...
111+
) -> np_1darray[np.bool]: ...
112112
def overlaps(self, other: Interval) -> bool: ...

pandas-stubs/core/base.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ from pandas._typing import (
2626
DropKeep,
2727
NDFrameT,
2828
Scalar,
29+
np_1darray,
2930
npt,
3031
)
3132
from pandas.util._decorators import cache_readonly
@@ -63,7 +64,7 @@ class IndexOpsMixin(OpsMixin, Generic[S1]):
6364
copy: bool = False,
6465
na_value: Scalar = ...,
6566
**kwargs,
66-
) -> np.ndarray: ...
67+
) -> np_1darray: ...
6768
@property
6869
def empty(self) -> bool: ...
6970
def max(self, axis=..., skipna: bool = ..., **kwargs): ...
@@ -114,7 +115,7 @@ class IndexOpsMixin(OpsMixin, Generic[S1]):
114115
def is_monotonic_increasing(self) -> bool: ...
115116
def factorize(
116117
self, sort: bool = False, use_na_sentinel: bool = True
117-
) -> tuple[np.ndarray, np.ndarray | Index | Categorical]: ...
118+
) -> tuple[np_1darray, np_1darray | Index | Categorical]: ...
118119
def searchsorted(
119120
self, value, side: Literal["left", "right"] = ..., sorter=...
120121
) -> int | list[int]: ...

0 commit comments

Comments
 (0)