Skip to content

Commit 354406c

Browse files
authored
Support pd.cut on a series of timestamps (#507)
* Support pd.cut on a series of timestamps * remove Series[Timestamp]
1 parent 49455ce commit 354406c

File tree

3 files changed

+84
-2
lines changed

3 files changed

+84
-2
lines changed

pandas-stubs/_typing.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ from pandas._libs.tslibs import (
3636
Timestamp,
3737
)
3838

39-
from pandas.core.dtypes.dtypes import ExtensionDtype
39+
from pandas.core.dtypes.dtypes import (
40+
CategoricalDtype,
41+
ExtensionDtype,
42+
)
4043

4144
from pandas.io.formats.format import EngFormatter
4245

@@ -210,6 +213,7 @@ S1 = TypeVar(
210213
Interval[float],
211214
Interval[Timestamp],
212215
Interval[Timedelta],
216+
CategoricalDtype,
213217
)
214218
T1 = TypeVar(
215219
"T1", str, int, np.int64, np.uint64, np.float64, float, np.dtype[np.generic]

pandas-stubs/core/reshape/tile.pyi

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@ from typing import (
77
import numpy as np
88
from pandas import (
99
Categorical,
10+
CategoricalDtype,
11+
DatetimeIndex,
1012
Float64Index,
1113
Index,
1214
Int64Index,
15+
Interval,
1316
IntervalIndex,
1417
Series,
18+
Timestamp,
1519
)
20+
from pandas.core.series import TimestampSeries
1621

1722
from pandas._typing import (
1823
Label,
@@ -46,6 +51,36 @@ def cut(
4651
ordered: bool = ...,
4752
) -> tuple[npt.NDArray[np.intp], IntervalIndex]: ...
4853
@overload
54+
def cut( # type: ignore[misc]
55+
x: TimestampSeries,
56+
bins: int
57+
| TimestampSeries
58+
| DatetimeIndex
59+
| Sequence[Timestamp]
60+
| Sequence[np.datetime64],
61+
right: bool = ...,
62+
labels: Literal[False] | Sequence[Label] | None = ...,
63+
*,
64+
retbins: Literal[True],
65+
precision: int = ...,
66+
include_lowest: bool = ...,
67+
duplicates: Literal["raise", "drop"] = ...,
68+
ordered: bool = ...,
69+
) -> tuple[Series, DatetimeIndex]: ...
70+
@overload
71+
def cut(
72+
x: TimestampSeries,
73+
bins: IntervalIndex[Interval[Timestamp]],
74+
right: bool = ...,
75+
labels: Sequence[Label] | None = ...,
76+
*,
77+
retbins: Literal[True],
78+
precision: int = ...,
79+
include_lowest: bool = ...,
80+
duplicates: Literal["raise", "drop"] = ...,
81+
ordered: bool = ...,
82+
) -> tuple[Series, DatetimeIndex]: ...
83+
@overload
4984
def cut(
5085
x: Series,
5186
bins: int | Series | Int64Index | Float64Index | Sequence[int] | Sequence[float],
@@ -61,7 +96,7 @@ def cut(
6196
@overload
6297
def cut(
6398
x: Series,
64-
bins: IntervalIndex,
99+
bins: IntervalIndex[Interval[int]] | IntervalIndex[Interval[float]],
65100
right: bool = ...,
66101
labels: Sequence[Label] | None = ...,
67102
*,
@@ -117,6 +152,23 @@ def cut(
117152
ordered: bool = ...,
118153
) -> npt.NDArray[np.intp]: ...
119154
@overload
155+
def cut(
156+
x: TimestampSeries,
157+
bins: int
158+
| TimestampSeries
159+
| DatetimeIndex
160+
| Sequence[Timestamp]
161+
| Sequence[np.datetime64]
162+
| IntervalIndex[Interval[Timestamp]],
163+
right: bool = ...,
164+
labels: Literal[False] | Sequence[Label] | None = ...,
165+
retbins: Literal[False] = ...,
166+
precision: int = ...,
167+
include_lowest: bool = ...,
168+
duplicates: Literal["raise", "drop"] = ...,
169+
ordered: bool = ...,
170+
) -> Series[CategoricalDtype]: ...
171+
@overload
120172
def cut(
121173
x: Series,
122174
bins: int

tests/test_pandas.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,32 @@ def test_cut() -> None:
870870
check(assert_type(n0, pd.Categorical), pd.Categorical)
871871
check(assert_type(n1, pd.IntervalIndex), pd.IntervalIndex)
872872

873+
s1 = pd.Series(data=pd.date_range("1/1/2020", periods=300))
874+
check(
875+
assert_type(
876+
pd.cut(s1, bins=[np.datetime64("2020-01-03"), np.datetime64("2020-09-01")]),
877+
"pd.Series[pd.CategoricalDtype]",
878+
),
879+
pd.Series,
880+
)
881+
check(
882+
assert_type(
883+
pd.cut(s1, bins=10),
884+
"pd.Series[pd.CategoricalDtype]",
885+
),
886+
pd.Series,
887+
pd.Interval,
888+
)
889+
s0r, s1r = pd.cut(s1, bins=10, retbins=True)
890+
check(assert_type(s0r, pd.Series), pd.Series, pd.Interval)
891+
check(assert_type(s1r, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp)
892+
s0rlf, s1rlf = pd.cut(s1, bins=10, labels=False, retbins=True)
893+
check(assert_type(s0rlf, pd.Series), pd.Series, int)
894+
check(assert_type(s1rlf, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp)
895+
s0rls, s1rls = pd.cut(s1, bins=4, labels=["1", "2", "3", "4"], retbins=True)
896+
check(assert_type(s0rls, pd.Series), pd.Series, str)
897+
check(assert_type(s1rls, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp)
898+
873899

874900
def test_qcut() -> None:
875901
val_list = [random.random() for _ in range(20)]

0 commit comments

Comments
 (0)