Skip to content

Commit 053390d

Browse files
dangotbannedFBruzzesiMarcoGorelli
authored
feat: Adds {Expr,Series}.{first,last} (#2528)
--------- Co-authored-by: Francesco Bruzzesi <[email protected]> Co-authored-by: FBruzzesi <[email protected]> Co-authored-by: Marco Gorelli <[email protected]>
1 parent 022a3d9 commit 053390d

File tree

25 files changed

+649
-50
lines changed

25 files changed

+649
-50
lines changed

docs/api-reference/expr.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- fill_nan
2424
- fill_null
2525
- filter
26+
- first
2627
- is_between
2728
- is_close
2829
- is_duplicated
@@ -34,6 +35,7 @@
3435
- is_null
3536
- is_unique
3637
- kurtosis
38+
- last
3739
- len
3840
- log
3941
- map_batches

docs/api-reference/series.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
- fill_nan
3131
- fill_null
3232
- filter
33+
- first
3334
- from_iterable
3435
- from_numpy
3536
- gather_every
@@ -50,6 +51,7 @@
5051
- is_unique
5152
- item
5253
- kurtosis
54+
- last
5355
- len
5456
- log
5557
- max

narwhals/_arrow/expr.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING, Any
44

5+
import pyarrow as pa
56
import pyarrow.compute as pc
67

78
from narwhals._arrow.series import ArrowSeries
@@ -111,11 +112,8 @@ def _reuse_series_extra_kwargs(
111112
return {"_return_py_scalar": False} if returns_scalar else {}
112113

113114
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
114-
if (
115-
partition_by
116-
and self._metadata is not None
117-
and not self._metadata.is_scalar_like
118-
):
115+
meta = self._metadata
116+
if partition_by and meta is not None and not meta.is_scalar_like:
119117
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
120118
raise NotImplementedError(msg)
121119

@@ -129,15 +127,24 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
129127
df = df.with_row_index(token, order_by=None).sort(
130128
*order_by, descending=False, nulls_last=False
131129
)
132-
result = self(df.drop([token], strict=True))
130+
results = self(df.drop([token], strict=True))
131+
if meta is not None and meta.is_scalar_like:
132+
# We need to broadcast the results to the original size, since
133+
# `over` is a length-preserving operation.
134+
size = len(df)
135+
return [s._with_native(pa.repeat(s.item(), size)) for s in results]
136+
133137
# TODO(marco): is there a way to do this efficiently without
134138
# doing 2 sorts? Here we're sorting the dataframe and then
135139
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
136140
sorting_indices = pc.sort_indices(df.get_column(token).native)
137-
return [s._with_native(s.native.take(sorting_indices)) for s in result]
141+
return [s._with_native(s.native.take(sorting_indices)) for s in results]
138142
else:
139143

140144
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
145+
if order_by:
146+
df = df.sort(*order_by, descending=False, nulls_last=False)
147+
141148
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
142149
if overlap := set(output_names).intersection(partition_by):
143150
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,

narwhals/_arrow/group_by.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
1010
from narwhals._compliant import EagerGroupBy
1111
from narwhals._expression_parsing import evaluate_output_names_and_aliases
12-
from narwhals._utils import generate_temporary_column_name
12+
from narwhals._utils import generate_temporary_column_name, requires
1313

1414
if TYPE_CHECKING:
1515
from collections.abc import Iterator, Mapping, Sequence
@@ -39,12 +39,23 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
3939
"count": "count",
4040
"all": "all",
4141
"any": "any",
42+
"first": "first",
43+
"last": "last",
4244
}
4345
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
4446
"any": "min",
4547
"first": "min",
4648
"last": "max",
4749
}
50+
_OPTION_COUNT_ALL: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(
51+
("len", "n_unique")
52+
)
53+
_OPTION_COUNT_VALID: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("count",))
54+
_OPTION_ORDERED: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(
55+
("first", "last")
56+
)
57+
_OPTION_VARIANCE: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("std", "var"))
58+
_OPTION_SCALAR: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("any", "all"))
4859

4960
def __init__(
5061
self,
@@ -60,12 +71,58 @@ def __init__(
6071
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
6172
self._drop_null_keys = drop_null_keys
6273

74+
def _configure_agg(
75+
self, grouped: pa.TableGroupBy, expr: ArrowExpr, /
76+
) -> tuple[pa.TableGroupBy, Aggregation, AggregateOptions | None]:
77+
option: AggregateOptions | None = None
78+
function_name = self._leaf_name(expr)
79+
if function_name in self._OPTION_VARIANCE:
80+
ddof = expr._scalar_kwargs.get("ddof", 1)
81+
option = pc.VarianceOptions(ddof=ddof)
82+
elif function_name in self._OPTION_COUNT_ALL:
83+
option = pc.CountOptions(mode="all")
84+
elif function_name in self._OPTION_COUNT_VALID:
85+
option = pc.CountOptions(mode="only_valid")
86+
elif function_name in self._OPTION_SCALAR:
87+
option = pc.ScalarAggregateOptions(min_count=0)
88+
elif function_name in self._OPTION_ORDERED:
89+
grouped, option = self._ordered_agg(grouped, function_name)
90+
return grouped, self._remap_expr_name(function_name), option
91+
92+
def _ordered_agg(
93+
self, grouped: pa.TableGroupBy, name: NarwhalsAggregation, /
94+
) -> tuple[pa.TableGroupBy, AggregateOptions]:
95+
"""The default behavior of `pyarrow` raises when `first` or `last` are used.
96+
97+
You'd see an error like:
98+
99+
ArrowNotImplementedError: Using ordered aggregator in multiple threaded execution is not supported
100+
101+
We need to **disable** multi-threading to use them, but the ability to do so
102+
wasn't possible before `14.0.0` ([pyarrow-36709])
103+
104+
[pyarrow-36709]: https://github.com/apache/arrow/issues/36709
105+
"""
106+
backend_version = self.compliant._backend_version
107+
if backend_version >= (14, 0) and grouped._use_threads:
108+
native = self.compliant.native
109+
grouped = pa.TableGroupBy(native, grouped.keys, use_threads=False)
110+
elif backend_version < (14, 0): # pragma: no cover
111+
msg = (
112+
f"Using `{name}()` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', "
113+
f"found version {requires._unparse_version(backend_version)!r}.\n\n"
114+
f"See https://github.com/apache/arrow/issues/36709"
115+
)
116+
raise NotImplementedError(msg)
117+
return grouped, pc.ScalarAggregateOptions(skip_nulls=False)
118+
63119
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
64120
self._ensure_all_simple(exprs)
65121
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
66122
expected_pyarrow_column_names: list[str] = self._keys.copy()
67123
new_column_names: list[str] = self._keys.copy()
68124
exclude = (*self._keys, *self._output_key_names)
125+
grouped = self._grouped
69126

70127
for expr in exprs:
71128
output_names, aliases = evaluate_output_names_and_aliases(
@@ -83,20 +140,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
83140
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
84141
continue
85142

86-
function_name = self._leaf_name(expr)
87-
if function_name in {"std", "var"}:
88-
assert "ddof" in expr._scalar_kwargs # noqa: S101
89-
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
90-
elif function_name in {"len", "n_unique"}:
91-
option = pc.CountOptions(mode="all")
92-
elif function_name == "count":
93-
option = pc.CountOptions(mode="only_valid")
94-
elif function_name in {"all", "any"}:
95-
option = pc.ScalarAggregateOptions(min_count=0)
96-
else:
97-
option = None
98-
99-
function_name = self._remap_expr_name(function_name)
143+
grouped, function_name, option = self._configure_agg(grouped, expr)
100144
new_column_names.extend(aliases)
101145
expected_pyarrow_column_names.extend(
102146
[f"{output_name}_{function_name}" for output_name in output_names]
@@ -105,7 +149,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
105149
[(output_name, function_name, option) for output_name in output_names]
106150
)
107151

108-
result_simple = self._grouped.aggregate(aggs)
152+
result_simple = grouped.aggregate(aggs)
109153

110154
# Rename columns, being very careful
111155
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)

narwhals/_arrow/series.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,15 @@ def filter(self, predicate: ArrowSeries | list[bool | None]) -> Self:
330330
other_native = predicate
331331
return self._with_native(self.native.filter(other_native))
332332

333+
def first(self, *, _return_py_scalar: bool = True) -> PythonLiteral:
334+
result = self.native[0] if len(self.native) else None
335+
return maybe_extract_py_scalar(result, _return_py_scalar)
336+
337+
def last(self, *, _return_py_scalar: bool = True) -> PythonLiteral:
338+
ca = self.native
339+
result = ca[height - 1] if (height := len(ca)) else None
340+
return maybe_extract_py_scalar(result, _return_py_scalar)
341+
333342
def mean(self, *, _return_py_scalar: bool = True) -> float:
334343
return maybe_extract_py_scalar(pc.mean(self.native), _return_py_scalar)
335344

narwhals/_compliant/expr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def max(self) -> Self: ...
124124
def mean(self) -> Self: ...
125125
def sum(self) -> Self: ...
126126
def median(self) -> Self: ...
127+
def first(self) -> Self: ...
128+
def last(self) -> Self: ...
127129
def skew(self) -> Self: ...
128130
def kurtosis(self) -> Self: ...
129131
def std(self, *, ddof: int) -> Self: ...
@@ -867,6 +869,12 @@ def is_close(
867869
nans_equal=nans_equal,
868870
)
869871

872+
def first(self) -> Self:
873+
return self._reuse_series("first", returns_scalar=True)
874+
875+
def last(self) -> Self:
876+
return self._reuse_series("last", returns_scalar=True)
877+
870878
@property
871879
def cat(self) -> EagerExprCatNamespace[Self]:
872880
return EagerExprCatNamespace(self)

narwhals/_compliant/series.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
Into1DArray,
4646
IntoDType,
4747
MultiIndexSelector,
48+
PythonLiteral,
4849
RollingInterpolationMethod,
4950
SizedMultiIndexSelector,
5051
_1DArray,
@@ -131,6 +132,8 @@ def arg_min(self) -> int: ...
131132
def arg_true(self) -> Self: ...
132133
def count(self) -> int: ...
133134
def filter(self, predicate: Any) -> Self: ...
135+
def first(self) -> PythonLiteral: ...
136+
def last(self) -> PythonLiteral: ...
134137
def gather_every(self, n: int, offset: int) -> Self: ...
135138
def head(self, n: int) -> Self: ...
136139
def is_empty(self) -> bool:

narwhals/_compliant/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ class ScalarKwargs(TypedDict, total=False):
194194
"quantile",
195195
"all",
196196
"any",
197+
"first",
198+
"last",
197199
]
198200
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
199201

narwhals/_dask/expr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,8 @@ def dt(self) -> DaskExprDateTimeNamespace:
729729
return DaskExprDateTimeNamespace(self)
730730

731731
rank = not_implemented()
732+
first = not_implemented()
733+
last = not_implemented()
732734

733735
# namespaces
734736
list: not_implemented = not_implemented() # type: ignore[assignment]

narwhals/_duckdb/expr.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
DeferredTimeZone,
1414
F,
1515
col,
16+
generate_order_by_sql,
1617
lit,
1718
narwhals_to_native_dtype,
19+
sql_expression,
1820
when,
1921
window_expression,
2022
)
@@ -93,6 +95,24 @@ def _window_expression(
9395
nulls_last=nulls_last,
9496
)
9597

98+
def _first(self, expr: Expression, *order_by: str) -> Expression:
99+
# https://github.com/duckdb/duckdb/discussions/19252
100+
order_by_sql = generate_order_by_sql(
101+
*order_by,
102+
descending=[False] * len(order_by),
103+
nulls_last=[False] * len(order_by),
104+
)
105+
return sql_expression(f"first({expr} {order_by_sql})")
106+
107+
def _last(self, expr: Expression, *order_by: str) -> Expression:
108+
# https://github.com/duckdb/duckdb/discussions/19252
109+
order_by_sql = generate_order_by_sql(
110+
*order_by,
111+
descending=[False] * len(order_by),
112+
nulls_last=[False] * len(order_by),
113+
)
114+
return sql_expression(f"last({expr} {order_by_sql})")
115+
96116
def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
97117
from narwhals._duckdb.namespace import DuckDBNamespace
98118

0 commit comments

Comments
 (0)