Skip to content

Commit 0ea787e

Browse files
raisadzdangotbannedMarcoGorelli
authored
feat: Add Expr.dt.offset_by() (#2733)
--------- Co-authored-by: dangotbanned <[email protected]> Co-authored-by: Marco Edward Gorelli <[email protected]>
1 parent d31d2f4 commit 0ea787e

File tree

16 files changed

+521
-41
lines changed

16 files changed

+521
-41
lines changed

docs/api-reference/expr_dt.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- minute
1414
- month
1515
- nanosecond
16+
- offset_by
1617
- ordinal_day
1718
- replace_time_zone
1819
- second

docs/api-reference/series_dt.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- minute
1414
- month
1515
- nanosecond
16+
- offset_by
1617
- ordinal_day
1718
- replace_time_zone
1819
- second

narwhals/_arrow/series_dt.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
US_PER_MINUTE,
1919
US_PER_SECOND,
2020
)
21-
from narwhals._duration import parse_interval_string
21+
from narwhals._duration import Interval
2222

2323
if TYPE_CHECKING:
2424
from collections.abc import Mapping
@@ -202,7 +202,25 @@ def total_nanoseconds(self) -> ArrowSeries:
202202
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
203203

204204
def truncate(self, every: str) -> ArrowSeries:
205-
multiple, unit = parse_interval_string(every)
205+
interval = Interval.parse(every)
206206
return self.with_native(
207-
pc.floor_temporal(self.native, multiple=multiple, unit=UNITS_DICT[unit])
207+
pc.floor_temporal(self.native, interval.multiple, UNITS_DICT[interval.unit])
208208
)
209+
210+
def offset_by(self, by: str) -> ArrowSeries:
211+
interval = Interval.parse_no_constraints(by)
212+
native = self.native
213+
if interval.unit in {"y", "q", "mo"}:
214+
msg = f"Offsetting by {interval.unit} is not yet supported for pyarrow."
215+
raise NotImplementedError(msg)
216+
if interval.unit == "d":
217+
offset: pa.DurationScalar[Any] = lit(interval.to_timedelta())
218+
if time_zone := native.type.tz:
219+
native_naive = pc.local_timestamp(native)
220+
result = pc.assume_timezone(pc.add(native_naive, offset), time_zone)
221+
return self.with_native(result)
222+
elif interval.unit == "ns": # pragma: no cover
223+
offset = lit(interval.multiple, pa.duration("ns")) # type: ignore[assignment]
224+
else:
225+
offset = lit(interval.to_timedelta())
226+
return self.with_native(pc.add(native, offset))

narwhals/_compliant/any_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def total_milliseconds(self) -> CompliantT_co: ...
4848
def total_microseconds(self) -> CompliantT_co: ...
4949
def total_nanoseconds(self) -> CompliantT_co: ...
5050
def truncate(self, every: str) -> CompliantT_co: ...
51+
def offset_by(self, by: str) -> CompliantT_co: ...
5152

5253

5354
class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):

narwhals/_compliant/expr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,9 @@ def total_nanoseconds(self) -> EagerExprT:
11121112
def truncate(self, every: str) -> EagerExprT:
11131113
return self.compliant._reuse_series_namespace("dt", "truncate", every=every)
11141114

1115+
def offset_by(self, by: str) -> EagerExprT:
1116+
return self.compliant._reuse_series_namespace("dt", "offset_by", by=by)
1117+
11151118

11161119
class EagerExprListNamespace(
11171120
EagerExprNamespace[EagerExprT], ListNamespace[EagerExprT], Generic[EagerExprT]

narwhals/_dask/expr_dt.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from narwhals._compliant import LazyExprNamespace
66
from narwhals._compliant.any_namespace import DateTimeNamespace
77
from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
8-
from narwhals._duration import parse_interval_string
8+
from narwhals._duration import Interval
99
from narwhals._pandas_like.utils import (
10-
UNIT_DICT,
10+
ALIAS_DICT,
1111
calculate_timestamp_date,
1212
calculate_timestamp_datetime,
1313
native_to_narwhals_dtype,
@@ -154,9 +154,22 @@ def total_nanoseconds(self) -> DaskExpr:
154154
)
155155

156156
def truncate(self, every: str) -> DaskExpr:
157-
multiple, unit = parse_interval_string(every)
157+
interval = Interval.parse(every)
158+
unit = interval.unit
158159
if unit in {"mo", "q", "y"}:
159-
msg = f"Truncating to {unit} is not supported yet for dask."
160+
msg = f"Truncating to {unit} is not yet supported for dask."
160161
raise NotImplementedError(msg)
161-
freq = f"{multiple}{UNIT_DICT.get(unit, unit)}"
162+
freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}"
162163
return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate")
164+
165+
def offset_by(self, by: str) -> DaskExpr:
166+
def func(s: dx.Series, by: str) -> dx.Series:
167+
interval = Interval.parse_no_constraints(by)
168+
unit = interval.unit
169+
if unit in {"y", "q", "mo", "d", "ns"}:
170+
msg = f"Offsetting by {unit} is not yet supported for dask."
171+
raise NotImplementedError(msg)
172+
offset = interval.to_timedelta()
173+
return s.add(offset)
174+
175+
return self.compliant._with_callable(func, "offset_by", by=by)

narwhals/_duckdb/expr_dt.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
US_PER_SECOND,
1414
)
1515
from narwhals._duckdb.utils import UNITS_DICT, F, fetch_rel_time_zone, lit
16-
from narwhals._duration import parse_interval_string
16+
from narwhals._duration import Interval
1717
from narwhals._utils import not_implemented
1818

1919
if TYPE_CHECKING:
@@ -99,7 +99,8 @@ def total_microseconds(self) -> DuckDBExpr:
9999
)
100100

101101
def truncate(self, every: str) -> DuckDBExpr:
102-
multiple, unit = parse_interval_string(every)
102+
interval = Interval.parse(every)
103+
multiple, unit = interval.multiple, interval.unit
103104
if multiple != 1:
104105
# https://github.com/duckdb/duckdb/issues/17554
105106
msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}."
@@ -114,6 +115,15 @@ def _truncate(expr: Expression) -> Expression:
114115

115116
return self.compliant._with_elementwise(_truncate)
116117

118+
def offset_by(self, by: str) -> DuckDBExpr:
119+
interval = Interval.parse_no_constraints(by)
120+
format = lit(f"{interval.multiple!s} {UNITS_DICT[interval.unit]}")
121+
122+
def _offset_by(expr: Expression) -> Expression:
123+
return F("date_add", format, expr)
124+
125+
return self.compliant._with_callable(_offset_by)
126+
117127
def _no_op_time_zone(self, time_zone: str) -> DuckDBExpr:
118128
def func(df: DuckDBLazyFrame) -> Sequence[Expression]:
119129
native_series_list = self.compliant(df)

narwhals/_duration.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
from __future__ import annotations
44

5+
import datetime as dt
56
import re
67
from typing import TYPE_CHECKING, Literal, cast, get_args
78

89
if TYPE_CHECKING:
10+
from collections.abc import Container, Mapping
11+
912
from typing_extensions import TypeAlias
1013

11-
__all__ = ["IntervalUnit", "parse_interval_string"]
14+
__all__ = ["IntervalUnit"]
1215

1316
IntervalUnit: TypeAlias = Literal["ns", "us", "ms", "s", "m", "h", "d", "mo", "q", "y"]
1417
"""A Polars duration string interval unit.
@@ -24,23 +27,43 @@
2427
- 'q': quarter.
2528
- 'y': year.
2629
"""
30+
TimedeltaKwd: TypeAlias = Literal[
31+
"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"
32+
]
2733

2834
PATTERN_INTERVAL: re.Pattern[str] = re.compile(
2935
r"^(?P<multiple>\d+)(?P<unit>ns|us|ms|mo|m|s|h|d|q|y)\Z"
3036
)
3137
MONTH_MULTIPLES = frozenset([1, 2, 3, 4, 6, 12])
3238
QUARTER_MULTIPLES = frozenset([1, 2, 4])
39+
UNIT_TO_TIMEDELTA: Mapping[IntervalUnit, TimedeltaKwd] = {
40+
"d": "days",
41+
"h": "hours",
42+
"m": "minutes",
43+
"s": "seconds",
44+
"ms": "milliseconds",
45+
"us": "microseconds",
46+
}
47+
3348

49+
class Interval:
50+
def __init__(self, multiple: int, unit: IntervalUnit, /) -> None:
51+
self.multiple: int = multiple
52+
self.unit: IntervalUnit = unit
3453

35-
def parse_interval_string(every: str) -> tuple[int, IntervalUnit]:
36-
"""Parse a string like "1d", "2h", "3m" into a tuple of (number, unit).
54+
def to_timedelta(
55+
self, *, unsupported: Container[IntervalUnit] = frozenset(("ns", "mo", "q", "y"))
56+
) -> dt.timedelta:
57+
if self.unit in unsupported: # pragma: no cover
58+
msg = f"Creating timedelta with {self.unit} unit is not supported."
59+
raise NotImplementedError(msg)
60+
kwd = UNIT_TO_TIMEDELTA[self.unit]
61+
# error: Keywords must be strings (bad mypy)
62+
return dt.timedelta(**{kwd: self.multiple}) # type: ignore[misc]
3763

38-
Returns:
39-
A tuple of multiple and unit parsed from the interval string.
40-
"""
41-
if match := PATTERN_INTERVAL.match(every):
42-
multiple = int(match["multiple"])
43-
unit = cast("IntervalUnit", match["unit"])
64+
@classmethod
65+
def parse(cls, every: str) -> Interval:
66+
multiple, unit = cls._parse(every)
4467
if unit == "mo" and multiple not in MONTH_MULTIPLES:
4568
msg = f"Only the following multiples are supported for 'mo' unit: {MONTH_MULTIPLES}.\nGot: {multiple}."
4669
raise ValueError(msg)
@@ -52,9 +75,20 @@ def parse_interval_string(every: str) -> tuple[int, IntervalUnit]:
5275
f"Only multiple 1 is currently supported for 'y' unit.\nGot: {multiple}."
5376
)
5477
raise ValueError(msg)
55-
return multiple, unit
56-
msg = (
57-
f"Invalid `every` string: {every}. Expected string of kind <number><unit>, "
58-
f"where 'unit' is one of: {get_args(IntervalUnit)}."
59-
)
60-
raise ValueError(msg)
78+
return cls(multiple, unit)
79+
80+
@classmethod
81+
def parse_no_constraints(cls, every: str) -> Interval:
82+
return cls(*cls._parse(every))
83+
84+
@staticmethod
85+
def _parse(every: str) -> tuple[int, IntervalUnit]:
86+
if match := PATTERN_INTERVAL.match(every):
87+
multiple = int(match["multiple"])
88+
unit = cast("IntervalUnit", match["unit"])
89+
return multiple, unit
90+
msg = (
91+
f"Invalid `every` string: {every}. Expected string of kind <number><unit>, "
92+
f"where 'unit' is one of: {get_args(IntervalUnit)}."
93+
)
94+
raise ValueError(msg)

narwhals/_ibis/expr_dt.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from narwhals._compliant import LazyExprNamespace
66
from narwhals._compliant.any_namespace import DateTimeNamespace
7-
from narwhals._duration import parse_interval_string
7+
from narwhals._duration import Interval
88
from narwhals._ibis.utils import UNITS_DICT_BUCKET, UNITS_DICT_TRUNCATE
99
from narwhals._utils import not_implemented
1010

@@ -68,7 +68,8 @@ def fn(expr: ir.TimestampValue) -> ir.TimestampValue:
6868
return fn
6969

7070
def truncate(self, every: str) -> IbisExpr:
71-
multiple, unit = parse_interval_string(every)
71+
interval = Interval.parse(every)
72+
multiple, unit = interval.multiple, interval.unit
7273
if unit == "q":
7374
multiple, unit = 3 * multiple, "mo"
7475
if multiple != 1:
@@ -80,6 +81,15 @@ def truncate(self, every: str) -> IbisExpr:
8081
fn = self._truncate(UNITS_DICT_TRUNCATE[unit])
8182
return self.compliant._with_callable(fn)
8283

84+
def offset_by(self, every: str) -> IbisExpr:
85+
interval = Interval.parse_no_constraints(every)
86+
unit = interval.unit
87+
if unit in {"y", "q", "mo", "d", "ns"}:
88+
msg = f"Offsetting by {unit} is not yet supported for ibis."
89+
raise NotImplementedError(msg)
90+
offset = interval.to_timedelta()
91+
return self.compliant._with_callable(lambda expr: expr.add(offset))
92+
8393
def replace_time_zone(self, time_zone: str | None) -> IbisExpr:
8494
if time_zone is None:
8595
return self.compliant._with_callable(lambda expr: expr.cast("timestamp"))

narwhals/_pandas_like/series_dt.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING, Any
44

5+
import pandas as pd
6+
57
from narwhals._compliant.any_namespace import DateTimeNamespace
68
from narwhals._constants import (
79
EPOCH_YEAR,
@@ -10,9 +12,10 @@
1012
SECONDS_PER_DAY,
1113
US_PER_SECOND,
1214
)
13-
from narwhals._duration import parse_interval_string
15+
from narwhals._duration import Interval
1416
from narwhals._pandas_like.utils import (
15-
UNIT_DICT,
17+
ALIAS_DICT,
18+
UNITS_DICT,
1619
PandasLikeSeriesNamespace,
1720
calculate_timestamp_date,
1821
calculate_timestamp_datetime,
@@ -203,13 +206,14 @@ def timestamp(self, time_unit: TimeUnit) -> PandasLikeSeries:
203206
return self.with_native(result)
204207

205208
def truncate(self, every: str) -> PandasLikeSeries:
206-
multiple, unit = parse_interval_string(every)
209+
interval = Interval.parse(every)
210+
multiple, unit = interval.multiple, interval.unit
207211
native = self.native
208212
if self.implementation.is_cudf():
209213
if multiple != 1:
210214
msg = f"Only multiple `1` is supported for cuDF, got: {multiple}."
211215
raise NotImplementedError(msg)
212-
return self.with_native(self.native.dt.floor(UNIT_DICT.get(unit, unit)))
216+
return self.with_native(self.native.dt.floor(ALIAS_DICT.get(unit, unit)))
213217
dtype_backend = get_dtype_backend(native.dtype, self.compliant._implementation)
214218
if unit in {"mo", "q", "y"}:
215219
if self.implementation.is_cudf():
@@ -218,8 +222,6 @@ def truncate(self, every: str) -> PandasLikeSeries:
218222
if dtype_backend == "pyarrow":
219223
import pyarrow.compute as pc # ignore-banned-import
220224

221-
from narwhals._arrow.utils import UNITS_DICT
222-
223225
ca = native.array._pa_array
224226
result_arr = pc.floor_temporal(ca, multiple, UNITS_DICT[unit])
225227
else:
@@ -230,7 +232,7 @@ def truncate(self, every: str) -> PandasLikeSeries:
230232
np_unit = "M"
231233
else:
232234
np_unit = "Y"
233-
arr = native.values
235+
arr = native.values # noqa: PD011
234236
arr_dtype = arr.dtype
235237
result_arr = arr.astype(f"datetime64[{multiple}{np_unit}]").astype(
236238
arr_dtype
@@ -240,5 +242,48 @@ def truncate(self, every: str) -> PandasLikeSeries:
240242
)
241243
return self.with_native(result_native)
242244
return self.with_native(
243-
self.native.dt.floor(f"{multiple}{UNIT_DICT.get(unit, unit)}")
245+
self.native.dt.floor(f"{multiple}{ALIAS_DICT.get(unit, unit)}")
244246
)
247+
248+
def offset_by(self, by: str) -> PandasLikeSeries:
249+
if self.implementation.is_cudf():
250+
msg = "Not implemented for cuDF."
251+
raise NotImplementedError(msg)
252+
native = self.native
253+
if self._is_pyarrow():
254+
import pyarrow as pa # ignore-banned-import
255+
256+
compliant = self.compliant
257+
ca = pa.chunked_array([compliant.to_arrow()]) # type: ignore[arg-type]
258+
result = (
259+
compliant._version.namespace.from_backend("pyarrow")
260+
.compliant.from_native(ca)
261+
.dt.offset_by(by)
262+
.native
263+
)
264+
result_pd = native.__class__(
265+
result, dtype=native.dtype, index=native.index, name=native.name
266+
)
267+
else:
268+
interval = Interval.parse_no_constraints(by)
269+
multiple, unit = interval.multiple, interval.unit
270+
if unit == "q":
271+
multiple *= 3
272+
unit = "mo"
273+
offset: pd.DateOffset | pd.Timedelta
274+
if unit == "y":
275+
offset = pd.DateOffset(years=multiple)
276+
elif unit == "mo":
277+
offset = pd.DateOffset(months=multiple)
278+
else:
279+
offset = pd.Timedelta(multiple, unit=UNITS_DICT[unit]) # type: ignore[arg-type]
280+
if unit == "d":
281+
original_timezone = native.dt.tz
282+
native_without_timezone = native.dt.tz_localize(None)
283+
result_pd = native_without_timezone + offset
284+
if original_timezone is not None:
285+
result_pd = result_pd.dt.tz_localize(original_timezone)
286+
else:
287+
result_pd = native + offset
288+
289+
return self.with_native(result_pd)

0 commit comments

Comments
 (0)