Skip to content

Commit 3aa4044

Browse files
authored
Add numpy array shapes for return types (#1325)
* Add numpy array shapes for return types * Fix errors on py310 * Try separate check and assert_type * Revert "Try separate check and assert_type" This reverts commit 6be36cf. * Skip mypy failing tests on Python 3.10 * Add comment * See if renaming variables work * Type timedelta array shapes and refactor ts and td tests * Period also supports arrays * Fix pyright on Python 3.10 * Add dtype to index and series subclasses * Ignore pyrefly override
1 parent d6dc067 commit 3aa4044

33 files changed

+1378
-883
lines changed

pandas-stubs/_libs/interval.pyi

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ from pandas.core.series import (
2121
from pandas._typing import (
2222
IntervalClosedType,
2323
IntervalT,
24-
np_ndarray_bool,
24+
np_1darray,
2525
npt,
2626
)
2727

@@ -170,7 +170,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
170170
@overload
171171
def __gt__(self, other: Interval[_OrderableT]) -> bool: ...
172172
@overload
173-
def __gt__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
173+
def __gt__(
174+
self: IntervalT, other: IntervalIndex[IntervalT]
175+
) -> np_1darray[np.bool]: ...
174176
@overload
175177
def __gt__(
176178
self,
@@ -179,7 +181,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
179181
@overload
180182
def __lt__(self, other: Interval[_OrderableT]) -> bool: ...
181183
@overload
182-
def __lt__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
184+
def __lt__(
185+
self: IntervalT, other: IntervalIndex[IntervalT]
186+
) -> np_1darray[np.bool]: ...
183187
@overload
184188
def __lt__(
185189
self,
@@ -188,7 +192,9 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
188192
@overload
189193
def __ge__(self, other: Interval[_OrderableT]) -> bool: ...
190194
@overload
191-
def __ge__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
195+
def __ge__(
196+
self: IntervalT, other: IntervalIndex[IntervalT]
197+
) -> np_1darray[np.bool]: ...
192198
@overload
193199
def __ge__(
194200
self,
@@ -197,19 +203,25 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
197203
@overload
198204
def __le__(self, other: Interval[_OrderableT]) -> bool: ...
199205
@overload
200-
def __le__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
206+
def __le__(
207+
self: IntervalT, other: IntervalIndex[IntervalT]
208+
) -> np_1darray[np.bool]: ...
201209
@overload
202210
def __eq__(self, other: Interval[_OrderableT]) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
203211
@overload
204-
def __eq__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
212+
def __eq__(
213+
self: IntervalT, other: IntervalIndex[IntervalT]
214+
) -> np_1darray[np.bool]: ...
205215
@overload
206216
def __eq__(self, other: Series[_OrderableT]) -> Series[bool]: ... # type: ignore[overload-overlap]
207217
@overload
208218
def __eq__(self, other: object) -> Literal[False]: ...
209219
@overload
210220
def __ne__(self, other: Interval[_OrderableT]) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
211221
@overload
212-
def __ne__(self: IntervalT, other: IntervalIndex[IntervalT]) -> np_ndarray_bool: ...
222+
def __ne__(
223+
self: IntervalT, other: IntervalIndex[IntervalT]
224+
) -> np_1darray[np.bool]: ...
213225
@overload
214226
def __ne__(self, other: Series[_OrderableT]) -> Series[bool]: ... # type: ignore[overload-overlap]
215227
@overload

pandas-stubs/_libs/tslibs/period.pyi

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ from typing_extensions import TypeAlias
2222
from pandas._libs.tslibs import NaTType
2323
from pandas._libs.tslibs.offsets import BaseOffset
2424
from pandas._libs.tslibs.timestamps import Timestamp
25-
from pandas._typing import npt
25+
from pandas._typing import (
26+
ShapeT,
27+
np_1darray,
28+
np_ndarray,
29+
)
2630

2731
class IncompatibleFrequency(ValueError): ...
2832

@@ -98,44 +102,64 @@ class Period(PeriodMixin):
98102
@overload
99103
def __eq__(self, other: Period) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
100104
@overload
101-
def __eq__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ... # type: ignore[overload-overlap]
105+
def __eq__(self, other: PeriodIndex) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
102106
@overload
103107
def __eq__(self, other: PeriodSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
104108
@overload
109+
def __eq__(self, other: np_ndarray[ShapeT, np.object_]) -> np_ndarray[ShapeT, np.bool]: ... # type: ignore[overload-overlap]
110+
@overload
105111
def __eq__(self, other: object) -> Literal[False]: ...
106112
@overload
107113
def __ge__(self, other: Period) -> bool: ...
108114
@overload
109-
def __ge__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ...
115+
def __ge__(self, other: PeriodIndex) -> np_1darray[np.bool]: ...
110116
@overload
111117
def __ge__(self, other: PeriodSeries) -> Series[bool]: ...
112118
@overload
119+
def __ge__(
120+
self, other: np_ndarray[ShapeT, np.object_]
121+
) -> np_ndarray[ShapeT, np.bool]: ...
122+
@overload
113123
def __gt__(self, other: Period) -> bool: ...
114124
@overload
115-
def __gt__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ...
125+
def __gt__(self, other: PeriodIndex) -> np_1darray[np.bool]: ...
116126
@overload
117127
def __gt__(self, other: PeriodSeries) -> Series[bool]: ...
118128
@overload
129+
def __gt__(
130+
self, other: np_ndarray[ShapeT, np.object_]
131+
) -> np_ndarray[ShapeT, np.bool]: ...
132+
@overload
119133
def __le__(self, other: Period) -> bool: ...
120134
@overload
121-
def __le__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ...
135+
def __le__(self, other: PeriodIndex) -> np_1darray[np.bool]: ...
122136
@overload
123137
def __le__(self, other: PeriodSeries) -> Series[bool]: ...
124138
@overload
139+
def __le__(
140+
self, other: np_ndarray[ShapeT, np.object_]
141+
) -> np_ndarray[ShapeT, np.bool]: ...
142+
@overload
125143
def __lt__(self, other: Period) -> bool: ...
126144
@overload
127-
def __lt__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ...
145+
def __lt__(self, other: PeriodIndex) -> np_1darray[np.bool]: ...
128146
@overload
129147
def __lt__(self, other: PeriodSeries) -> Series[bool]: ...
148+
@overload
149+
def __lt__(
150+
self, other: np_ndarray[ShapeT, np.object_]
151+
) -> np_ndarray[ShapeT, np.bool]: ...
130152
# ignore[misc] here because we know all other comparisons
131153
# are False, so we use Literal[False]
132154
@overload
133155
def __ne__(self, other: Period) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
134156
@overload
135-
def __ne__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ... # type: ignore[overload-overlap]
157+
def __ne__(self, other: PeriodIndex) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
136158
@overload
137159
def __ne__(self, other: PeriodSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
138160
@overload
161+
def __ne__(self, other: np_ndarray[ShapeT, np.object_]) -> np_ndarray[ShapeT, np.bool]: ... # type: ignore[overload-overlap]
162+
@overload
139163
def __ne__(self, other: object) -> Literal[True]: ...
140164
# Ignored due to indecipherable error from mypy:
141165
# Forward operator "__add__" is not callable [misc]

pandas-stubs/_libs/tslibs/timedeltas.pyi

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ from pandas._libs.tslibs import (
3333
from pandas._libs.tslibs.period import Period
3434
from pandas._libs.tslibs.timestamps import Timestamp
3535
from pandas._typing import (
36+
ShapeT,
3637
TimeUnit,
38+
np_1darray,
39+
np_ndarray,
3740
npt,
3841
)
3942

@@ -153,12 +156,12 @@ class Timedelta(timedelta):
153156
def __add__(self, other: DatetimeIndex) -> DatetimeIndex: ...
154157
@overload
155158
def __add__(
156-
self, other: npt.NDArray[np.timedelta64]
157-
) -> npt.NDArray[np.timedelta64]: ...
159+
self, other: np_ndarray[ShapeT, np.timedelta64]
160+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
158161
@overload
159162
def __add__(
160-
self, other: npt.NDArray[np.datetime64]
161-
) -> npt.NDArray[np.datetime64]: ...
163+
self, other: np_ndarray[ShapeT, np.datetime64]
164+
) -> np_ndarray[ShapeT, np.datetime64]: ...
162165
@overload
163166
def __add__(self, other: pd.TimedeltaIndex) -> pd.TimedeltaIndex: ...
164167
@overload
@@ -176,12 +179,12 @@ class Timedelta(timedelta):
176179
def __radd__(self, other: NaTType) -> NaTType: ...
177180
@overload
178181
def __radd__(
179-
self, other: npt.NDArray[np.timedelta64]
180-
) -> npt.NDArray[np.timedelta64]: ...
182+
self, other: np_ndarray[ShapeT, np.timedelta64]
183+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
181184
@overload
182185
def __radd__(
183-
self, other: npt.NDArray[np.datetime64]
184-
) -> npt.NDArray[np.datetime64]: ...
186+
self, other: np_ndarray[ShapeT, np.datetime64]
187+
) -> np_ndarray[ShapeT, np.datetime64]: ...
185188
@overload
186189
def __radd__(self, other: pd.TimedeltaIndex) -> pd.TimedeltaIndex: ...
187190
@overload
@@ -193,8 +196,8 @@ class Timedelta(timedelta):
193196
def __sub__(self, other: NaTType) -> NaTType: ...
194197
@overload
195198
def __sub__(
196-
self, other: npt.NDArray[np.timedelta64]
197-
) -> npt.NDArray[np.timedelta64]: ...
199+
self, other: np_ndarray[ShapeT, np.timedelta64]
200+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
198201
@overload
199202
def __sub__(self, other: pd.TimedeltaIndex) -> TimedeltaIndex: ...
200203
@overload
@@ -215,12 +218,12 @@ class Timedelta(timedelta):
215218
def __rsub__(self, other: DatetimeIndex) -> DatetimeIndex: ...
216219
@overload
217220
def __rsub__(
218-
self, other: npt.NDArray[np.datetime64]
219-
) -> npt.NDArray[np.datetime64]: ...
221+
self, other: np_ndarray[ShapeT, np.datetime64]
222+
) -> np_ndarray[ShapeT, np.datetime64]: ...
220223
@overload
221224
def __rsub__(
222-
self, other: npt.NDArray[np.timedelta64]
223-
) -> npt.NDArray[np.timedelta64]: ...
225+
self, other: np_ndarray[ShapeT, np.timedelta64]
226+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
224227
@overload
225228
def __rsub__(self, other: pd.TimedeltaIndex) -> pd.TimedeltaIndex: ...
226229
def __neg__(self) -> Timedelta: ...
@@ -231,8 +234,8 @@ class Timedelta(timedelta):
231234
def __mul__(self, other: float) -> Timedelta: ...
232235
@overload
233236
def __mul__(
234-
self, other: npt.NDArray[np.integer] | npt.NDArray[np.floating]
235-
) -> npt.NDArray[np.timedelta64]: ...
237+
self, other: np_ndarray[ShapeT, np.integer] | np_ndarray[ShapeT, np.floating]
238+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
236239
@overload
237240
def __mul__(self, other: Series[int]) -> TimedeltaSeries: ...
238241
@overload
@@ -243,8 +246,8 @@ class Timedelta(timedelta):
243246
def __rmul__(self, other: float) -> Timedelta: ...
244247
@overload
245248
def __rmul__(
246-
self, other: npt.NDArray[np.floating] | npt.NDArray[np.integer]
247-
) -> npt.NDArray[np.timedelta64]: ...
249+
self, other: np_ndarray[ShapeT, np.floating] | np_ndarray[ShapeT, np.integer]
250+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
248251
@overload
249252
def __rmul__(self, other: Series[int]) -> TimedeltaSeries: ...
250253
@overload
@@ -260,12 +263,12 @@ class Timedelta(timedelta):
260263
def __floordiv__(self, other: float) -> Timedelta: ...
261264
@overload
262265
def __floordiv__(
263-
self, other: npt.NDArray[np.integer] | npt.NDArray[np.floating]
264-
) -> npt.NDArray[np.timedelta64]: ...
266+
self, other: np_ndarray[ShapeT, np.integer] | np_ndarray[ShapeT, np.floating]
267+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
265268
@overload
266269
def __floordiv__(
267-
self, other: npt.NDArray[np.timedelta64]
268-
) -> npt.NDArray[np.int_]: ...
270+
self, other: np_ndarray[ShapeT, np.timedelta64]
271+
) -> np_ndarray[ShapeT, np.int_]: ...
269272
@overload
270273
def __floordiv__(self, other: Index[int] | Index[float]) -> TimedeltaIndex: ...
271274
@overload
@@ -282,17 +285,17 @@ class Timedelta(timedelta):
282285
def __rfloordiv__(self, other: NaTType | None) -> float: ...
283286
@overload
284287
def __rfloordiv__(
285-
self, other: npt.NDArray[np.timedelta64]
286-
) -> npt.NDArray[np.int_]: ...
288+
self, other: np_ndarray[ShapeT, np.timedelta64]
289+
) -> np_ndarray[ShapeT, np.int_]: ...
287290
# Override due to more types supported than dt.timedelta
288291
@overload # type: ignore[override]
289292
def __truediv__(self, other: timedelta | Timedelta | NaTType) -> float: ...
290293
@overload
291294
def __truediv__(self, other: float) -> Timedelta: ...
292295
@overload
293296
def __truediv__(
294-
self, other: npt.NDArray[np.integer] | npt.NDArray[np.floating]
295-
) -> npt.NDArray[np.timedelta64]: ...
297+
self, other: np_ndarray[ShapeT, np.integer] | np_ndarray[ShapeT, np.floating]
298+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
296299
@overload
297300
def __truediv__(self, other: TimedeltaSeries) -> Series[float]: ...
298301
@overload
@@ -308,9 +311,11 @@ class Timedelta(timedelta):
308311
@overload
309312
def __eq__(self, other: TimedeltaSeries | Series[pd.Timedelta]) -> Series[bool]: ... # type: ignore[overload-overlap]
310313
@overload
314+
def __eq__(self, other: TimedeltaIndex) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
315+
@overload
311316
def __eq__( # type: ignore[overload-overlap]
312-
self, other: TimedeltaIndex | npt.NDArray[np.timedelta64]
313-
) -> npt.NDArray[np.bool_]: ...
317+
self, other: np_ndarray[ShapeT, np.timedelta64]
318+
) -> np_ndarray[ShapeT, np.bool_]: ...
314319
@overload
315320
def __eq__(self, other: object) -> Literal[False]: ...
316321
# Override due to more types supported than dt.timedelta
@@ -319,9 +324,11 @@ class Timedelta(timedelta):
319324
@overload
320325
def __ne__(self, other: TimedeltaSeries | Series[pd.Timedelta]) -> Series[bool]: ... # type: ignore[overload-overlap]
321326
@overload
327+
def __ne__(self, other: TimedeltaIndex) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
328+
@overload
322329
def __ne__( # type: ignore[overload-overlap]
323-
self, other: TimedeltaIndex | npt.NDArray[np.timedelta64]
324-
) -> npt.NDArray[np.bool_]: ...
330+
self, other: np_ndarray[ShapeT, np.timedelta64]
331+
) -> np_ndarray[ShapeT, np.bool_]: ...
325332
@overload
326333
def __ne__(self, other: object) -> Literal[True]: ...
327334
# Override due to more types supported than dt.timedelta
@@ -335,8 +342,8 @@ class Timedelta(timedelta):
335342
def __mod__(self, other: Index[int] | Index[float]) -> TimedeltaIndex: ...
336343
@overload
337344
def __mod__(
338-
self, other: npt.NDArray[np.integer] | npt.NDArray[np.floating]
339-
) -> npt.NDArray[np.timedelta64]: ...
345+
self, other: np_ndarray[ShapeT, np.integer] | np_ndarray[ShapeT, np.floating]
346+
) -> np_ndarray[ShapeT, np.timedelta64]: ...
340347
@overload
341348
def __mod__(
342349
self, other: Series[int] | Series[float] | TimedeltaSeries
@@ -348,36 +355,44 @@ class Timedelta(timedelta):
348355
@overload # type: ignore[override]
349356
def __le__(self, other: timedelta | Timedelta | np.timedelta64) -> bool: ... # type: ignore[misc]
350357
@overload
358+
def __le__(self, other: TimedeltaIndex) -> np_1darray[np.bool]: ...
359+
@overload
351360
def __le__(
352-
self, other: TimedeltaIndex | npt.NDArray[np.timedelta64]
353-
) -> npt.NDArray[np.bool_]: ...
361+
self, other: np_ndarray[ShapeT, np.timedelta64]
362+
) -> np_ndarray[ShapeT, np.bool_]: ...
354363
@overload
355364
def __le__(self, other: TimedeltaSeries | Series[pd.Timedelta]) -> Series[bool]: ...
356365
# Override due to more types supported than dt.timedelta
357366
@overload # type: ignore[override]
358367
def __lt__(self, other: timedelta | Timedelta | np.timedelta64) -> bool: ... # type: ignore[misc]
359368
@overload
369+
def __lt__(self, other: TimedeltaIndex) -> np_1darray[np.bool]: ...
370+
@overload
360371
def __lt__(
361-
self, other: TimedeltaIndex | npt.NDArray[np.timedelta64]
362-
) -> npt.NDArray[np.bool_]: ...
372+
self, other: np_ndarray[ShapeT, np.timedelta64]
373+
) -> np_ndarray[ShapeT, np.bool_]: ...
363374
@overload
364375
def __lt__(self, other: TimedeltaSeries | Series[pd.Timedelta]) -> Series[bool]: ...
365376
# Override due to more types supported than dt.timedelta
366377
@overload # type: ignore[override]
367378
def __ge__(self, other: timedelta | Timedelta | np.timedelta64) -> bool: ... # type: ignore[misc]
368379
@overload
380+
def __ge__(self, other: TimedeltaIndex) -> np_1darray[np.bool]: ...
381+
@overload
369382
def __ge__(
370-
self, other: TimedeltaIndex | npt.NDArray[np.timedelta64]
371-
) -> npt.NDArray[np.bool_]: ...
383+
self, other: np_ndarray[ShapeT, np.timedelta64]
384+
) -> np_ndarray[ShapeT, np.bool_]: ...
372385
@overload
373386
def __ge__(self, other: TimedeltaSeries | Series[pd.Timedelta]) -> Series[bool]: ...
374387
# Override due to more types supported than dt.timedelta
375388
@overload # type: ignore[override]
376389
def __gt__(self, other: timedelta | Timedelta | np.timedelta64) -> bool: ... # type: ignore[misc]
377390
@overload
391+
def __gt__(self, other: TimedeltaIndex) -> np_1darray[np.bool]: ...
392+
@overload
378393
def __gt__(
379-
self, other: TimedeltaIndex | npt.NDArray[np.timedelta64]
380-
) -> npt.NDArray[np.bool_]: ...
394+
self, other: np_ndarray[ShapeT, np.timedelta64]
395+
) -> np_ndarray[ShapeT, np.bool_]: ...
381396
@overload
382397
def __gt__(self, other: TimedeltaSeries | Series[pd.Timedelta]) -> Series[bool]: ...
383398
def __hash__(self) -> int: ...

0 commit comments

Comments
 (0)