Skip to content

Commit 252c7e9

Browse files
committed
Period also supports arrays
1 parent 519be94 commit 252c7e9

File tree

2 files changed

+140
-109
lines changed

2 files changed

+140
-109
lines changed

pandas-stubs/_libs/tslibs/period.pyi

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ from typing_extensions import TypeAlias
2222
from pandas._libs.tslibs import NaTType
2323
from pandas._libs.tslibs.offsets import BaseOffset
2424
from pandas._libs.tslibs.timestamps import Timestamp
25-
from pandas._typing import npt
25+
from pandas._typing import (
26+
ShapeT,
27+
np_1darray,
28+
np_ndarray,
29+
)
2630

2731
class IncompatibleFrequency(ValueError): ...
2832

@@ -98,44 +102,56 @@ class Period(PeriodMixin):
98102
@overload
99103
def __eq__(self, other: Period) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
100104
@overload
101-
def __eq__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ... # type: ignore[overload-overlap]
105+
def __eq__(self, other: PeriodIndex) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
102106
@overload
103107
def __eq__(self, other: PeriodSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
104108
@overload
109+
def __eq__(self, other: np_ndarray[ShapeT]) -> np_ndarray[ShapeT, np.bool]: ... # type: ignore[overload-overlap]
110+
@overload
105111
def __eq__(self, other: object) -> Literal[False]: ...
106112
@overload
107113
def __ge__(self, other: Period) -> bool: ...
108114
@overload
109-
def __ge__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ...
115+
def __ge__(self, other: PeriodIndex) -> np_1darray[np.bool]: ...
110116
@overload
111117
def __ge__(self, other: PeriodSeries) -> Series[bool]: ...
112118
@overload
119+
def __ge__(self, other: np_ndarray[ShapeT]) -> np_ndarray[ShapeT, np.bool]: ...
120+
@overload
113121
def __gt__(self, other: Period) -> bool: ...
114122
@overload
115-
def __gt__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ...
123+
def __gt__(self, other: PeriodIndex) -> np_1darray[np.bool]: ...
116124
@overload
117125
def __gt__(self, other: PeriodSeries) -> Series[bool]: ...
118126
@overload
127+
def __gt__(self, other: np_ndarray[ShapeT]) -> np_ndarray[ShapeT, np.bool]: ...
128+
@overload
119129
def __le__(self, other: Period) -> bool: ...
120130
@overload
121-
def __le__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ...
131+
def __le__(self, other: PeriodIndex) -> np_1darray[np.bool]: ...
122132
@overload
123133
def __le__(self, other: PeriodSeries) -> Series[bool]: ...
124134
@overload
135+
def __le__(self, other: np_ndarray[ShapeT]) -> np_ndarray[ShapeT, np.bool]: ...
136+
@overload
125137
def __lt__(self, other: Period) -> bool: ...
126138
@overload
127-
def __lt__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ...
139+
def __lt__(self, other: PeriodIndex) -> np_1darray[np.bool]: ...
128140
@overload
129141
def __lt__(self, other: PeriodSeries) -> Series[bool]: ...
142+
@overload
143+
def __lt__(self, other: np_ndarray[ShapeT]) -> np_ndarray[ShapeT, np.bool]: ...
130144
# ignore[misc] here because we know all other comparisons
131145
# are False, so we use Literal[False]
132146
@overload
133147
def __ne__(self, other: Period) -> bool: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
134148
@overload
135-
def __ne__(self, other: PeriodIndex) -> npt.NDArray[np.bool_]: ... # type: ignore[overload-overlap]
149+
def __ne__(self, other: PeriodIndex) -> np_1darray[np.bool]: ... # type: ignore[overload-overlap]
136150
@overload
137151
def __ne__(self, other: PeriodSeries) -> Series[bool]: ... # type: ignore[overload-overlap]
138152
@overload
153+
def __ne__(self, other: np_ndarray[ShapeT]) -> np_ndarray[ShapeT, np.bool]: ... # type: ignore[overload-overlap]
154+
@overload
139155
def __ne__(self, other: object) -> Literal[True]: ...
140156
# Ignored due to indecipherable error from mypy:
141157
# Forward operator "__add__" is not callable [misc]

tests/test_scalars.py

Lines changed: 117 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,128 +1824,143 @@ def test_period_add_subtract() -> None:
18241824
check(assert_type(as_period_index - p, pd.Index), pd.Index)
18251825

18261826

1827-
def test_period_cmp() -> None:
1827+
def test_period_cmp_scalar() -> None:
18281828
p = pd.Period("2012-1-1", freq="D")
1829+
p2 = pd.Period("2012-1-2", freq="D")
18291830

1830-
c_period = pd.Period("2012-1-1", freq="D")
1831-
c_period_index = pd.period_range("2012-1-1", periods=10, freq="D")
1832-
c_period_series = pd.Series(c_period_index)
1831+
# >, <=
1832+
gt1 = check(assert_type(p > p2, bool), bool)
1833+
le1 = check(assert_type(p <= p2, bool), bool)
1834+
assert gt1 != le1
18331835

1834-
eq = check(assert_type(p == c_period, bool), bool)
1835-
ne = check(assert_type(p != c_period, bool), bool)
1836-
assert eq != ne
1836+
# <, >=
1837+
lt1 = check(assert_type(p < p2, bool), bool)
1838+
ge1 = check(assert_type(p >= p2, bool), bool)
1839+
assert lt1 != ge1
18371840

1838-
eq_a = check(
1839-
assert_type(p == c_period_index, np_ndarray_bool), np.ndarray, np.bool_
1840-
)
1841-
ne_q = check(
1842-
assert_type(p != c_period_index, np_ndarray_bool), np.ndarray, np.bool_
1843-
)
1844-
assert (eq_a != ne_q).all()
1841+
# ==, !=
1842+
eq1 = check(assert_type(p == p2, bool), bool)
1843+
ne1 = check(assert_type(p != p2, bool), bool)
1844+
assert eq1 != ne1
1845+
eq2 = check(assert_type(p == 1, Literal[False]), bool)
1846+
ne2 = check(assert_type(p != 1, Literal[True]), bool)
1847+
assert eq2 != ne2
18451848

1846-
eq_s = check(
1847-
assert_type(p == c_period_series, "pd.Series[bool]"), pd.Series, np.bool_
1848-
)
1849-
ne_s = check(
1850-
assert_type(p != c_period_series, "pd.Series[bool]"), pd.Series, np.bool_
1851-
)
1852-
assert (eq_s != ne_s).all()
18531849

1854-
eq = check(assert_type(c_period == p, bool), bool)
1855-
ne = check(assert_type(c_period != p, bool), bool)
1856-
assert eq != ne
1850+
def test_period_cmp_series() -> None:
1851+
p = pd.Period("2012-1-1", freq="D")
1852+
p_ser = pd.Series(pd.period_range("2012-1-1", periods=10, freq="D"))
18571853

1858-
eq_a = check(
1859-
assert_type(c_period_index == p, np_1darray[np.bool]), np_1darray[np.bool]
1860-
)
1861-
ne_a = check(
1862-
assert_type(c_period_index != p, np_1darray[np.bool]), np_1darray[np.bool]
1863-
)
1864-
assert (eq_a != ne_a).all()
1854+
# >, <=
1855+
gt1 = check(assert_type(p > p_ser, "pd.Series[bool]"), pd.Series, np.bool)
1856+
le1 = check(assert_type(p <= p_ser, "pd.Series[bool]"), pd.Series, np.bool)
1857+
assert (gt1 != le1).all()
1858+
gt2 = check(assert_type(p_ser > p, "pd.Series[bool]"), pd.Series, np.bool)
1859+
le2 = check(assert_type(p_ser <= p, "pd.Series[bool]"), pd.Series, np.bool)
1860+
assert (gt2 != le2).all()
18651861

1866-
eq_s = check(
1867-
assert_type(c_period_series == p, "pd.Series[bool]"), pd.Series, np.bool_
1868-
)
1869-
ne_s = check(
1870-
assert_type(c_period_series != p, "pd.Series[bool]"), pd.Series, np.bool_
1871-
)
1872-
assert (eq_s != ne_s).all()
1862+
# <, >=
1863+
lt1 = check(assert_type(p < p_ser, "pd.Series[bool]"), pd.Series, np.bool)
1864+
ge1 = check(assert_type(p >= p_ser, "pd.Series[bool]"), pd.Series, np.bool)
1865+
assert (lt1 != ge1).all()
1866+
lt2 = check(assert_type(p_ser < p, "pd.Series[bool]"), pd.Series, np.bool)
1867+
ge2 = check(assert_type(p_ser >= p, "pd.Series[bool]"), pd.Series, np.bool)
1868+
assert (lt2 != ge2).all()
18731869

1874-
gt = check(assert_type(p > c_period, bool), bool)
1875-
le = check(assert_type(p <= c_period, bool), bool)
1876-
assert gt != le
1870+
# ==, !=
1871+
eq1 = check(assert_type(p == p_ser, "pd.Series[bool]"), pd.Series, np.bool)
1872+
ne1 = check(assert_type(p != p_ser, "pd.Series[bool]"), pd.Series, np.bool)
1873+
assert (eq1 != ne1).all()
18771874

1878-
gt_a = check(assert_type(p > c_period_index, np_ndarray_bool), np.ndarray, np.bool_)
1879-
le_a = check(
1880-
assert_type(p <= c_period_index, np_ndarray_bool), np.ndarray, np.bool_
1881-
)
1882-
assert (gt_a != le_a).all()
1875+
# ==, != (p on the rhs, use == and != of lhs)
1876+
eq_rhs1 = check(assert_type(p_ser == p, "pd.Series[bool]"), pd.Series, np.bool)
1877+
ne_rhs1 = check(assert_type(p_ser != p, "pd.Series[bool]"), pd.Series, np.bool)
1878+
assert (eq_rhs1 != ne_rhs1).all()
18831879

1884-
gt_s = check(
1885-
assert_type(p > c_period_series, "pd.Series[bool]"), pd.Series, np.bool_
1886-
)
1887-
le_s = check(
1888-
assert_type(p <= c_period_series, "pd.Series[bool]"), pd.Series, np.bool_
1889-
)
1890-
assert (gt_s != le_s).all()
18911880

1892-
gt = check(assert_type(c_period > p, bool), bool)
1893-
le = check(assert_type(c_period <= p, bool), bool)
1894-
assert gt != le
1881+
def test_period_cmp_index() -> None:
1882+
p = pd.Period("2012-1-1", freq="D")
1883+
p_idx = pd.period_range("2012-1-1", periods=10, freq="D")
18951884

1896-
gt_a = check(
1897-
assert_type(c_period_index > p, np_1darray[np.bool]), np_1darray[np.bool]
1898-
)
1899-
le_a = check(
1900-
assert_type(c_period_index <= p, np_1darray[np.bool]), np_1darray[np.bool]
1901-
)
1902-
assert (gt_a != le_a).all()
1885+
# >, <=
1886+
gt1 = check(assert_type(p > p_idx, np_1darray[np.bool]), np_1darray[np.bool])
1887+
le1 = check(assert_type(p <= p_idx, np_1darray[np.bool]), np_1darray[np.bool])
1888+
assert (gt1 != le1).all()
1889+
gt2 = check(assert_type(p_idx > p, np_1darray[np.bool]), np_1darray[np.bool])
1890+
le2 = check(assert_type(p_idx <= p, np_1darray[np.bool]), np_1darray[np.bool])
1891+
assert (gt2 != le2).all()
19031892

1904-
gt_s = check(
1905-
assert_type(c_period_series > p, "pd.Series[bool]"), pd.Series, np.bool_
1906-
)
1907-
le_s = check(
1908-
assert_type(c_period_series <= p, "pd.Series[bool]"), pd.Series, np.bool_
1909-
)
1910-
assert (gt_s != le_s).all()
1893+
# <, >=
1894+
lt1 = check(assert_type(p < p_idx, np_1darray[np.bool]), np_1darray[np.bool])
1895+
ge1 = check(assert_type(p >= p_idx, np_1darray[np.bool]), np_1darray[np.bool])
1896+
assert (lt1 != ge1).all()
1897+
lt2 = check(assert_type(p_idx < p, np_1darray[np.bool]), np_1darray[np.bool])
1898+
ge2 = check(assert_type(p_idx >= p, np_1darray[np.bool]), np_1darray[np.bool])
1899+
assert (lt2 != ge2).all()
19111900

1912-
lt = check(assert_type(p < c_period, bool), bool)
1913-
ge = check(assert_type(p >= c_period, bool), bool)
1914-
assert lt != ge
1901+
# ==, !=
1902+
eq1 = check(assert_type(p == p_idx, np_1darray[np.bool]), np_1darray[np.bool])
1903+
ne1 = check(assert_type(p != p_idx, np_1darray[np.bool]), np_1darray[np.bool])
1904+
assert (eq1 != ne1).all()
19151905

1916-
lt_a = check(assert_type(p < c_period_index, np_ndarray_bool), np.ndarray, np.bool_)
1917-
ge_a = check(
1918-
assert_type(p >= c_period_index, np_ndarray_bool), np.ndarray, np.bool_
1919-
)
1920-
assert (lt_a != ge_a).all()
1906+
# ==, != (p on the rhs, use == and != of lhs)
1907+
eq_rhs1 = check(assert_type(p_idx == p, np_1darray[np.bool]), np_1darray[np.bool])
1908+
ne_rhs1 = check(assert_type(p_idx != p, np_1darray[np.bool]), np_1darray[np.bool])
1909+
assert (eq_rhs1 != ne_rhs1).all()
19211910

1922-
lt_s = check(
1923-
assert_type(p < c_period_series, "pd.Series[bool]"), pd.Series, np.bool_
1924-
)
1925-
ge_s = check(
1926-
assert_type(p >= c_period_series, "pd.Series[bool]"), pd.Series, np.bool_
1927-
)
1928-
assert (lt_s != ge_s).all()
19291911

1930-
lt = check(assert_type(c_period < p, bool), bool)
1931-
ge = check(assert_type(c_period >= p, bool), bool)
1932-
assert lt != ge
1912+
def test_period_cmp_array() -> None:
1913+
p = pd.Period("2012-1-1", freq="D")
1914+
arr_nd: npt.NDArray[np.object_] = pd.period_range(
1915+
"2012-1-1", periods=4, freq="D"
1916+
).to_numpy()
1917+
arr_2d = arr_nd.reshape(2, 2)
19331918

1934-
lt_a = check(
1935-
assert_type(c_period_index < p, np_1darray[np.bool]), np_1darray[np.bool]
1936-
)
1937-
ge_a = check(
1938-
assert_type(c_period_index >= p, np_1darray[np.bool]), np_1darray[np.bool]
1939-
)
1940-
assert (lt_a != ge_a).all()
1919+
# >, <=
1920+
gt_nd1 = check(assert_type(p > arr_nd, np_ndarray_bool), np.ndarray, np.bool)
1921+
le_nd1 = check(assert_type(p <= arr_nd, np_ndarray_bool), np.ndarray, np.bool)
1922+
assert (gt_nd1 != le_nd1).all()
1923+
gt_2d1 = check(assert_type(p > arr_2d, np_2darray[np.bool]), np_2darray[np.bool])
1924+
le_2d1 = check(assert_type(p <= arr_2d, np_2darray[np.bool]), np_2darray[np.bool])
1925+
assert (gt_2d1 != le_2d1).all()
1926+
# p on the rhs, type depends on np.ndarray > and <= methods
1927+
gt_nd2 = check(assert_type(arr_nd > p, np_ndarray_bool), np.ndarray, np.bool)
1928+
le_nd2 = check(assert_type(arr_nd <= p, np_ndarray_bool), np.ndarray, np.bool)
1929+
assert (gt_nd2 != le_nd2).all()
1930+
gt_2d2 = check(assert_type(arr_2d > p, np_ndarray_bool), np_2darray[np.bool])
1931+
le_2d2 = check(assert_type(arr_2d <= p, np_ndarray_bool), np_2darray[np.bool])
1932+
assert (gt_2d2 != le_2d2).all()
19411933

1942-
lt_s = check(
1943-
assert_type(c_period_series < p, "pd.Series[bool]"), pd.Series, np.bool_
1944-
)
1945-
ge_s = check(
1946-
assert_type(c_period_series >= p, "pd.Series[bool]"), pd.Series, np.bool_
1947-
)
1948-
assert (lt_s != ge_s).all()
1934+
# <, >=
1935+
lt_nd1 = check(assert_type(p < arr_nd, np_ndarray_bool), np.ndarray, np.bool)
1936+
ge_nd1 = check(assert_type(p >= arr_nd, np_ndarray_bool), np.ndarray, np.bool)
1937+
assert (lt_nd1 != ge_nd1).all()
1938+
lt_2d1 = check(assert_type(p < arr_2d, np_2darray[np.bool]), np_2darray[np.bool])
1939+
ge_2d1 = check(assert_type(p >= arr_2d, np_2darray[np.bool]), np_2darray[np.bool])
1940+
assert (lt_2d1 != ge_2d1).all()
1941+
# p on the rhs, type depends on np.ndarray < and >= methods
1942+
lt_nd2 = check(assert_type(arr_nd < p, np_ndarray_bool), np.ndarray, np.bool)
1943+
ge_nd2 = check(assert_type(arr_nd >= p, np_ndarray_bool), np.ndarray, np.bool)
1944+
assert (lt_nd2 != ge_nd2).all()
1945+
lt_2d2 = check(assert_type(arr_2d < p, np_ndarray_bool), np_2darray[np.bool])
1946+
ge_2d2 = check(assert_type(arr_2d >= p, np_ndarray_bool), np_2darray[np.bool])
1947+
assert (lt_2d2 != ge_2d2).all()
1948+
1949+
# ==, !=
1950+
eq_nd1 = check(assert_type(p == arr_nd, np_ndarray_bool), np.ndarray, np.bool)
1951+
ne_nd1 = check(assert_type(p != arr_nd, np_ndarray_bool), np.ndarray, np.bool)
1952+
assert (eq_nd1 != ne_nd1).all()
1953+
eq_2d1 = check(assert_type(p == arr_2d, np_2darray[np.bool]), np_2darray[np.bool])
1954+
ne_2d1 = check(assert_type(p != arr_2d, np_2darray[np.bool]), np_2darray[np.bool])
1955+
assert (eq_2d1 != ne_2d1).all()
1956+
1957+
# ==, != (td on the rhs, use == and != of lhs)
1958+
eq_rhs_nd1 = check(assert_type(arr_nd == p, Any), np_ndarray_bool)
1959+
ne_rhs_nd1 = check(assert_type(arr_nd != p, Any), np_ndarray_bool)
1960+
assert (eq_rhs_nd1 != ne_rhs_nd1).all()
1961+
eq_rhs_2d1 = check(assert_type(arr_2d == p, Any), np_2darray[np.bool])
1962+
ne_rhs_2d1 = check(assert_type(arr_2d != p, Any), np_2darray[np.bool])
1963+
assert (eq_rhs_2d1 != ne_rhs_2d1).all()
19491964

19501965

19511966
def test_period_methods() -> None:

0 commit comments

Comments
 (0)