|
19 | 19 | from pandas.core.arrays.timedeltas import TimedeltaArray |
20 | 20 | from pandas.core.indexes.base import Index |
21 | 21 | from pandas.core.indexes.category import CategoricalIndex |
| 22 | +from pandas.core.indexes.datetimes import DatetimeIndex |
22 | 23 | from typing_extensions import ( |
23 | 24 | Never, |
24 | 25 | assert_type, |
@@ -1541,3 +1542,39 @@ def test_multiindex_swaplevel() -> None: |
1541 | 1542 | """Test that MultiIndex.swaplevel returns MultiIndex""" |
1542 | 1543 | mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"]) |
1543 | 1544 | check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex) |
| 1545 | + |
| 1546 | + |
| 1547 | +def test_index_where() -> None: |
| 1548 | + """Test Index.where with multiple types of other GH1419.""" |
| 1549 | + idx = pd.Index(range(48)) |
| 1550 | + mask = np.ones(48, dtype=bool) |
| 1551 | + val_idx = idx.where(mask, idx) |
| 1552 | + check(assert_type(val_idx, "pd.Index[int]"), pd.Index, int) |
| 1553 | + |
| 1554 | + val_sr = idx.where(mask, (idx).to_series()) |
| 1555 | + check(assert_type(val_sr, "pd.Index[int]"), pd.Index, int) |
| 1556 | + |
| 1557 | + |
| 1558 | +def test_datetimeindex_where() -> None: |
| 1559 | + """Test DatetimeIndex.where with multiple types of other GH1419.""" |
| 1560 | + datetime_index = pd.date_range(start="2025-01-01", freq="h", periods=48) |
| 1561 | + mask = np.ones(48, dtype=bool) |
| 1562 | + val_idx = datetime_index.where(mask, datetime_index - pd.Timedelta(days=1)) |
| 1563 | + check(assert_type(val_idx, DatetimeIndex), DatetimeIndex) |
| 1564 | + |
| 1565 | + val_sr = datetime_index.where( |
| 1566 | + mask, (datetime_index - pd.Timedelta(days=1)).to_series() |
| 1567 | + ) |
| 1568 | + check(assert_type(val_sr, DatetimeIndex), DatetimeIndex) |
| 1569 | + |
| 1570 | + val_idx_scalar = datetime_index.where(mask, pd.Index([0, 1])) |
| 1571 | + check(assert_type(val_idx_scalar, pd.Index), pd.Index) |
| 1572 | + |
| 1573 | + val_sr_scalar = datetime_index.where(mask, pd.Series([0, 1])) |
| 1574 | + check(assert_type(val_sr_scalar, pd.Index), pd.Index) |
| 1575 | + |
| 1576 | + val_scalar = datetime_index.where(mask, 1) |
| 1577 | + check(assert_type(val_scalar, pd.Index), pd.Index) |
| 1578 | + |
| 1579 | + val_range = pd.RangeIndex(2).where(pd.Series([True, False]), 3) |
| 1580 | + check(assert_type(val_range, pd.Index), pd.RangeIndex) |
0 commit comments