Skip to content

Commit be6ec3e

Browse files
committed
feat: Add methods for all of functions
Still need: - reprs - fix the hierarchy issue (#2572 (comment)) - Flag summing (#2572 (comment))
1 parent 98292dc commit be6ec3e

File tree

2 files changed

+276
-3
lines changed

2 files changed

+276
-3
lines changed

narwhals/_plan/dummy.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88
from narwhals._plan import aggregation as agg
99
from narwhals._plan import boolean
1010
from narwhals._plan import expr
11+
from narwhals._plan import expr_parsing as parse
12+
from narwhals._plan import functions as F # noqa: N812
1113
from narwhals._plan import operators as ops
14+
from narwhals._plan.options import EWMOptions
15+
from narwhals._plan.options import RankOptions
16+
from narwhals._plan.options import RollingOptionsFixedWindow
17+
from narwhals._plan.options import RollingVarParams
1218
from narwhals._plan.options import SortMultipleOptions
1319
from narwhals._plan.options import SortOptions
1420
from narwhals._plan.window import Over
1521
from narwhals.dtypes import DType
22+
from narwhals.exceptions import ComputeError
1623
from narwhals.utils import Version
1724
from narwhals.utils import _hasattr_static
1825
from narwhals.utils import flatten
@@ -21,9 +28,16 @@
2128
from typing_extensions import Self
2229

2330
from narwhals._plan.common import ExprIR
31+
from narwhals._plan.common import IntoExpr
32+
from narwhals._plan.common import IntoExprColumn
2433
from narwhals._plan.common import Seq
34+
from narwhals._plan.common import Udf
35+
from narwhals.typing import FillNullStrategy
2536
from narwhals.typing import NativeSeries
37+
from narwhals.typing import NumericLiteral
38+
from narwhals.typing import RankMethod
2639
from narwhals.typing import RollingInterpolationMethod
40+
from narwhals.typing import TemporalLiteral
2741

2842

2943
# NOTE: Overly simplified placeholders for mocking typing
@@ -127,6 +141,237 @@ def sort_by(
127141
options = SortMultipleOptions(descending=desc, nulls_last=nulls)
128142
return self._from_ir(expr.SortBy(expr=self._ir, by=sort_by, options=options))
129143

144+
def abs(self) -> Self:
145+
return self._from_ir(F.Abs().to_function_expr(self._ir))
146+
147+
def hist(
148+
self,
149+
bins: t.Sequence[float] | None = None,
150+
*,
151+
bin_count: int | None = None,
152+
include_breakpoint: bool = True,
153+
) -> Self:
154+
node: F.Hist
155+
if bins is not None:
156+
if bin_count is not None:
157+
msg = "can only provide one of `bin_count` or `bins`"
158+
raise ComputeError(msg)
159+
node = F.HistBins(bins=tuple(bins), include_breakpoint=include_breakpoint)
160+
elif bin_count is not None:
161+
node = F.HistBinCount(
162+
bin_count=bin_count, include_breakpoint=include_breakpoint
163+
)
164+
else:
165+
node = F.HistBinCount(include_breakpoint=include_breakpoint)
166+
return self._from_ir(node.to_function_expr(self._ir))
167+
168+
def null_count(self) -> Self:
169+
return self._from_ir(F.NullCount().to_function_expr(self._ir))
170+
171+
def fill_null(
172+
self,
173+
value: IntoExpr = None,
174+
strategy: FillNullStrategy | None = None,
175+
limit: int | None = None,
176+
) -> Self:
177+
node: F.FillNullWithStrategy | F.FillNull
178+
if strategy is not None:
179+
node = F.FillNullWithStrategy(strategy=strategy, limit=limit)
180+
else:
181+
node = F.FillNull(value=parse.parse_into_expr_ir(value, str_as_lit=True))
182+
return self._from_ir(node.to_function_expr(self._ir))
183+
184+
def shift(self, n: int) -> Self:
185+
return self._from_ir(F.Shift(n=n).to_function_expr(self._ir))
186+
187+
def drop_nulls(self) -> Self:
188+
return self._from_ir(F.DropNulls().to_function_expr(self._ir))
189+
190+
def mode(self) -> Self:
191+
return self._from_ir(F.Mode().to_function_expr(self._ir))
192+
193+
def skew(self) -> Self:
194+
return self._from_ir(F.Skew().to_function_expr(self._ir))
195+
196+
def rank(self, method: RankMethod = "average", *, descending: bool = False) -> Self:
197+
options = RankOptions(method=method, descending=descending)
198+
return self._from_ir(F.Rank(options=options).to_function_expr(self._ir))
199+
200+
def clip(
201+
self,
202+
lower_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None,
203+
upper_bound: IntoExprColumn | NumericLiteral | TemporalLiteral | None = None,
204+
) -> Self:
205+
return self._from_ir(
206+
F.Clip().to_function_expr(
207+
self._ir, *parse.parse_into_seq_of_expr_ir(lower_bound, upper_bound)
208+
)
209+
)
210+
211+
def cum_count(self, *, reverse: bool = False) -> Self:
212+
return self._from_ir(F.CumCount(reverse=reverse).to_function_expr(self._ir))
213+
214+
def cum_min(self, *, reverse: bool = False) -> Self:
215+
return self._from_ir(F.CumMin(reverse=reverse).to_function_expr(self._ir))
216+
217+
def cum_max(self, *, reverse: bool = False) -> Self:
218+
return self._from_ir(F.CumMax(reverse=reverse).to_function_expr(self._ir))
219+
220+
def cum_prod(self, *, reverse: bool = False) -> Self:
221+
return self._from_ir(F.CumProd(reverse=reverse).to_function_expr(self._ir))
222+
223+
def rolling_sum(
224+
self,
225+
window_size: int,
226+
*,
227+
min_samples: int | None = None,
228+
center: bool = False,
229+
) -> Self:
230+
min_samples = window_size if min_samples is None else min_samples
231+
fn_params = None
232+
options = RollingOptionsFixedWindow(
233+
window_size=window_size,
234+
min_samples=min_samples,
235+
center=center,
236+
fn_params=fn_params,
237+
)
238+
function = F.RollingSum(options=options)
239+
return self._from_ir(function.to_function_expr(self._ir))
240+
241+
def rolling_mean(
242+
self,
243+
window_size: int,
244+
*,
245+
min_samples: int | None = None,
246+
center: bool = False,
247+
) -> Self:
248+
min_samples = window_size if min_samples is None else min_samples
249+
fn_params = None
250+
options = RollingOptionsFixedWindow(
251+
window_size=window_size,
252+
min_samples=min_samples,
253+
center=center,
254+
fn_params=fn_params,
255+
)
256+
function = F.RollingMean(options=options)
257+
return self._from_ir(function.to_function_expr(self._ir))
258+
259+
def rolling_var(
260+
self,
261+
window_size: int,
262+
*,
263+
min_samples: int | None = None,
264+
center: bool = False,
265+
ddof: int = 1,
266+
) -> Self:
267+
min_samples = window_size if min_samples is None else min_samples
268+
fn_params = RollingVarParams(ddof=ddof)
269+
options = RollingOptionsFixedWindow(
270+
window_size=window_size,
271+
min_samples=min_samples,
272+
center=center,
273+
fn_params=fn_params,
274+
)
275+
function = F.RollingVar(options=options)
276+
return self._from_ir(function.to_function_expr(self._ir))
277+
278+
def rolling_std(
279+
self,
280+
window_size: int,
281+
*,
282+
min_samples: int | None = None,
283+
center: bool = False,
284+
ddof: int = 1,
285+
) -> Self:
286+
min_samples = window_size if min_samples is None else min_samples
287+
fn_params = RollingVarParams(ddof=ddof)
288+
options = RollingOptionsFixedWindow(
289+
window_size=window_size,
290+
min_samples=min_samples,
291+
center=center,
292+
fn_params=fn_params,
293+
)
294+
function = F.RollingStd(options=options)
295+
return self._from_ir(function.to_function_expr(self._ir))
296+
297+
def diff(self) -> Self:
298+
return self._from_ir(F.Diff().to_function_expr(self._ir))
299+
300+
def unique(self) -> Self:
301+
return self._from_ir(F.Unique().to_function_expr(self._ir))
302+
303+
def round(self, decimals: int = 0) -> Self:
304+
return self._from_ir(F.Round(decimals=decimals).to_function_expr(self._ir))
305+
306+
def ewm_mean(
307+
self,
308+
*,
309+
com: float | None = None,
310+
span: float | None = None,
311+
half_life: float | None = None,
312+
alpha: float | None = None,
313+
adjust: bool = True,
314+
min_samples: int = 1,
315+
ignore_nulls: bool = False,
316+
) -> Self:
317+
options = EWMOptions(
318+
com=com,
319+
span=span,
320+
half_life=half_life,
321+
alpha=alpha,
322+
adjust=adjust,
323+
min_samples=min_samples,
324+
ignore_nulls=ignore_nulls,
325+
)
326+
return self._from_ir(F.EwmMean(options=options).to_function_expr(self._ir))
327+
328+
def replace_strict(
329+
self,
330+
old: t.Sequence[t.Any] | t.Mapping[t.Any, t.Any],
331+
new: t.Sequence[t.Any] | None = None,
332+
*,
333+
return_dtype: DType | type[DType] | None = None,
334+
) -> Self:
335+
before: Seq[t.Any]
336+
after: Seq[t.Any]
337+
if new is None:
338+
if not isinstance(old, t.Mapping):
339+
msg = "`new` argument is required if `old` argument is not a Mapping type"
340+
raise TypeError(msg)
341+
before = tuple(old)
342+
after = tuple(old.values())
343+
elif isinstance(old, t.Mapping):
344+
# NOTE: polars raises later when this occurs
345+
# TypeError: cannot create expression literal for value of type dict.
346+
# Hint: Pass `allow_object=True` to accept any value and create a literal of type Object.
347+
msg = "`new` argument cannot be used if `old` argument is a Mapping type"
348+
raise TypeError(msg)
349+
else:
350+
before = tuple(old)
351+
after = tuple(new)
352+
function = F.ReplaceStrict(old=before, new=after, return_dtype=return_dtype)
353+
return self._from_ir(function.to_function_expr(self._ir))
354+
355+
def gather_every(self, n: int, offset: int = 0) -> Self:
356+
return self._from_ir(F.GatherEvery(n=n, offset=offset).to_function_expr(self._ir))
357+
358+
def map_batches(
359+
self,
360+
function: Udf,
361+
return_dtype: DType | None = None,
362+
*,
363+
is_elementwise: bool = False,
364+
returns_scalar: bool = False,
365+
) -> Self:
366+
return self._from_ir(
367+
F.MapBatches(
368+
function=function,
369+
return_dtype=return_dtype,
370+
is_elementwise=is_elementwise,
371+
returns_scalar=returns_scalar,
372+
).to_function_expr(self._ir)
373+
)
374+
130375
def __eq__(self, other: DummyExpr) -> Self: # type: ignore[override]
131376
op = ops.Eq()
132377
return self._from_ir(op.to_binary_expr(self._ir, other._ir))
@@ -186,6 +431,11 @@ def __or__(self, other: DummyExpr) -> Self:
186431
def __invert__(self) -> Self:
187432
return self._from_ir(boolean.Not().to_function_expr(self._ir))
188433

434+
def __pow__(self, other: IntoExpr) -> Self:
435+
exponent = parse.parse_into_expr_ir(other, str_as_lit=True)
436+
base = self._ir
437+
return self._from_ir(F.Pow().to_function_expr(base, exponent))
438+
189439

190440
class DummyExprV1(DummyExpr):
191441
_version: t.ClassVar[Version] = Version.V1

narwhals/_plan/functions.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from narwhals._plan.common import Function
1313
from narwhals._plan.options import FunctionFlags
1414
from narwhals._plan.options import FunctionOptions
15+
from narwhals.exceptions import ComputeError
1516

1617
if TYPE_CHECKING:
18+
from typing import Any
19+
1720
from narwhals._plan.common import Seq
1821
from narwhals._plan.common import Udf
1922
from narwhals._plan.options import EWMOptions
@@ -44,15 +47,28 @@ def function_options(self) -> FunctionOptions:
4447
class HistBins(Hist):
4548
"""Subclasses for each variant."""
4649

47-
__slots__ = (*Hist.__slots__, "bins")
50+
__slots__ = ("bins", *Hist.__slots__)
4851

4952
bins: Seq[float]
5053

54+
def __init__(self, *, bins: Seq[float], include_breakpoint: bool = True) -> None:
55+
for i in range(1, len(bins)):
56+
if bins[i - 1] >= bins[i]:
57+
msg = "bins must increase monotonically"
58+
raise ComputeError(msg)
59+
object.__setattr__(self, "bins", bins)
60+
object.__setattr__(self, "include_breakpoint", include_breakpoint)
61+
5162

5263
class HistBinCount(Hist):
53-
__slots__ = (*Hist.__slots__, "bin_count")
64+
__slots__ = ("bin_count", *Hist.__slots__)
5465

5566
bin_count: int
67+
"""Polars (v1.20) sets `bin_count=10` if neither `bins` or `bin_count` are provided."""
68+
69+
def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> None:
70+
object.__setattr__(self, "bin_count", bin_count)
71+
object.__setattr__(self, "include_breakpoint", include_breakpoint)
5672

5773

5874
class NullCount(Function):
@@ -267,8 +283,10 @@ def function_options(self) -> FunctionOptions:
267283

268284

269285
class ReplaceStrict(Function):
270-
__slots__ = ("return_dtype",)
286+
__slots__ = ("new", "old", "return_dtype")
271287

288+
old: Seq[Any]
289+
new: Seq[Any]
272290
return_dtype: DType | type[DType] | None
273291

274292
@property
@@ -277,6 +295,11 @@ def function_options(self) -> FunctionOptions:
277295

278296

279297
class GatherEvery(Function):
298+
__slots__ = ("n", "offset")
299+
300+
n: int
301+
offset: int
302+
280303
@property
281304
def function_options(self) -> FunctionOptions:
282305
return FunctionOptions.groupwise()

0 commit comments

Comments
 (0)