Skip to content

Commit ccae8ea

Browse files
committed
feat(expr-ir): Add date_range
Mostly just re-purposing `int_range` *so far*, as mentioned in (#2895 (comment)) Need to think about `Interval` some more
1 parent 50bcb9c commit ccae8ea

File tree

7 files changed

+116
-24
lines changed

7 files changed

+116
-24
lines changed

narwhals/_plan/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
any_horizontal,
1010
col,
1111
concat_str,
12+
date_range,
1213
exclude,
1314
int_range,
1415
len,
@@ -38,6 +39,7 @@
3839
"any_horizontal",
3940
"col",
4041
"concat_str",
42+
"date_range",
4143
"exclude",
4244
"int_range",
4345
"len",

narwhals/_plan/arrow/functions.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from narwhals._utils import Implementation
2121

2222
if TYPE_CHECKING:
23+
import datetime as dt
2324
from collections.abc import Iterable, Mapping
2425

2526
from typing_extensions import TypeAlias, TypeIs
@@ -43,6 +44,7 @@
4344
DataType,
4445
DataTypeRemap,
4546
DataTypeT,
47+
DateScalar,
4648
IntegerScalar,
4749
IntegerType,
4850
LargeStringType,
@@ -328,12 +330,38 @@ def int_range(
328330
*,
329331
dtype: IntegerType = pa.int64(), # noqa: B008
330332
) -> ChunkedArray[IntegerScalar]:
331-
import numpy as np # ignore-banned-import
332-
333333
if end is None:
334334
end = start
335335
start = 0
336-
return pa.chunked_array([pa.array(np.arange(start, end, step), dtype)])
336+
if BACKEND_VERSION < (21, 0, 0): # pragma: no cover
337+
import numpy as np # ignore-banned-import
338+
339+
arr = pa.array(np.arange(start=start, stop=end, step=step), type=dtype)
340+
else:
341+
pa_arange: Incomplete = t.cast("Incomplete", pa.arange) # type: ignore[attr-defined]
342+
arr = t.cast("ArrayAny", pa_arange(start=start, stop=end, step=step)).cast(dtype)
343+
return pa.chunked_array([arr])
344+
345+
346+
def date_range(
347+
start: dt.date,
348+
end: dt.date,
349+
interval: int, # (* assuming the `Interval` part is solved)
350+
*,
351+
closed: ClosedInterval = "both",
352+
) -> ChunkedArray[DateScalar]:
353+
start_i = pa.scalar(start).cast(pa.int32()).as_py()
354+
end_i = pa.scalar(end).cast(pa.int32()).as_py()
355+
ca = int_range(start_i, end_i + 1, interval, dtype=pa.int32())
356+
if closed == "both":
357+
return ca.cast(pa.date32())
358+
if closed == "left":
359+
ca = ca.slice(length=ca.length() - 1)
360+
elif closed == "none":
361+
ca = ca.slice(1, length=ca.length() - 1)
362+
else:
363+
ca = ca.slice(1)
364+
return ca.cast(pa.date32())
337365

338366

339367
def nulls_like(n: int, native: ArrowAny) -> ArrayAny:

narwhals/_plan/arrow/namespace.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import datetime as dt
34
from functools import reduce
45
from typing import TYPE_CHECKING, Any, Literal, cast, overload
56

@@ -11,6 +12,7 @@
1112
from narwhals._plan.arrow import functions as fn
1213
from narwhals._plan.compliant.namespace import EagerNamespace
1314
from narwhals._plan.expressions.literal import is_literal_scalar
15+
from narwhals._typing_compat import TypeVar
1416
from narwhals._utils import Version
1517
from narwhals.exceptions import InvalidOperationError
1618

@@ -24,12 +26,15 @@
2426
from narwhals._plan.expressions import expr, functions as F
2527
from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal
2628
from narwhals._plan.expressions.expr import FunctionExpr, RangeExpr
27-
from narwhals._plan.expressions.ranges import IntRange
29+
from narwhals._plan.expressions.ranges import DateRange, IntRange
2830
from narwhals._plan.expressions.strings import ConcatStr
2931
from narwhals._plan.series import Series as NwSeries
3032
from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral
3133

3234

35+
PythonLiteralT = TypeVar("PythonLiteralT", bound="PythonLiteral")
36+
37+
3338
class ArrowNamespace(EagerNamespace["Frame", "Series", "Expr", "Scalar"]):
3439
def __init__(self, version: Version = Version.MAIN) -> None:
3540
self._version = version
@@ -155,12 +160,12 @@ def concat_str(
155160
return self._scalar.from_native(result, name, self.version)
156161
return self._expr.from_native(result, name, self.version)
157162

158-
def int_range(self, node: RangeExpr[IntRange], frame: Frame, name: str) -> Expr:
163+
def _range_function_inputs(
164+
self, node: RangeExpr, frame: Frame, valid_type: type[PythonLiteralT]
165+
) -> tuple[PythonLiteralT, PythonLiteralT]:
159166
start_: PythonLiteral
160167
end_: PythonLiteral
161168
start, end = node.function.unwrap_input(node)
162-
step = node.function.step
163-
dtype = node.function.dtype
164169
if is_literal_scalar(start) and is_literal_scalar(end):
165170
start_, end_ = start.unwrap(), end.unwrap()
166171
else:
@@ -172,22 +177,29 @@ def int_range(self, node: RangeExpr[IntRange], frame: Frame, name: str) -> Expr:
172177
start_, end_ = scalar_start.to_python(), scalar_end.to_python()
173178
else:
174179
msg = (
175-
f"All inputs for `int_range()` must be scalar or aggregations, but got \n"
180+
f"All inputs for `{node.function}()` must be scalar or aggregations, but got \n"
176181
f"{scalar_start.native!r}\n{scalar_end.native!r}"
177182
)
178183
raise InvalidOperationError(msg)
179-
if isinstance(start_, int) and isinstance(end_, int):
180-
pa_dtype = narwhals_to_native_dtype(dtype, self.version)
181-
if not pa.types.is_integer(pa_dtype):
182-
raise TypeError(pa_dtype)
183-
native = fn.int_range(start_, end_, step, dtype=pa_dtype)
184-
return self._expr.from_native(native, name, self.version)
185-
186-
msg = (
187-
f"All inputs for `int_range()` resolve to int, but got \n{start_!r}\n{end_!r}"
188-
)
184+
if isinstance(start_, valid_type) and isinstance(end_, valid_type):
185+
return start_, end_
186+
msg = f"All inputs for `{node.function}()` resolve to {valid_type.__name__}, but got \n{start_!r}\n{end_!r}"
189187
raise InvalidOperationError(msg)
190188

189+
def int_range(self, node: RangeExpr[IntRange], frame: Frame, name: str) -> Expr:
190+
start, end = self._range_function_inputs(node, frame, int)
191+
dtype = narwhals_to_native_dtype(node.function.dtype, self.version)
192+
if not pa.types.is_integer(dtype):
193+
raise TypeError(dtype)
194+
native = fn.int_range(start, end, node.function.step, dtype=dtype)
195+
return self._expr.from_native(native, name, self.version)
196+
197+
def date_range(self, node: RangeExpr[DateRange], frame: Frame, name: str) -> Expr:
198+
start, end = self._range_function_inputs(node, frame, dt.date)
199+
func = node.function
200+
native = fn.date_range(start, end, func.interval, closed=func.closed)
201+
return self._expr.from_native(native, name, self.version)
202+
191203
@overload
192204
def concat(self, items: Iterable[Frame], *, how: ConcatMethod) -> Frame: ...
193205
@overload

narwhals/_plan/arrow/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pyarrow as pa
1111
import pyarrow.compute as pc
1212
from pyarrow.lib import (
13+
Date32Type,
1314
Int8Type,
1415
Int16Type,
1516
Int32Type,
@@ -28,6 +29,7 @@
2829
StringScalar: TypeAlias = "Scalar[StringType | LargeStringType]"
2930
IntegerType: TypeAlias = "Int8Type | Int16Type | Int32Type | Int64Type | Uint8Type | Uint16Type | Uint32Type | Uint64Type"
3031
IntegerScalar: TypeAlias = "Scalar[IntegerType]"
32+
DateScalar: TypeAlias = "Scalar[Date32Type]"
3133

3234
class NativeArrowSeries(NativeSeries, Protocol):
3335
@property

narwhals/_plan/compliant/namespace.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from narwhals._plan import expressions as ir
2626
from narwhals._plan.expressions import FunctionExpr, boolean, functions as F
27-
from narwhals._plan.expressions.ranges import IntRange
27+
from narwhals._plan.expressions.ranges import DateRange, IntRange
2828
from narwhals._plan.expressions.strings import ConcatStr
2929
from narwhals._plan.series import Series
3030
from narwhals.typing import ConcatMethod, NonNestedLiteral
@@ -47,6 +47,9 @@ def col(self, node: ir.Column, frame: FrameT, name: str) -> ExprT_co: ...
4747
def concat_str(
4848
self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str
4949
) -> ExprT_co | ScalarT_co: ...
50+
def date_range(
51+
self, node: ir.RangeExpr[DateRange], frame: FrameT, name: str
52+
) -> ExprT_co: ...
5053
def int_range(
5154
self, node: ir.RangeExpr[IntRange], frame: FrameT, name: str
5255
) -> ExprT_co: ...

narwhals/_plan/expressions/ranges.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from narwhals._plan.expressions import ExprIR, RangeExpr
1212
from narwhals.dtypes import IntegerType
13+
from narwhals.typing import ClosedInterval
1314

1415

1516
class RangeFunction(Function, config=FEOptions.namespaced()):
@@ -18,6 +19,10 @@ def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]:
1819

1920
return RangeExpr(input=inputs, function=self, options=self.function_options)
2021

22+
def unwrap_input(self, node: RangeExpr[Self], /) -> tuple[ExprIR, ExprIR]:
23+
start, end = node.input
24+
return start, end
25+
2126

2227
class IntRange(RangeFunction, options=FunctionOptions.row_separable):
2328
"""N-ary (start, end)."""
@@ -26,6 +31,10 @@ class IntRange(RangeFunction, options=FunctionOptions.row_separable):
2631
step: int
2732
dtype: IntegerType
2833

29-
def unwrap_input(self, node: RangeExpr[Self], /) -> tuple[ExprIR, ExprIR]:
30-
start, end = node.input
31-
return start, end
34+
35+
class DateRange(RangeFunction, options=FunctionOptions.row_separable):
36+
"""N-ary (start, end)."""
37+
38+
__slots__ = ("interval", "closed") # noqa: RUF023
39+
interval: int
40+
closed: ClosedInterval

narwhals/_plan/functions.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
from __future__ import annotations
22

33
import builtins
4+
import datetime as dt
45
import typing as t
56
from typing import TYPE_CHECKING
67

8+
from narwhals._duration import Interval
79
from narwhals._plan import _guards, _parse, common, expressions as ir, selectors as cs
810
from narwhals._plan.expressions import functions as F
911
from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral
10-
from narwhals._plan.expressions.ranges import IntRange
12+
from narwhals._plan.expressions.ranges import DateRange, IntRange
1113
from narwhals._plan.expressions.strings import ConcatStr
1214
from narwhals._plan.when_then import When
1315
from narwhals._utils import Version, flatten
16+
from narwhals.exceptions import ComputeError
1417

1518
if TYPE_CHECKING:
1619
from narwhals._plan.expr import Expr
1720
from narwhals._plan.series import Series
1821
from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT
1922
from narwhals.dtypes import IntegerType
20-
from narwhals.typing import IntoDType, NonNestedLiteral
23+
from narwhals.typing import ClosedInterval, IntoDType, NonNestedLiteral
2124

2225

2326
def col(*names: str | t.Iterable[str]) -> Expr:
@@ -161,3 +164,36 @@ def int_range(
161164
.to_function_expr(*_parse.parse_into_seq_of_expr_ir(start, end))
162165
.to_narwhals()
163166
)
167+
168+
169+
def date_range(
170+
start: dt.date | IntoExprColumn,
171+
end: dt.date | IntoExprColumn,
172+
interval: str | dt.timedelta = "1d",
173+
*,
174+
closed: ClosedInterval = "both",
175+
eager: bool = False,
176+
) -> Expr:
177+
if eager:
178+
msg = f"{eager=}"
179+
raise NotImplementedError(msg)
180+
return (
181+
DateRange(interval=_interval_days(interval), closed=closed)
182+
.to_function_expr(*_parse.parse_into_seq_of_expr_ir(start, end))
183+
.to_narwhals()
184+
)
185+
186+
187+
def _interval_days(interval: str | dt.timedelta, /) -> int:
188+
if not isinstance(interval, dt.timedelta):
189+
if interval == "1d":
190+
return 1
191+
parsed = Interval.parse_no_constraints(interval)
192+
if parsed.unit not in {"d", "mo", "q", "y"}:
193+
msg = f"`interval` input for `date_range` must consist of full days, got: {parsed.multiple}{parsed.unit}"
194+
raise ComputeError(msg)
195+
if parsed.unit in {"mo", "q", "y"}:
196+
msg = f"`interval` input for `date_range` does not support {parsed.unit!r} yet, got: {parsed.multiple}{parsed.unit}"
197+
raise NotImplementedError(msg)
198+
interval = parsed.to_timedelta()
199+
return interval.days

0 commit comments

Comments
 (0)