Skip to content

Commit d78d534

Browse files
authored
TYP: stronger typing for unit, as_unit (#63162)
1 parent 8eff6a8 commit d78d534

File tree

12 files changed

+52
-19
lines changed

12 files changed

+52
-19
lines changed

pandas/_libs/tslibs/nattype.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ from pandas._libs.tslibs.period import Period
1717
from pandas._typing import (
1818
Frequency,
1919
TimestampNonexistent,
20+
TimeUnit,
2021
)
2122

2223
NaT: NaTType
@@ -180,4 +181,4 @@ class NaTType:
180181
def __floordiv__(self, other: float, /) -> Self: ...
181182
# other
182183
def __hash__(self) -> int: ...
183-
def as_unit(self, unit: str, round_ok: bool = ...) -> NaTType: ...
184+
def as_unit(self, unit: TimeUnit, round_ok: bool = ...) -> NaTType: ...

pandas/_libs/tslibs/timedeltas.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ from pandas._libs.tslibs import (
1515
)
1616
from pandas._typing import (
1717
Frequency,
18+
TimeUnit,
1819
npt,
1920
)
2021

@@ -162,5 +163,5 @@ class Timedelta(timedelta):
162163
) -> np.timedelta64: ...
163164
def view(self, dtype: npt.DTypeLike) -> object: ...
164165
@property
165-
def unit(self) -> str: ...
166-
def as_unit(self, unit: str, round_ok: bool = ...) -> Timedelta: ...
166+
def unit(self) -> TimeUnit: ...
167+
def as_unit(self, unit: TimeUnit, round_ok: bool = ...) -> Timedelta: ...

pandas/_libs/tslibs/timestamps.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ from pandas._libs.tslibs import (
2323
Tick,
2424
Timedelta,
2525
)
26-
from pandas._typing import TimestampNonexistent
26+
from pandas._typing import (
27+
TimestampNonexistent,
28+
TimeUnit,
29+
)
2730

2831
_TimeZones: TypeAlias = str | _tzinfo | None | int
2932

@@ -235,5 +238,5 @@ class Timestamp(datetime):
235238
@property
236239
def daysinmonth(self) -> int: ...
237240
@property
238-
def unit(self) -> str: ...
239-
def as_unit(self, unit: str, round_ok: bool = ...) -> Timestamp: ...
241+
def unit(self) -> TimeUnit: ...
242+
def as_unit(self, unit: TimeUnit, round_ok: bool = ...) -> Timestamp: ...

pandas/core/arrays/_ranges.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,18 @@
2222
from pandas.core.construction import range_to_ndarray
2323

2424
if TYPE_CHECKING:
25-
from pandas._typing import npt
25+
from pandas._typing import (
26+
TimeUnit,
27+
npt,
28+
)
2629

2730

2831
def generate_regular_range(
2932
start: Timestamp | Timedelta | None,
3033
end: Timestamp | Timedelta | None,
3134
periods: int | None,
3235
freq: BaseOffset,
33-
unit: str = "ns",
36+
unit: TimeUnit = "ns",
3437
) -> npt.NDArray[np.intp]:
3538
"""
3639
Generate a range of dates or timestamps with the spans between dates
@@ -46,7 +49,7 @@ def generate_regular_range(
4649
Number of periods in produced date range.
4750
freq : Tick
4851
Describes space between dates in produced date range.
49-
unit : str, default "ns"
52+
unit : {'s', 'ms', 'us', 'ns'}, default "ns"
5053
The resolution the output is meant to represent.
5154
5255
Returns

pandas/core/arrays/datetimes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,15 @@
107107

108108

109109
@overload
110-
def tz_to_dtype(tz: tzinfo, unit: str = ...) -> DatetimeTZDtype: ...
110+
def tz_to_dtype(tz: tzinfo, unit: TimeUnit = ...) -> DatetimeTZDtype: ...
111111

112112

113113
@overload
114-
def tz_to_dtype(tz: None, unit: str = ...) -> np.dtype[np.datetime64]: ...
114+
def tz_to_dtype(tz: None, unit: TimeUnit = ...) -> np.dtype[np.datetime64]: ...
115115

116116

117117
def tz_to_dtype(
118-
tz: tzinfo | None, unit: str = "ns"
118+
tz: tzinfo | None, unit: TimeUnit = "ns"
119119
) -> np.dtype[np.datetime64] | DatetimeTZDtype:
120120
"""
121121
Return a datetime64[ns] dtype appropriate for the given timezone.
@@ -393,6 +393,7 @@ def _from_sequence_not_strict(
393393
)
394394

395395
data_unit = np.datetime_data(subarr.dtype)[0]
396+
data_unit = cast("TimeUnit", data_unit)
396397
data_dtype = tz_to_dtype(tz, data_unit)
397398
result = cls._simple_new(subarr, freq=inferred_freq, dtype=data_dtype)
398399
if unit is not None and unit != result.unit:
@@ -2935,7 +2936,7 @@ def _generate_range(
29352936
periods: int | None,
29362937
offset: BaseOffset,
29372938
*,
2938-
unit: str,
2939+
unit: TimeUnit,
29392940
) -> Generator[Timestamp]:
29402941
"""
29412942
Generates a sequence of dates corresponding to the specified time

pandas/core/dtypes/cast.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
DtypeObj,
101101
NumpyIndexT,
102102
Scalar,
103+
TimeUnit,
103104
)
104105

105106
from pandas import Index
@@ -567,6 +568,7 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
567568
# different unit, e.g. passed np.timedelta64(24, "h") with dtype=m8[ns]
568569
# see if we can losslessly cast it to our dtype
569570
unit = np.datetime_data(dtype)[0]
571+
unit = cast("TimeUnit", unit)
570572
try:
571573
td = Timedelta(fill_value).as_unit(unit, round_ok=False)
572574
except OutOfBoundsTimedelta:

pandas/core/dtypes/dtypes.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
IntervalClosedType,
8686
Ordered,
8787
Scalar,
88+
TimeUnit,
8889
npt,
8990
type_t,
9091
)
@@ -780,10 +781,10 @@ def base(self) -> DtypeObj: # type: ignore[override]
780781
def str(self) -> str: # type: ignore[override]
781782
return f"|M8[{self.unit}]"
782783

783-
def __init__(self, unit: str_type | DatetimeTZDtype = "ns", tz=None) -> None:
784+
def __init__(self, unit: TimeUnit | DatetimeTZDtype = "ns", tz=None) -> None:
784785
if isinstance(unit, DatetimeTZDtype):
785786
# error: "str" has no attribute "tz"
786-
unit, tz = unit.unit, unit.tz # type: ignore[attr-defined]
787+
unit, tz = unit.unit, unit.tz # type: ignore[union-attr]
787788

788789
if unit != "ns":
789790
if isinstance(unit, str) and tz is None:
@@ -820,7 +821,7 @@ def _creso(self) -> int:
820821
return abbrev_to_npy_unit(self.unit)
821822

822823
@property
823-
def unit(self) -> str_type:
824+
def unit(self) -> TimeUnit:
824825
"""
825826
The precision of the datetime data.
826827
@@ -894,7 +895,8 @@ def construct_from_string(cls, string: str_type) -> DatetimeTZDtype:
894895
if match:
895896
d = match.groupdict()
896897
try:
897-
return cls(unit=d["unit"], tz=d["tz"])
898+
unit = cast("TimeUnit", d["unit"])
899+
return cls(unit=unit, tz=d["tz"])
898900
except (KeyError, TypeError, ValueError) as err:
899901
# KeyError if maybe_get_tz tries and fails to get a
900902
# zoneinfo timezone (actually zoneinfo.ZoneInfoNotFoundError).
@@ -971,6 +973,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
971973
if all(isinstance(t, DatetimeTZDtype) and t.tz == self.tz for t in dtypes):
972974
np_dtype = np.max([cast(DatetimeTZDtype, t).base for t in [self, *dtypes]])
973975
unit = np.datetime_data(np_dtype)[0]
976+
unit = cast("TimeUnit", unit)
974977
return type(self)(unit=unit, tz=self.tz)
975978
return super()._get_common_dtype(dtypes)
976979

pandas/core/reshape/tile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TYPE_CHECKING,
99
Any,
1010
Literal,
11+
cast,
1112
)
1213

1314
import numpy as np
@@ -49,6 +50,7 @@
4950
from pandas._typing import (
5051
DtypeObj,
5152
IntervalLeftRight,
53+
TimeUnit,
5254
)
5355

5456

@@ -412,7 +414,7 @@ def _nbins_to_bins(x_idx: Index, nbins: int, right: bool) -> Index:
412414
# error: Argument 1 to "dtype_to_unit" has incompatible type
413415
# "dtype[Any] | ExtensionDtype"; expected "DatetimeTZDtype | dtype[Any]"
414416
unit = dtype_to_unit(x_idx.dtype) # type: ignore[arg-type]
415-
td = Timedelta(seconds=1).as_unit(unit)
417+
td = Timedelta(seconds=1).as_unit(cast("TimeUnit", unit))
416418
# Use DatetimeArray/TimedeltaArray method instead of linspace
417419
# error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]"
418420
# has no attribute "_generate_range"
@@ -595,6 +597,7 @@ def _format_labels(
595597
# error: Argument 1 to "dtype_to_unit" has incompatible type
596598
# "dtype[Any] | ExtensionDtype"; expected "DatetimeTZDtype | dtype[Any]"
597599
unit = dtype_to_unit(bins.dtype) # type: ignore[arg-type]
600+
unit = cast("TimeUnit", unit)
598601
formatter = lambda x: x
599602
adjust = lambda x: x - Timedelta(1, unit=unit).as_unit(unit)
600603
else:

pandas/core/tools/datetimes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787

8888
from pandas._libs.tslibs.nattype import NaTType
8989
from pandas._libs.tslibs.timedeltas import UnitChoices
90+
from pandas._typing import TimeUnit
9091

9192
from pandas import (
9293
DataFrame,
@@ -447,6 +448,7 @@ def _convert_listlike_datetimes(
447448
# We can take a shortcut since the datetime64 numpy array
448449
# is in UTC
449450
out_unit = np.datetime_data(result.dtype)[0]
451+
out_unit = cast("TimeUnit", out_unit)
450452
dtype = tz_to_dtype(tz_parsed, out_unit)
451453
dt64_values = result.view(f"M8[{dtype.unit}]")
452454
dta = DatetimeArray._simple_new(dt64_values, dtype=dtype)
@@ -469,13 +471,15 @@ def _array_strptime_with_fallback(
469471
result, tz_out = array_strptime(arg, fmt, exact=exact, errors=errors, utc=utc)
470472
if tz_out is not None:
471473
unit = np.datetime_data(result.dtype)[0]
474+
unit = cast("TimeUnit", unit)
472475
dtype = DatetimeTZDtype(tz=tz_out, unit=unit)
473476
dta = DatetimeArray._simple_new(result, dtype=dtype)
474477
if utc:
475478
dta = dta.tz_convert("UTC")
476479
return Index(dta, name=name)
477480
elif result.dtype != object and utc:
478481
unit = np.datetime_data(result.dtype)[0]
482+
unit = cast("TimeUnit", unit)
479483
res = Index(result, dtype=f"M8[{unit}, UTC]", name=name)
480484
return res
481485
return Index(result, dtype=result.dtype, name=name)

pandas/core/window/ewm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import datetime
44
from functools import partial
55
from textwrap import dedent
6-
from typing import TYPE_CHECKING
6+
from typing import (
7+
TYPE_CHECKING,
8+
cast,
9+
)
710

811
import numpy as np
912

@@ -60,6 +63,7 @@
6063
if TYPE_CHECKING:
6164
from pandas._typing import (
6265
TimedeltaConvertibleTypes,
66+
TimeUnit,
6367
npt,
6468
)
6569

@@ -125,6 +129,7 @@ def _calculate_deltas(
125129
Diff of the times divided by the half-life
126130
"""
127131
unit = dtype_to_unit(times.dtype)
132+
unit = cast("TimeUnit", unit)
128133
if isinstance(times, ABCSeries):
129134
times = times._values
130135
_times = np.asarray(times.view(np.int64), dtype=np.float64)

0 commit comments

Comments
 (0)