Skip to content

Commit ac4d313

Browse files
committed
Fix errors on py310
1 parent 7bdfb49 commit ac4d313

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

pandas-stubs/core/indexes/interval.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class IntervalIndex(ExtensionIndex[IntervalT], IntervalMixin):
222222
def memory_usage(self, deep: bool = False) -> int: ...
223223
@property
224224
def is_overlapping(self) -> bool: ...
225-
def get_loc(self, key: Label) -> int | slice | npt.NDArray[np.bool_]: ...
225+
def get_loc(self, key: Label) -> int | slice | np_1darray[np.bool]: ...
226226
@final
227227
def get_indexer(
228228
self,

tests/test_indexes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,22 @@ def test_get_loc() -> None:
13131313
np_1darray[np.bool],
13141314
)
13151315

1316+
i1, i2, i3 = pd.Interval(0, 1), pd.Interval(1, 2), pd.Interval(0, 2)
1317+
unique_interval_index = pd.IntervalIndex([i1, i2])
1318+
check(
1319+
assert_type(
1320+
unique_interval_index.get_loc(i1), Union[int, slice, np_1darray[np.bool]]
1321+
),
1322+
np.int64,
1323+
)
1324+
overlap_interval_index = pd.IntervalIndex([i1, i2, i3])
1325+
check(
1326+
assert_type(
1327+
overlap_interval_index.get_loc(1), Union[int, slice, np_1darray[np.bool]]
1328+
),
1329+
np_1darray[np.bool],
1330+
)
1331+
13161332

13171333
def test_value_counts() -> None:
13181334
nmi = pd.Index(list("abcb"))

tests/test_scalars.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,9 @@ def test_timestamp_cmp() -> None:
12391239
np_dt64_arr: npt.NDArray[np.datetime64] = np.array(
12401240
[1, 2, 3], dtype="datetime64[ns]"
12411241
)
1242-
np_dt64_arr2d = np.arange(6).astype(dtype=np.datetime64).reshape(3, 2)
1242+
np_dt64_arr2d: np.ndarray[tuple[int, int], np.dtype[np.datetime64]] = (
1243+
np.arange(6).astype(dtype=np.datetime64).reshape(3, 2)
1244+
)
12431245

12441246
c_timestamp = ts
12451247
c_np_dt64 = np.datetime64(1, "ns")

0 commit comments

Comments
 (0)