Skip to content

Commit 4adc6ce

Browse files
authored
enh: support offset_by for cudf (#2823)
1 parent ebbd7fa commit 4adc6ce

File tree

3 files changed

+69
-61
lines changed

3 files changed

+69
-61
lines changed

narwhals/_arrow/series_dt.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,14 @@ def offset_by(self, by: str) -> ArrowSeries:
213213
if interval.unit in {"y", "q", "mo"}:
214214
msg = f"Offsetting by {interval.unit} is not yet supported for pyarrow."
215215
raise NotImplementedError(msg)
216-
if interval.unit == "d":
216+
dtype = self.compliant.dtype
217+
datetime_dtype = self.version.dtypes.Datetime
218+
if interval.unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
217219
offset: pa.DurationScalar[Any] = lit(interval.to_timedelta())
218-
if time_zone := native.type.tz:
219-
native_naive = pc.local_timestamp(native)
220-
result = pc.assume_timezone(pc.add(native_naive, offset), time_zone)
221-
return self.with_native(result)
222-
elif interval.unit == "ns": # pragma: no cover
220+
native_naive = pc.local_timestamp(native)
221+
result = pc.assume_timezone(pc.add(native_naive, offset), dtype.time_zone)
222+
return self.with_native(result)
223+
if interval.unit == "ns": # pragma: no cover
223224
offset = lit(interval.multiple, pa.duration("ns")) # type: ignore[assignment]
224225
else:
225226
offset = lit(interval.to_timedelta())

narwhals/_pandas_like/series_dt.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from typing import TYPE_CHECKING, Any
44

5-
import pandas as pd
6-
75
from narwhals._compliant.any_namespace import DateTimeNamespace
86
from narwhals._constants import (
97
EPOCH_YEAR,
@@ -25,6 +23,10 @@
2523
)
2624

2725
if TYPE_CHECKING:
26+
from datetime import timedelta
27+
28+
import pandas as pd
29+
2830
from narwhals._pandas_like.series import PandasLikeSeries
2931
from narwhals.typing import TimeUnit
3032

@@ -246,10 +248,8 @@ def truncate(self, every: str) -> PandasLikeSeries:
246248
)
247249

248250
def offset_by(self, by: str) -> PandasLikeSeries:
249-
if self.implementation.is_cudf():
250-
msg = "Not implemented for cuDF."
251-
raise NotImplementedError(msg)
252251
native = self.native
252+
pdx = self.compliant.__native_namespace__()
253253
if self._is_pyarrow():
254254
import pyarrow as pa # ignore-banned-import
255255

@@ -270,19 +270,21 @@ def offset_by(self, by: str) -> PandasLikeSeries:
270270
if unit == "q":
271271
multiple *= 3
272272
unit = "mo"
273-
offset: pd.DateOffset | pd.Timedelta
273+
offset: pd.DateOffset | timedelta
274274
if unit == "y":
275-
offset = pd.DateOffset(years=multiple)
275+
offset = pdx.DateOffset(years=multiple)
276276
elif unit == "mo":
277-
offset = pd.DateOffset(months=multiple)
277+
offset = pdx.DateOffset(months=multiple)
278+
elif unit == "ns":
279+
offset = pdx.Timedelta(multiple, unit=UNITS_DICT[unit])
278280
else:
279-
offset = pd.Timedelta(multiple, unit=UNITS_DICT[unit]) # type: ignore[arg-type]
280-
if unit == "d":
281-
original_timezone = native.dt.tz
281+
offset = interval.to_timedelta()
282+
dtype = self.compliant.dtype
283+
datetime_dtype = self.version.dtypes.Datetime
284+
if unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
282285
native_without_timezone = native.dt.tz_localize(None)
283286
result_pd = native_without_timezone + offset
284-
if original_timezone is not None:
285-
result_pd = result_pd.dt.tz_localize(original_timezone)
287+
result_pd = result_pd.dt.tz_localize(dtype.time_zone)
286288
else:
287289
result_pd = native + offset
288290

tests/expr_and_series/dt/offset_by_test.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from datetime import datetime, timezone
3+
from datetime import date, datetime, timezone
44

55
import pytest
66

@@ -32,6 +32,13 @@
3232
datetime(2020, 1, 2, 2, 4, 14, 715125),
3333
],
3434
),
35+
(
36+
"2000ns",
37+
[
38+
datetime(2021, 3, 1, 12, 34, 56, 49014),
39+
datetime(2020, 1, 2, 2, 4, 14, 715125),
40+
],
41+
),
3542
(
3643
"2ms",
3744
[
@@ -107,29 +114,27 @@ def test_offset_by(
107114
x in str(constructor) for x in ("dask", "pyarrow", "ibis")
108115
):
109116
request.applymarker(pytest.mark.xfail())
117+
if "ns" in by and any(
118+
x in str(constructor) for x in ("dask", "pyspark", "ibis", "cudf", "duckdb")
119+
):
120+
request.applymarker(pytest.mark.xfail())
110121
if by.endswith("d") and any(x in str(constructor) for x in ("dask", "ibis")):
111122
request.applymarker(pytest.mark.xfail())
112-
if "cudf" in str(constructor):
113-
# https://github.com/rapidsai/cudf/issues/19292
114-
request.applymarker(pytest.mark.xfail)
115123
result = df.select(nw.col("a").dt.offset_by(by))
116124
assert_equal_data(result, {"a": expected})
117125

118126

119127
@pytest.mark.parametrize(
120128
("by", "expected"),
121129
[
122-
("2d", ["2024-01-03T05:45+0545"]),
123-
("5mo", ["2024-06-01T05:45+0545"]),
124-
("7q", ["2025-10-01T05:45+0545"]),
125-
("5y", ["2029-01-01T05:45+0545"]),
130+
("2d", "2024-01-03T05:45+0545"),
131+
("5mo", "2024-06-01T05:45+0545"),
132+
("7q", "2025-10-01T05:45+0545"),
133+
("5y", "2029-01-01T05:45+0545"),
126134
],
127135
)
128136
def test_offset_by_tz(
129-
request: pytest.FixtureRequest,
130-
constructor: Constructor,
131-
by: str,
132-
expected: list[datetime],
137+
request: pytest.FixtureRequest, constructor: Constructor, by: str, expected: str
133138
) -> None:
134139
if (
135140
("pyarrow" in str(constructor) and is_windows())
@@ -142,35 +147,31 @@ def test_offset_by_tz(
142147
# pyspark,duckdb don't support changing time zones.
143148
# convert_time_zone is not supported for ibis.
144149
request.applymarker(pytest.mark.xfail())
150+
if any(x in str(constructor) for x in ("cudf",)) and "d" not in by:
151+
# cudf: https://github.com/rapidsai/cudf/issues/19363
152+
request.applymarker(pytest.mark.xfail())
145153
if any(x in by for x in ("y", "q", "mo")) and any(
146154
x in str(constructor) for x in ("dask", "pyarrow", "ibis")
147155
):
148156
request.applymarker(pytest.mark.xfail())
149157
if by.endswith("d") and any(x in str(constructor) for x in ("dask",)):
150158
request.applymarker(pytest.mark.xfail())
151-
if "cudf" in str(constructor):
152-
# https://github.com/rapidsai/cudf/issues/19292
153-
request.applymarker(pytest.mark.xfail)
154159
df = nw.from_native(constructor(data_tz))
155160
df = df.select(nw.col("a").dt.convert_time_zone("Asia/Kathmandu"))
156161
result = df.select(nw.col("a").dt.offset_by(by))
157-
result_str = result.select(nw.col("a").dt.to_string("%Y-%m-%dT%H:%M%z"))
158-
assert_equal_data(result_str, {"a": expected})
162+
assert_equal_data(result, {"a": [datetime.strptime(expected, "%Y-%m-%dT%H:%M%z")]})
159163

160164

161165
@pytest.mark.parametrize(
162166
("by", "expected"),
163167
[
164-
("2d", ["2020-10-27T02:00+0100"]),
165-
("5mo", ["2021-03-25T02:00+0100"]),
166-
("1q", ["2021-01-25T02:00+0100"]),
168+
("2d", "2020-10-27T02:00+0100"),
169+
("5mo", "2021-03-25T02:00+0100"),
170+
("1q", "2021-01-25T02:00+0100"),
167171
],
168172
)
169173
def test_offset_by_dst(
170-
request: pytest.FixtureRequest,
171-
constructor: Constructor,
172-
by: str,
173-
expected: list[datetime],
174+
request: pytest.FixtureRequest, constructor: Constructor, by: str, expected: str
174175
) -> None:
175176
if (
176177
("pyarrow" in str(constructor) and is_windows())
@@ -179,12 +180,12 @@ def test_offset_by_dst(
179180
or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1))
180181
):
181182
pytest.skip()
182-
if any(
183-
x in str(constructor) for x in ("duckdb", "sqlframe", "pyspark", "ibis", "cudf")
184-
):
183+
if any(x in str(constructor) for x in ("duckdb", "sqlframe", "pyspark", "ibis")):
185184
# pyspark,duckdb don't support changing time zones.
186185
# convert_time_zone is not supported for ibis.
187-
# cudf https://github.com/rapidsai/cudf/issues/19292
186+
request.applymarker(pytest.mark.xfail())
187+
if any(x in str(constructor) for x in ("cudf",)) and "d" not in by:
188+
# cudf: https://github.com/rapidsai/cudf/issues/19363
188189
request.applymarker(pytest.mark.xfail())
189190
if any(x in by for x in ("y", "q", "mo")) and any(
190191
x in str(constructor) for x in ("dask", "pyarrow")
@@ -195,16 +196,10 @@ def test_offset_by_dst(
195196
df = nw.from_native(constructor(data_dst))
196197
df = df.with_columns(a=nw.col("a").dt.convert_time_zone("Europe/Amsterdam"))
197198
result = df.select(nw.col("a").dt.offset_by(by))
198-
result_str = result.select(nw.col("a").dt.to_string("%Y-%m-%dT%H:%M%z"))
199-
assert_equal_data(result_str, {"a": expected})
199+
assert_equal_data(result, {"a": [datetime.strptime(expected, "%Y-%m-%dT%H:%M%z")]})
200200

201201

202-
def test_offset_by_series(
203-
constructor_eager: ConstructorEager, request: pytest.FixtureRequest
204-
) -> None:
205-
if "cudf" in str(constructor_eager):
206-
# https://github.com/rapidsai/cudf/issues/19292
207-
request.applymarker(pytest.mark.xfail)
202+
def test_offset_by_series(constructor_eager: ConstructorEager) -> None:
208203
df = nw.from_native(constructor_eager(data), eager_only=True)
209204
result = df.select(df["a"].dt.offset_by("1h"))
210205
expected = {
@@ -216,13 +211,23 @@ def test_offset_by_series(
216211
assert_equal_data(result, expected)
217212

218213

219-
def test_offset_by_invalid_interval(
220-
constructor: Constructor, request: pytest.FixtureRequest
221-
) -> None:
222-
if "cudf" in str(constructor):
223-
# https://github.com/rapidsai/cudf/issues/19292
224-
request.applymarker(pytest.mark.xfail)
214+
def test_offset_by_invalid_interval(constructor: Constructor) -> None:
225215
df = nw.from_native(constructor(data))
226216
msg = "Invalid `every` string"
227217
with pytest.raises(ValueError, match=msg):
228218
df.select(nw.col("a").dt.offset_by("1r"))
219+
220+
221+
@pytest.mark.skipif(PANDAS_VERSION < (2, 2), reason="too old for pyarrow date type")
222+
def test_offset_by_date_pandas() -> None:
223+
pytest.importorskip("pandas")
224+
import pandas as pd
225+
226+
df = nw.from_native(pd.DataFrame({"a": [date(2020, 1, 1)]}, dtype="date32[pyarrow]"))
227+
result = df.select(nw.col("a").dt.offset_by("1d"))
228+
expected = {"a": [date(2020, 1, 2)]}
229+
assert_equal_data(result, expected)
230+
df = nw.from_native(pd.DataFrame({"a": [date(2020, 1, 1)]}))
231+
result = df.select(nw.col("a").dt.offset_by("1d"))
232+
expected = {"a": [date(2020, 1, 2)]}
233+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)