Skip to content

Commit ff493ff

Browse files
chore: Rename call_kwargs to scalar_kwargs (#2555)
--------- Co-authored-by: Dan Redding <[email protected]>
1 parent 0ad881b commit ff493ff

File tree

10 files changed

+111
-68
lines changed

10 files changed

+111
-68
lines changed

narwhals/_arrow/expr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from narwhals._compliant.typing import AliasNames
2323
from narwhals._compliant.typing import EvalNames
2424
from narwhals._compliant.typing import EvalSeries
25+
from narwhals._compliant.typing import ScalarKwargs
2526
from narwhals._expression_parsing import ExprMetadata
2627
from narwhals.typing import RankMethod
2728
from narwhals.utils import Version
@@ -41,7 +42,7 @@ def __init__(
4142
alias_output_names: AliasNames | None,
4243
backend_version: tuple[int, ...],
4344
version: Version,
44-
call_kwargs: dict[str, Any] | None = None,
45+
scalar_kwargs: ScalarKwargs | None = None,
4546
implementation: Implementation | None = None,
4647
) -> None:
4748
self._call = call
@@ -52,7 +53,7 @@ def __init__(
5253
self._alias_output_names = alias_output_names
5354
self._backend_version = backend_version
5455
self._version = version
55-
self._call_kwargs = call_kwargs or {}
56+
self._scalar_kwargs = scalar_kwargs or {}
5657
self._metadata: ExprMetadata | None = None
5758

5859
@classmethod

narwhals/_arrow/group_by.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
8585

8686
function_name = self._leaf_name(expr)
8787
if function_name in {"std", "var"}:
88-
option: Any = pc.VarianceOptions(ddof=expr._call_kwargs["ddof"])
88+
assert "ddof" in expr._scalar_kwargs # noqa: S101
89+
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
8990
elif function_name in {"len", "n_unique"}:
9091
option = pc.CountOptions(mode="all")
9192
elif function_name == "count":

narwhals/_compliant/expr.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from narwhals._compliant.typing import AliasNames
6060
from narwhals._compliant.typing import EvalNames
6161
from narwhals._compliant.typing import EvalSeries
62+
from narwhals._compliant.typing import ScalarKwargs
6263
from narwhals._expression_parsing import ExprKind
6364
from narwhals._expression_parsing import ExprMetadata
6465
from narwhals.dtypes import DType
@@ -342,7 +343,7 @@ class EagerExpr(
342343
Protocol38[EagerDataFrameT, EagerSeriesT],
343344
):
344345
_call: EvalSeries[EagerDataFrameT, EagerSeriesT]
345-
_call_kwargs: dict[str, Any]
346+
_scalar_kwargs: ScalarKwargs
346347

347348
def __init__(
348349
self,
@@ -355,7 +356,7 @@ def __init__(
355356
implementation: Implementation,
356357
backend_version: tuple[int, ...],
357358
version: Version,
358-
call_kwargs: dict[str, Any] | None = None,
359+
scalar_kwargs: ScalarKwargs | None = None,
359360
) -> None: ...
360361

361362
def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]:
@@ -376,7 +377,7 @@ def _from_callable(
376377
evaluate_output_names: EvalNames[EagerDataFrameT],
377378
alias_output_names: AliasNames | None,
378379
context: _FullContext,
379-
call_kwargs: dict[str, Any] | None = None,
380+
scalar_kwargs: ScalarKwargs | None = None,
380381
) -> Self:
381382
return cls(
382383
func,
@@ -387,7 +388,7 @@ def _from_callable(
387388
implementation=context._implementation,
388389
backend_version=context._backend_version,
389390
version=context._version,
390-
call_kwargs=call_kwargs,
391+
scalar_kwargs=scalar_kwargs,
391392
)
392393

393394
@classmethod
@@ -408,7 +409,7 @@ def _reuse_series(
408409
method_name: str,
409410
*,
410411
returns_scalar: bool = False,
411-
call_kwargs: dict[str, Any] | None = None,
412+
scalar_kwargs: ScalarKwargs | None = None,
412413
**expressifiable_args: Any,
413414
) -> Self:
414415
"""Reuse Series implementation for expression.
@@ -420,7 +421,7 @@ def _reuse_series(
420421
method_name: name of method.
421422
returns_scalar: whether the Series version returns a scalar. In this case,
422423
the expression version should return a 1-row Series.
423-
call_kwargs: non-expressifiable args which we may need to reuse in `agg` or `over`,
424+
scalar_kwargs: non-expressifiable args which we may need to reuse in `agg` or `over`,
424425
such as `ddof` for `std` and `var`.
425426
expressifiable_args: keyword arguments to pass to function, which may
426427
be expressifiable (e.g. `nw.col('a').is_between(3, nw.col('b')))`).
@@ -429,7 +430,7 @@ def _reuse_series(
429430
self._reuse_series_inner,
430431
method_name=method_name,
431432
returns_scalar=returns_scalar,
432-
call_kwargs=call_kwargs or {},
433+
scalar_kwargs=scalar_kwargs or {},
433434
expressifiable_args=expressifiable_args,
434435
)
435436
return self._from_callable(
@@ -438,7 +439,7 @@ def _reuse_series(
438439
function_name=f"{self._function_name}->{method_name}",
439440
evaluate_output_names=self._evaluate_output_names,
440441
alias_output_names=self._alias_output_names,
441-
call_kwargs=call_kwargs,
442+
scalar_kwargs=scalar_kwargs,
442443
context=self,
443444
)
444445

@@ -459,11 +460,11 @@ def _reuse_series_inner(
459460
*,
460461
method_name: str,
461462
returns_scalar: bool,
462-
call_kwargs: dict[str, Any],
463+
scalar_kwargs: ScalarKwargs,
463464
expressifiable_args: dict[str, Any],
464465
) -> Sequence[EagerSeriesT]:
465466
kwargs = {
466-
**call_kwargs,
467+
**scalar_kwargs,
467468
**{
468469
name: df._evaluate_expr(value) if self._is_expr(value) else value
469470
for name, value in expressifiable_args.items()
@@ -513,7 +514,7 @@ def _reuse_series_namespace(
513514
function_name=f"{self._function_name}->{series_namespace}.{method_name}",
514515
evaluate_output_names=self._evaluate_output_names,
515516
alias_output_names=self._alias_output_names,
516-
call_kwargs=self._call_kwargs,
517+
scalar_kwargs=self._scalar_kwargs,
517518
context=self,
518519
)
519520

@@ -537,7 +538,7 @@ def func(df: EagerDataFrameT) -> list[EagerSeriesT]:
537538
backend_version=self._backend_version,
538539
implementation=self._implementation,
539540
version=self._version,
540-
call_kwargs=self._call_kwargs,
541+
scalar_kwargs=self._scalar_kwargs,
541542
)
542543

543544
def cast(self, dtype: DType | type[DType]) -> Self:
@@ -627,10 +628,14 @@ def median(self) -> Self:
627628
return self._reuse_series("median", returns_scalar=True)
628629

629630
def std(self, *, ddof: int) -> Self:
630-
return self._reuse_series("std", returns_scalar=True, call_kwargs={"ddof": ddof})
631+
return self._reuse_series(
632+
"std", returns_scalar=True, scalar_kwargs={"ddof": ddof}
633+
)
631634

632635
def var(self, *, ddof: int) -> Self:
633-
return self._reuse_series("var", returns_scalar=True, call_kwargs={"ddof": ddof})
636+
return self._reuse_series(
637+
"var", returns_scalar=True, scalar_kwargs={"ddof": ddof}
638+
)
634639

635640
def skew(self) -> Self:
636641
return self._reuse_series("skew", returns_scalar=True)
@@ -747,7 +752,7 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:
747752
backend_version=self._backend_version,
748753
implementation=self._implementation,
749754
version=self._version,
750-
call_kwargs=self._call_kwargs,
755+
scalar_kwargs=self._scalar_kwargs,
751756
)
752757

753758
def is_unique(self) -> Self:
@@ -1087,7 +1092,7 @@ def _from_callable(self, func: AliasName, /, *, alias: bool = True) -> EagerExpr
10871092
backend_version=expr._backend_version,
10881093
implementation=expr._implementation,
10891094
version=expr._version,
1090-
call_kwargs=expr._call_kwargs,
1095+
scalar_kwargs=expr._scalar_kwargs,
10911096
)
10921097

10931098

narwhals/_compliant/selectors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import re
66
from functools import partial
77
from typing import TYPE_CHECKING
8-
from typing import Any
98
from typing import Collection
109
from typing import Iterable
1110
from typing import Iterator
@@ -49,6 +48,7 @@
4948
from narwhals._compliant.typing import CompliantSeriesOrNativeExprAny
5049
from narwhals._compliant.typing import EvalNames
5150
from narwhals._compliant.typing import EvalSeries
51+
from narwhals._compliant.typing import ScalarKwargs
5252
from narwhals.dtypes import DType
5353
from narwhals.typing import TimeUnit
5454
from narwhals.utils import Implementation
@@ -226,7 +226,7 @@ class CompliantSelector(
226226
_implementation: Implementation
227227
_backend_version: tuple[int, ...]
228228
_version: Version
229-
_call_kwargs: dict[str, Any]
229+
_scalar_kwargs: ScalarKwargs
230230

231231
@classmethod
232232
def from_callables(
@@ -245,7 +245,7 @@ def from_callables(
245245
obj._implementation = context._implementation
246246
obj._backend_version = context._backend_version
247247
obj._version = context._version
248-
obj._call_kwargs = {}
248+
obj._scalar_kwargs = {}
249249
return obj
250250

251251
@property

narwhals/_compliant/typing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55
from typing import Callable
66
from typing import Sequence
7+
from typing import TypedDict
78
from typing import TypeVar
89

910
if TYPE_CHECKING:
@@ -21,8 +22,25 @@
2122
from narwhals._compliant.namespace import EagerNamespace
2223
from narwhals._compliant.series import CompliantSeries
2324
from narwhals._compliant.series import EagerSeries
25+
from narwhals.typing import FillNullStrategy
2426
from narwhals.typing import NativeFrame
2527
from narwhals.typing import NativeSeries
28+
from narwhals.typing import RankMethod
29+
30+
class ScalarKwargs(TypedDict, total=False):
31+
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
32+
33+
center: int
34+
ddof: int
35+
descending: bool
36+
limit: int | None
37+
method: RankMethod
38+
min_samples: int
39+
n: int
40+
reverse: bool
41+
strategy: FillNullStrategy | None
42+
window_size: int
43+
2644

2745
__all__ = [
2846
"AliasName",

narwhals/_compliant/when_then.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing_extensions import TypeAlias
2626

2727
from narwhals._compliant.typing import EvalSeries
28+
from narwhals._compliant.typing import ScalarKwargs
2829
from narwhals.typing import NonNestedLiteral
2930
from narwhals.utils import Implementation
3031
from narwhals.utils import Version
@@ -91,7 +92,7 @@ class CompliantThen(CompliantExpr[FrameT, SeriesT], Protocol38[FrameT, SeriesT,
9192
_implementation: Implementation
9293
_backend_version: tuple[int, ...]
9394
_version: Version
94-
_call_kwargs: dict[str, Any]
95+
_scalar_kwargs: ScalarKwargs
9596

9697
@classmethod
9798
def from_when(
@@ -113,7 +114,7 @@ def from_when(
113114
obj._implementation = when._implementation
114115
obj._backend_version = when._backend_version
115116
obj._version = when._version
116-
obj._call_kwargs = {}
117+
obj._scalar_kwargs = {}
117118
return obj
118119

119120
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:

narwhals/_dask/expr.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from narwhals._compliant.typing import AliasNames
3131
from narwhals._compliant.typing import EvalNames
3232
from narwhals._compliant.typing import EvalSeries
33+
from narwhals._compliant.typing import ScalarKwargs
3334
from narwhals._dask.dataframe import DaskLazyFrame
3435
from narwhals._dask.namespace import DaskNamespace
3536
from narwhals._expression_parsing import ExprKind
@@ -60,9 +61,7 @@ def __init__(
6061
alias_output_names: AliasNames | None,
6162
backend_version: tuple[int, ...],
6263
version: Version,
63-
# Kwargs with metadata which we may need in group-by agg
64-
# (e.g. `ddof` for `std` and `var`).
65-
call_kwargs: dict[str, Any] | None = None,
64+
scalar_kwargs: ScalarKwargs | None = None,
6665
) -> None:
6766
self._call = call
6867
self._depth = depth
@@ -71,7 +70,7 @@ def __init__(
7170
self._alias_output_names = alias_output_names
7271
self._backend_version = backend_version
7372
self._version = version
74-
self._call_kwargs = call_kwargs or {}
73+
self._scalar_kwargs = scalar_kwargs or {}
7574
self._metadata: ExprMetadata | None = None
7675

7776
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
@@ -97,7 +96,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
9796
alias_output_names=self._alias_output_names,
9897
backend_version=self._backend_version,
9998
version=self._version,
100-
call_kwargs=self._call_kwargs,
99+
scalar_kwargs=self._scalar_kwargs,
101100
)
102101

103102
@classmethod
@@ -155,7 +154,7 @@ def _with_callable(
155154
call: Callable[..., dx.Series],
156155
/,
157156
expr_name: str = "",
158-
call_kwargs: dict[str, Any] | None = None,
157+
scalar_kwargs: ScalarKwargs | None = None,
159158
**expressifiable_args: Self | Any,
160159
) -> Self:
161160
def func(df: DaskLazyFrame) -> list[dx.Series]:
@@ -178,7 +177,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
178177
alias_output_names=self._alias_output_names,
179178
backend_version=self._backend_version,
180179
version=self._version,
181-
call_kwargs=call_kwargs,
180+
scalar_kwargs=scalar_kwargs,
182181
)
183182

184183
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
@@ -190,7 +189,7 @@ def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
190189
alias_output_names=func,
191190
backend_version=self._backend_version,
192191
version=self._version,
193-
call_kwargs=self._call_kwargs,
192+
scalar_kwargs=self._scalar_kwargs,
194193
)
195194

196195
def __add__(self, other: Any) -> Self:
@@ -321,14 +320,14 @@ def std(self, ddof: int) -> Self:
321320
return self._with_callable(
322321
lambda _input: _input.std(ddof=ddof).to_series(),
323322
"std",
324-
call_kwargs={"ddof": ddof},
323+
scalar_kwargs={"ddof": ddof},
325324
)
326325

327326
def var(self, ddof: int) -> Self:
328327
return self._with_callable(
329328
lambda _input: _input.var(ddof=ddof).to_series(),
330329
"var",
331-
call_kwargs={"ddof": ddof},
330+
scalar_kwargs={"ddof": ddof},
332331
)
333332

334333
def skew(self) -> Self:
@@ -626,11 +625,11 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
626625
msg = "Safety check failed, please report a bug."
627626
raise AssertionError(msg)
628627
res_native = grouped.transform(
629-
dask_function_name, **self._call_kwargs
628+
dask_function_name, **self._scalar_kwargs
630629
).to_frame(output_names[0])
631630
else:
632631
res_native = grouped[list(output_names)].transform(
633-
dask_function_name, **self._call_kwargs
632+
dask_function_name, **self._scalar_kwargs
634633
)
635634
result_frame = df._with_native(
636635
res_native.rename(columns=dict(zip(output_names, aliases)))

narwhals/_dask/group_by.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
115115
# e.g. `agg(nw.mean('a'))`
116116
agg_fn = self._remap_expr_name(self._leaf_name(expr))
117117
# deal with n_unique case in a "lazy" mode to not depend on dask globally
118-
agg_fn = agg_fn(**expr._call_kwargs) if callable(agg_fn) else agg_fn
118+
agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn
119119
simple_aggregations.update(
120120
(alias, (output_name, agg_fn))
121121
for alias, output_name in zip(aliases, output_names)

0 commit comments

Comments
 (0)