Skip to content

Commit fe7e8b4

Browse files
GH1419 Allow Series and Index for other in Index.where(..., other) (pandas-dev#1420)
* GH1419 Allow Series and Index for other in Index.where(..., other) * GH1419 PR Feedback * GH1419 Soften typing in RangeIndex.where
1 parent 6957ad1 commit fe7e8b4

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

pandas-stubs/core/indexes/base.pyi

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ from pandas import (
4040
Series,
4141
TimedeltaIndex,
4242
)
43+
from pandas.core.arrays.boolean import BooleanArray
4344
from pandas.core.base import (
4445
ElementOpsMixin,
4546
IndexOpsMixin,
@@ -457,7 +458,18 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
457458
@property
458459
def values(self) -> np_1darray: ...
459460
def memory_usage(self, deep: bool = False): ...
460-
def where(self, cond, other: Scalar | ArrayLike | None = None): ...
461+
@overload
462+
def where(
463+
self,
464+
cond: Sequence[bool] | np_ndarray_bool | BooleanArray | IndexOpsMixin[bool],
465+
other: S1 | Series[S1] | Self,
466+
) -> Self: ...
467+
@overload
468+
def where(
469+
self,
470+
cond: Sequence[bool] | np_ndarray_bool | BooleanArray | IndexOpsMixin[bool],
471+
other: Scalar | AnyArrayLike | None = None,
472+
) -> Index: ...
461473
def __contains__(self, key) -> bool: ...
462474
@final
463475
def __setitem__(self, key, value) -> None: ...

pandas-stubs/core/indexes/range.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@ from typing import (
99
)
1010

1111
import numpy as np
12+
from pandas.core.arrays.boolean import BooleanArray
13+
from pandas.core.base import IndexOpsMixin
1214
from pandas.core.indexes.base import (
1315
Index,
1416
_IndexSubclassBase,
1517
)
1618
from typing_extensions import Self
1719

1820
from pandas._typing import (
21+
AnyArrayLike,
1922
Dtype,
2023
HashableT,
2124
MaskType,
25+
Scalar,
2226
np_1darray,
2327
np_ndarray_anyint,
28+
np_ndarray_bool,
2429
)
2530

2631
class RangeIndex(_IndexSubclassBase[int, np.int64]):
@@ -82,3 +87,8 @@ class RangeIndex(_IndexSubclassBase[int, np.int64]):
8287
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride]
8388
self, idx: int
8489
) -> int: ...
90+
def where( # type: ignore[override]
91+
self,
92+
cond: Sequence[bool] | np_ndarray_bool | BooleanArray | IndexOpsMixin[bool],
93+
other: Scalar | AnyArrayLike | None = None,
94+
) -> Index: ...

tests/indexes/test_indexes.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pandas.core.arrays.timedeltas import TimedeltaArray
2020
from pandas.core.indexes.base import Index
2121
from pandas.core.indexes.category import CategoricalIndex
22+
from pandas.core.indexes.datetimes import DatetimeIndex
2223
from typing_extensions import (
2324
Never,
2425
assert_type,
@@ -1541,3 +1542,39 @@ def test_multiindex_swaplevel() -> None:
15411542
"""Test that MultiIndex.swaplevel returns MultiIndex"""
15421543
mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"])
15431544
check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex)
1545+
1546+
1547+
def test_index_where() -> None:
1548+
"""Test Index.where with multiple types of other GH1419."""
1549+
idx = pd.Index(range(48))
1550+
mask = np.ones(48, dtype=bool)
1551+
val_idx = idx.where(mask, idx)
1552+
check(assert_type(val_idx, "pd.Index[int]"), pd.Index, int)
1553+
1554+
val_sr = idx.where(mask, (idx).to_series())
1555+
check(assert_type(val_sr, "pd.Index[int]"), pd.Index, int)
1556+
1557+
1558+
def test_datetimeindex_where() -> None:
1559+
"""Test DatetimeIndex.where with multiple types of other GH1419."""
1560+
datetime_index = pd.date_range(start="2025-01-01", freq="h", periods=48)
1561+
mask = np.ones(48, dtype=bool)
1562+
val_idx = datetime_index.where(mask, datetime_index - pd.Timedelta(days=1))
1563+
check(assert_type(val_idx, DatetimeIndex), DatetimeIndex)
1564+
1565+
val_sr = datetime_index.where(
1566+
mask, (datetime_index - pd.Timedelta(days=1)).to_series()
1567+
)
1568+
check(assert_type(val_sr, DatetimeIndex), DatetimeIndex)
1569+
1570+
val_idx_scalar = datetime_index.where(mask, pd.Index([0, 1]))
1571+
check(assert_type(val_idx_scalar, pd.Index), pd.Index)
1572+
1573+
val_sr_scalar = datetime_index.where(mask, pd.Series([0, 1]))
1574+
check(assert_type(val_sr_scalar, pd.Index), pd.Index)
1575+
1576+
val_scalar = datetime_index.where(mask, 1)
1577+
check(assert_type(val_scalar, pd.Index), pd.Index)
1578+
1579+
val_range = pd.RangeIndex(2).where(pd.Series([True, False]), 3)
1580+
check(assert_type(val_range, pd.Index), pd.RangeIndex)

0 commit comments

Comments
 (0)