11from __future__ import annotations
22
3+ import datetime as dt
34from functools import reduce
45from typing import TYPE_CHECKING , Any , Literal , cast , overload
56
1112from narwhals ._plan .arrow import functions as fn
1213from narwhals ._plan .compliant .namespace import EagerNamespace
1314from narwhals ._plan .expressions .literal import is_literal_scalar
15+ from narwhals ._typing_compat import TypeVar
1416from narwhals ._utils import Version
1517from narwhals .exceptions import InvalidOperationError
1618
2426 from narwhals ._plan .expressions import expr , functions as F
2527 from narwhals ._plan .expressions .boolean import AllHorizontal , AnyHorizontal
2628 from narwhals ._plan .expressions .expr import FunctionExpr , RangeExpr
27- from narwhals ._plan .expressions .ranges import IntRange
29+ from narwhals ._plan .expressions .ranges import DateRange , IntRange
2830 from narwhals ._plan .expressions .strings import ConcatStr
2931 from narwhals ._plan .series import Series as NwSeries
3032 from narwhals .typing import ConcatMethod , NonNestedLiteral , PythonLiteral
3133
3234
35+ PythonLiteralT = TypeVar ("PythonLiteralT" , bound = "PythonLiteral" )
36+
37+
3338class ArrowNamespace (EagerNamespace ["Frame" , "Series" , "Expr" , "Scalar" ]):
3439 def __init__ (self , version : Version = Version .MAIN ) -> None :
3540 self ._version = version
@@ -155,12 +160,12 @@ def concat_str(
155160 return self ._scalar .from_native (result , name , self .version )
156161 return self ._expr .from_native (result , name , self .version )
157162
158- def int_range (self , node : RangeExpr [IntRange ], frame : Frame , name : str ) -> Expr :
163+ def _range_function_inputs (
164+ self , node : RangeExpr , frame : Frame , valid_type : type [PythonLiteralT ]
165+ ) -> tuple [PythonLiteralT , PythonLiteralT ]:
159166 start_ : PythonLiteral
160167 end_ : PythonLiteral
161168 start , end = node .function .unwrap_input (node )
162- step = node .function .step
163- dtype = node .function .dtype
164169 if is_literal_scalar (start ) and is_literal_scalar (end ):
165170 start_ , end_ = start .unwrap (), end .unwrap ()
166171 else :
@@ -172,22 +177,29 @@ def int_range(self, node: RangeExpr[IntRange], frame: Frame, name: str) -> Expr:
172177 start_ , end_ = scalar_start .to_python (), scalar_end .to_python ()
173178 else :
174179 msg = (
175- f"All inputs for `int_range ()` must be scalar or aggregations, but got \n "
180+ f"All inputs for `{ node . function } ()` must be scalar or aggregations, but got \n "
176181 f"{ scalar_start .native !r} \n { scalar_end .native !r} "
177182 )
178183 raise InvalidOperationError (msg )
179- if isinstance (start_ , int ) and isinstance (end_ , int ):
180- pa_dtype = narwhals_to_native_dtype (dtype , self .version )
181- if not pa .types .is_integer (pa_dtype ):
182- raise TypeError (pa_dtype )
183- native = fn .int_range (start_ , end_ , step , dtype = pa_dtype )
184- return self ._expr .from_native (native , name , self .version )
185-
186- msg = (
187- f"All inputs for `int_range()` resolve to int, but got \n { start_ !r} \n { end_ !r} "
188- )
184+ if isinstance (start_ , valid_type ) and isinstance (end_ , valid_type ):
185+ return start_ , end_
186+ msg = f"All inputs for `{ node .function } ()` resolve to { valid_type .__name__ } , but got \n { start_ !r} \n { end_ !r} "
189187 raise InvalidOperationError (msg )
190188
189+ def int_range (self , node : RangeExpr [IntRange ], frame : Frame , name : str ) -> Expr :
190+ start , end = self ._range_function_inputs (node , frame , int )
191+ dtype = narwhals_to_native_dtype (node .function .dtype , self .version )
192+ if not pa .types .is_integer (dtype ):
193+ raise TypeError (dtype )
194+ native = fn .int_range (start , end , node .function .step , dtype = dtype )
195+ return self ._expr .from_native (native , name , self .version )
196+
197+ def date_range (self , node : RangeExpr [DateRange ], frame : Frame , name : str ) -> Expr :
198+ start , end = self ._range_function_inputs (node , frame , dt .date )
199+ func = node .function
200+ native = fn .date_range (start , end , func .interval , closed = func .closed )
201+ return self ._expr .from_native (native , name , self .version )
202+
191203 @overload
192204 def concat (self , items : Iterable [Frame ], * , how : ConcatMethod ) -> Frame : ...
193205 @overload
0 commit comments