Skip to content

Commit 8be198e

Browse files
feat: Support over expressions more freely, make expressions printable, rewrite internals (travelling pr 🌴 ) (#3152)
--------- Co-authored-by: FBruzzesi <[email protected]> Co-authored-by: Francesco Bruzzesi <[email protected]>
1 parent 9240530 commit 8be198e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+2225
-2704
lines changed

docs/how_it_works.md

Lines changed: 135 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ pn = PandasLikeNamespace(
7878
)
7979
print(nw.col("a")._to_compliant_expr(pn))
8080
```
81+
8182
The result from the last line above is the same as we'd get from `pn.col('a')`, and it's
8283
a `narwhals._pandas_like.expr.PandasLikeExpr` object, which we'll call `PandasLikeExpr` for
8384
short.
@@ -215,6 +216,7 @@ pn = PandasLikeNamespace(
215216
expr = (nw.col("a") + 1)._to_compliant_expr(pn)
216217
print(expr)
217218
```
219+
218220
If we then extract a Narwhals-compliant dataframe from `df` by
219221
calling `._compliant_frame`, we get a `PandasLikeDataFrame` - and that's an object which we can pass `expr` to!
220222

@@ -228,6 +230,7 @@ We can then view the underlying pandas Dataframe which was produced by calling `
228230
```python exec="1" result="python" session="pandas_api_mapping" source="above"
229231
print(result._native_frame)
230232
```
233+
231234
which is the same as we'd have obtained by just using the Narwhals API directly:
232235

233236
```python exec="1" result="python" session="pandas_api_mapping" source="above"
@@ -238,49 +241,98 @@ print(nw.to_native(df.select(nw.col("a") + 1)))
238241

239242
Group-by is probably one of Polars' most significant innovations (on the syntax side) with respect
240243
to pandas. We can write something like
244+
241245
```python
242246
df: pl.DataFrame
243247
df.group_by("a").agg((pl.col("c") > pl.col("b").mean()).max())
244248
```
249+
245250
To do this in pandas, we need to either use `GroupBy.apply` (sloooow), or do some crazy manual
246251
optimisations to get it to work.
247252

248253
In Narwhals, here's what we do:
249254

250255
- if somebody uses a simple group-by aggregation (e.g. `df.group_by('a').agg(nw.col('b').mean())`),
251256
then on the pandas side we translate it to
252-
```python
253-
df: pd.DataFrame
254-
df.groupby("a").agg({"b": ["mean"]})
255-
```
257+
258+
```python
259+
df: pd.DataFrame
260+
df.groupby("a").agg({"b": ["mean"]})
261+
```
262+
256263
- if somebody passes a complex group-by aggregation, then we use `apply` and raise a `UserWarning`, warning
257264
users of the performance penalty and advising them to refactor their code so that the aggregation they perform
258265
ends up being a simple one.
259266

260-
In order to tell whether an aggregation is simple, Narwhals uses the private `_depth` attribute of `PandasLikeExpr`:
267+
## Nodes
268+
269+
If we have a Narwhals expression, we can look at the operations which make it up by accessing `_nodes`:
270+
271+
```python exec="1" result="python" session="pandas_impl" source="above"
272+
import narwhals as nw
273+
274+
expr = nw.col("a").abs().std(ddof=1) + nw.col("b")
275+
print(expr._nodes)
276+
```
277+
278+
Each node represents an operation. Here, we have 4 operations:
279+
280+
1. Given some dataframe, select column `'a'`.
281+
2. Take its absolute value.
282+
3. Take its standard deviation, with `ddof=1`.
283+
4. Sum column `'b'`.
284+
285+
Let's take a look at a couple of these nodes. Let's start with the third one:
286+
287+
```python exec="1" result="python" session="pandas_impl" source="above"
288+
print(expr._nodes[2].as_dict())
289+
```
290+
291+
This tells us a few things:
292+
293+
- We're performing an aggregation.
294+
- The name of the function is `'std'`. This will be looked up in the compliant object.
295+
- It takes keyword arguments `ddof=1`.
296+
- We'll look at `exprs`, `str_as_lit`, and `allow_multi_output` later.
297+
298+
In order for the evaluation to succeed, then `PandasLikeExpr` must have a `std` method defined
299+
on it, which takes a `ddof` argument. And this is what the `CompliantExpr` Protocol is for: so
300+
long as a backend's implementation complies with the protocol, then Narwhals will be able to
301+
unpack a `ExprNode` and turn it into a valid call.
302+
303+
Let's take a look at the fourth node:
304+
305+
```python exec="1" result="python" session="pandas_impl" source="above"
306+
print(expr._nodes[3].as_dict())
307+
```
308+
309+
Note how now, the `exprs` attribute is populated. Indeed, we are summing another expression: `col('b')`.
310+
The `exprs` parameter holds arguments which are either expressions, or should be interpreted as expressions.
311+
The `str_as_lit` parameter tells us whether string literals should be interpreted as literals (e.g. `lit('foo')`)
312+
or columns (e.g. `col('foo')`). Finally `allow_multi_output` tells us whether multi-outuput expressions
313+
(more on this in the next section) are allowed to appear in `exprs`.
314+
315+
Note that the expression in `exprs` also has its own nodes:
261316

262317
```python exec="1" result="python" session="pandas_impl" source="above"
263-
print(pn.col("a").mean())
264-
print((pn.col("a") + 1).mean())
318+
print(expr._nodes[3].exprs[0]._nodes)
265319
```
266320

267-
For simple aggregations, Narwhals can just look at `_depth` and `function_name` and figure out
268-
which (efficient) elementary operation this corresponds to in pandas.
321+
It's nodes all the way down!
269322

270323
## Expression Metadata
271324

272-
Let's try printing out a few expressions to the console to see what they show us:
325+
Let's try printing out some compliant expressions' metadata to see what it shows us:
273326

274-
```python exec="1" result="python" session="metadata" source="above"
327+
```python exec="1" result="python" session="pandas_impl" source="above"
275328
import narwhals as nw
276329

277-
print(nw.col("a"))
278-
print(nw.col("a").mean())
279-
print(nw.col("a").mean().over("b"))
330+
print(nw.col("a")._to_compliant_expr(pn)._metadata)
331+
print(nw.col("a").mean()._to_compliant_expr(pn)._metadata)
332+
print(nw.col("a").mean().over("b")._to_compliant_expr(pn)._metadata)
280333
```
281334

282-
Note how they tell us something about their metadata. This section is all about
283-
making sense of what that all means, what the rules are, and what it enables.
335+
This section is all about making sense of what that all means, what the rules are, and what it enables.
284336

285337
Here's a brief description of each piece of metadata:
286338

@@ -293,8 +345,6 @@ Here's a brief description of each piece of metadata:
293345
- `ExpansionKind.MULTI_UNNAMED`: Produces multiple outputs whose names depend
294346
on the input dataframe. For example, `nw.nth(0, 1)` or `nw.selectors.numeric()`.
295347

296-
- `last_node`: Kind of the last operation in the expression. See
297-
`narwhals._expression_parsing.ExprKind` for the various options.
298348
- `has_windows`: Whether the expression already contains an `over(...)` statement.
299349
- `n_orderable_ops`: How many order-dependent operations the expression contains.
300350

@@ -311,8 +361,9 @@ Here's a brief description of each piece of metadata:
311361
- `is_scalar_like`: Whether the output of the expression is always length-1.
312362
- `is_literal`: Whether the expression doesn't depend on any column but instead
313363
only on literal values, like `nw.lit(1)`.
364+
- `nodes`: List of operations which this expression applies when evaluated.
314365

315-
#### Chaining
366+
### Chaining
316367

317368
Say we have `expr.expr_method()`. How does `expr`'s `ExprMetadata` change?
318369
This depends on `expr_method`. Details can be found in `narwhals/_expression_parsing`,
@@ -356,7 +407,7 @@ is:
356407
then `n_orderable_ops` is decreased by 1. This is the only way that
357408
`n_orderable_ops` can decrease.
358409

359-
### Broadcasting
410+
## Broadcasting
360411

361412
When performing comparisons between columns and aggregations or scalars, we operate as if the
362413
aggregation or scalar was broadcasted to the length of the whole column. For example, if we
@@ -377,3 +428,67 @@ Narwhals triggers a broadcast in these situations:
377428

378429
Each backend is then responsible for doing its own broadcasting, as defined in each
379430
`CompliantExpr.broadcast` method.
431+
432+
## Elementwise push-down
433+
434+
SQL is picky about `over` operations. For example:
435+
436+
- `sum(a) over (partition by b)` is valid.
437+
- `sum(abs(a)) over (partition by b)` is valid.
438+
- `abs(sum(a)) over (partition by b)` is not valid.
439+
440+
In Polars, however, all three of
441+
442+
- `pl.col('a').sum().over('b')` is valid.
443+
- `pl.col('a').abs().sum().over('b')` is valid.
444+
- `pl.col('a').sum().abs().over('b')` is valid.
445+
446+
How can we retain Polars' level of flexibility when translating to SQL engines?
447+
448+
The answer is: by rewriting expressions. Specifically, we push down `over` nodes past elementwise ones.
449+
To see this, let's try printing the Narwhals equivalent of the last expression above (the one that SQL rejects):
450+
451+
```python exec="1" result="python" session="pushdown" source="above"
452+
import narwhals as nw
453+
454+
print(nw.col("a").sum().abs().over("b"))
455+
```
456+
457+
Note how Narwhals automatically inserted the `over` operation _before_ the `abs` one. In other words, instead
458+
of doing
459+
460+
- `sum` -> `abs` -> `over`
461+
462+
it did
463+
464+
- `sum` -> `over` -> `abs`
465+
466+
thus allowing the expression to be valid for SQL engines!
467+
468+
This is what we refer to as "pushing down `over` nodes". The idea is:
469+
470+
- Elementwise operations operate row-by-row and don't depend on the rows around them.
471+
- An `over` node partitions or orders a computation.
472+
- Therefore, an elementwise operation followed by an `over` operation is the same
473+
as doing the `over` operation followed by that same elementwise operation!
474+
475+
Note that the pushdown also applies to any arguments to the elementwise operation.
476+
For example, if we have
477+
478+
```python
479+
(nw.col("a").sum() + nw.col("b").sum()).over("c")
480+
```
481+
482+
then `+` is an elementwise operation and so can be swapped with `over`. We just need
483+
to take care to apply the `over` operation to all the arguments of `+`, so that we
484+
end up with
485+
486+
```python
487+
nw.col("a").sum().over("c") + nw.col("b").sum().over("c")
488+
```
489+
490+
!!! info
491+
In general, query optimisation is out-of-scope for Narwhals. We consider this
492+
expression rewrite acceptable because:
493+
- It's simple.
494+
- It allows us to evaluate operations which otherwise wouldn't be allowed for certain backends.

narwhals/_arrow/dataframe.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from narwhals._arrow.series import ArrowSeries
1010
from narwhals._arrow.utils import concat_tables, native_to_narwhals_dtype, repeat
1111
from narwhals._compliant import EagerDataFrame
12-
from narwhals._expression_parsing import ExprKind
1312
from narwhals._utils import (
1413
Implementation,
1514
Version,
@@ -330,7 +329,7 @@ def simple_select(self, *column_names: str) -> Self:
330329
)
331330

332331
def select(self, *exprs: ArrowExpr) -> Self:
333-
new_series = self._evaluate_into_exprs(*exprs)
332+
new_series = self._evaluate_exprs(*exprs)
334333
if not new_series:
335334
# return empty dataframe, like Polars does
336335
return self._with_native(
@@ -357,7 +356,7 @@ def with_columns(self, *exprs: ArrowExpr) -> Self:
357356
# NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame)
358357
# All `pyarrow` data is immutable, so this is fine
359358
native_frame = self.native
360-
new_columns = self._evaluate_into_exprs(*exprs)
359+
new_columns = self._evaluate_exprs(*exprs)
361360
columns = self.columns
362361

363362
for col_value in new_columns:
@@ -402,12 +401,10 @@ def join(
402401
)
403402

404403
return self._with_native(
405-
self.with_columns(
406-
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
407-
)
404+
self.with_columns(plx.lit(0, None).alias(key_token).broadcast())
408405
.native.join(
409406
other.with_columns(
410-
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
407+
plx.lit(0, None).alias(key_token).broadcast()
411408
).native,
412409
keys=key_token,
413410
right_keys=key_token,
@@ -517,8 +514,7 @@ def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
517514
return self.select(row_index, plx.all())
518515

519516
def filter(self, predicate: ArrowExpr) -> Self:
520-
# `[0]` is safe as the predicate's expression only returns a single column
521-
mask_native = self._evaluate_into_exprs(predicate)[0].native
517+
mask_native = self._evaluate_single_output_expr(predicate).native
522518
return self._with_native(
523519
self.native.filter(mask_native), validate_column_names=False
524520
)

narwhals/_arrow/expr.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222
from narwhals._arrow.dataframe import ArrowDataFrame
2323
from narwhals._arrow.namespace import ArrowNamespace
24-
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
25-
from narwhals._expression_parsing import ExprMetadata
24+
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries
2625
from narwhals._utils import Version, _LimitedContext
2726

2827

@@ -33,23 +32,15 @@ def __init__(
3332
self,
3433
call: EvalSeries[ArrowDataFrame, ArrowSeries],
3534
*,
36-
depth: int,
37-
function_name: str,
3835
evaluate_output_names: EvalNames[ArrowDataFrame],
3936
alias_output_names: AliasNames | None,
4037
version: Version,
41-
scalar_kwargs: ScalarKwargs | None = None,
42-
implementation: Implementation | None = None,
38+
implementation: Implementation = Implementation.PYARROW,
4339
) -> None:
4440
self._call = call
45-
self._depth = depth
46-
self._function_name = function_name
47-
self._depth = depth
4841
self._evaluate_output_names = evaluate_output_names
4942
self._alias_output_names = alias_output_names
5043
self._version = version
51-
self._scalar_kwargs = scalar_kwargs or {}
52-
self._metadata: ExprMetadata | None = None
5344

5445
@classmethod
5546
def from_column_names(
@@ -58,7 +49,6 @@ def from_column_names(
5849
/,
5950
*,
6051
context: _LimitedContext,
61-
function_name: str = "",
6252
) -> Self:
6353
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
6454
try:
@@ -75,8 +65,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
7565

7666
return cls(
7767
func,
78-
depth=0,
79-
function_name=function_name,
8068
evaluate_output_names=evaluate_column_names,
8169
alias_output_names=None,
8270
version=context._version,
@@ -94,8 +82,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
9482

9583
return cls(
9684
func,
97-
depth=0,
98-
function_name="nth",
9985
evaluate_output_names=cls._eval_names_indices(column_indices),
10086
alias_output_names=None,
10187
version=context._version,
@@ -113,7 +99,7 @@ def _reuse_series_extra_kwargs(
11399

114100
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
115101
meta = self._metadata
116-
if partition_by and meta is not None and not meta.is_scalar_like:
102+
if partition_by and not meta.is_scalar_like:
117103
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
118104
raise NotImplementedError(msg)
119105

@@ -167,8 +153,6 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
167153

168154
return self.__class__(
169155
func,
170-
depth=self._depth + 1,
171-
function_name=self._function_name + "->over",
172156
evaluate_output_names=self._evaluate_output_names,
173157
alias_output_names=self._alias_output_names,
174158
version=self._version,

narwhals/_arrow/group_by.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ def _configure_agg(
7676
) -> tuple[pa.TableGroupBy, Aggregation, AggregateOptions | None]:
7777
option: AggregateOptions | None = None
7878
function_name = self._leaf_name(expr)
79+
kwargs = self._kwargs(expr)
7980
if function_name in self._OPTION_VARIANCE:
80-
ddof = expr._scalar_kwargs.get("ddof", 1)
81+
ddof = kwargs["ddof"]
8182
option = pc.VarianceOptions(ddof=ddof)
8283
elif function_name in self._OPTION_COUNT_ALL:
8384
option = pc.CountOptions(mode="all")
@@ -128,10 +129,11 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
128129
output_names, aliases = evaluate_output_names_and_aliases(
129130
expr, self.compliant, exclude
130131
)
131-
132-
if expr._depth == 0:
132+
md = expr._metadata
133+
op_nodes_reversed = list(md.op_nodes_reversed())
134+
if len(op_nodes_reversed) == 1:
133135
# e.g. `agg(nw.len())`
134-
if expr._function_name != "len": # pragma: no cover
136+
if op_nodes_reversed[0].name != "len": # pragma: no cover
135137
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
136138
raise AssertionError(msg)
137139

0 commit comments

Comments
 (0)