Skip to content

Commit 0d81b8b

Browse files
committed
Merge remote-tracking branch 'upstream/main' into oh-nodes
2 parents dc7fb95 + 02ad4b1 commit 0d81b8b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1559
-530
lines changed

.github/workflows/downstream_tests.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ jobs:
3535
- name: install-altair-dev
3636
run: |
3737
cd altair
38+
# Temporary pin until it's addressed upstream.
39+
uv pip install "pandas<3" --system
3840
uv pip install -e ".[dev, all]" --system
3941
- name: install-narwhals-dev
4042
run: |
@@ -86,6 +88,10 @@ jobs:
8688
run: |
8789
cd marimo
8890
. .venv/bin/activate
91+
# Temporary pin until it's addressed upstream.
92+
uv pip install "pandas<3"
93+
# Temporary pin to get CI green
94+
uv pip install "sqlglot<28.7.0"
8995
uv pip install -e ".[dev]"
9096
which python
9197
- name: install-narwhals-dev
@@ -237,6 +243,8 @@ jobs:
237243
run: |
238244
cd tea-tasting
239245
uv sync --group test
246+
# Temporary pin to get CI green
247+
uv pip install "sqlglot<28.7.0"
240248
- name: install-narwhals-dev
241249
run: |
242250
cd tea-tasting
@@ -282,6 +290,8 @@ jobs:
282290
- name: install-tubular-dev
283291
run: |
284292
cd tubular
293+
# Temporary pin until it's addressed upstream.
294+
uv pip install "pandas<3" --system
285295
uv pip install -e ".[dev]" --system
286296
- name: install-narwhals-dev
287297
run: |
@@ -367,6 +377,8 @@ jobs:
367377
- name: install-deps
368378
run: |
369379
cd hierarchicalforecast
380+
# Temporary pin until it's addressed upstream.
381+
uv pip install "pandas<3" --system
370382
uv pip install . --group dev --group polars --system
371383
- name: install-narwhals-dev
372384
run: |
@@ -409,6 +421,8 @@ jobs:
409421
- name: install-formulaic-dev
410422
run: |
411423
cd formulaic
424+
# Temporary pin until it's addressed upstream.
425+
hatch run uv pip install "pandas<3"
412426
hatch run uv pip install -e ".[arrow,calculus]"
413427
- name: install-narwhals-dev
414428
run: |
@@ -452,6 +466,8 @@ jobs:
452466
uv venv -p ${{ matrix.python-version }}
453467
. .venv/bin/activate
454468
uv pip install . --group dev
469+
# Temporary pin to get CI green
470+
uv pip install "sqlglot<28.7.0"
455471
uv pip install pytest pytest-cov pytest-snapshot pandas polars "ibis-framework[duckdb,mysql,postgres,sqlite]>=9.5.0" chatlas shiny
456472
- name: install-narwhals-dev
457473
run: |
@@ -499,6 +515,9 @@ jobs:
499515
run: |
500516
cd validoopsie
501517
uv sync --dev --upgrade
518+
# Temporary pin until it's addressed upstream.
519+
uv remove pandas --group dev
520+
uv add "pandas<3"
502521
- name: install-narwhals-dev
503522
run: |
504523
cd validoopsie
@@ -545,6 +564,8 @@ jobs:
545564
- name: install-deps
546565
run: |
547566
cd darts
567+
# Temporary pin until it's addressed upstream.
568+
uv pip install "pandas<3" --system
548569
uv pip install \
549570
-r requirements/core.txt \
550571
-r requirements/dev.txt \

.github/workflows/extremes.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ jobs:
169169
- name: Assert nightlies dependencies
170170
run: |
171171
DEPS=$(uv pip freeze)
172-
echo "$DEPS" | grep -E 'pandas.*(dev|rc)'
172+
echo "$DEPS" | grep -E 'pandas.*(dev|rc|\+)'
173173
echo "$DEPS" | grep 'pyarrow.*dev'
174174
echo "$DEPS" | grep 'numpy.*dev'
175175
echo "$DEPS" | grep 'dask.*@'

.github/workflows/pytest.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ jobs:
5353
cache-dependency-glob: "pyproject.toml"
5454
- name: install-reqs
5555
# we are not testing pyspark, modin, or dask on Windows here because nobody got time for that
56-
run: uv pip install -e ".[ibis]" --group core-tests --group extra --system
56+
# TODO(FBruzzesi): Temporarily pin sqlglot to <28.6.0 to avoid breaking changes in SQLFrame
57+
# See https://github.com/eakmanrq/sqlframe/issues/577
58+
run: uv pip install -e ".[ibis]" --group core-tests --group extra "sqlglot<28.6.0" --system
5759
- name: install-test-plugin
5860
run: uv pip install -e test-plugin/. --system
5961
- name: show-deps
@@ -85,7 +87,9 @@ jobs:
8587
cache-suffix: pytest-full-coverage-${{ matrix.python-version }}
8688
cache-dependency-glob: "pyproject.toml"
8789
- name: install-reqs
88-
run: uv pip install -e ".[dask, modin, ibis]" --group core-tests --group extra --system
90+
# TODO(FBruzzesi): Temporarily pin sqlglot to <28.6.0 to avoid breaking changes in SQLFrame
91+
# See https://github.com/eakmanrq/sqlframe/issues/577
92+
run: uv pip install -e ".[dask, modin, ibis]" --group core-tests --group extra "sqlglot<28.6.0" --system
8993
- name: install-test-plugin
9094
run: uv pip install -e test-plugin/. --system
9195
- name: show-deps
@@ -153,8 +157,9 @@ jobs:
153157
cache-suffix: python-314-${{ matrix.python-version }}
154158
cache-dependency-glob: "pyproject.toml"
155159
- name: install-reqs
156-
# Use `--pre` as duckdb stable not compatible with 3.14
157-
run: uv pip install -e . --group tests --pre pandas polars pyarrow duckdb sqlframe --system
160+
# TODO(FBruzzesi): Temporarily pin sqlglot to <28.6.0 to avoid breaking changes in SQLFrame
161+
# See https://github.com/eakmanrq/sqlframe/issues/577
162+
run: uv pip install -e . --group tests pandas polars pyarrow duckdb sqlframe "sqlglot<28.6.0" --system
158163
- name: show-deps
159164
run: uv pip freeze
160165
- name: Run pytest

docs/api-reference/sql.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# `narwhals.sql`
2+
3+
::: narwhals.sql
4+
handler: python
5+
options:
6+
members:
7+
- table
8+
9+
::: narwhals.sql.SQLTable
10+
handler: python
11+
options:
12+
members:
13+
- to_sql
14+
show_source: false
15+
show_bases: false

docs/concepts/order_dependence.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ such as:
1717
- `cum_sum`, `cum_min`, ...
1818
- `rolling_sum`, `rolling_min`, ...
1919
- `is_first_distinct`, `is_last_distinct`.
20+
- `first`, `last`.
2021

2122
When row-order is defined, as is the case for `DataFrame`, these operations pose
2223
no issue.
@@ -50,3 +51,14 @@ When writing an order-dependent function, if you want it to be executable by `La
5051
(and not just `DataFrame`), make sure that all order-dependent expressions are followed
5152
by `over` with `order_by` specified. If you forget to, don't worry, Narwhals will
5253
give you a loud and clear error message.
54+
55+
## Aggregations
56+
57+
To make `nw.col('a').first()` valid in the lazy case, you have the choice between writing:
58+
59+
- `nw.col('a').first().over(order_by='i')`.
60+
- `nw.col('a').first(order_by='i')`.
61+
62+
The first produces a new column of the same length as the original dataframe, whereas
63+
the other one produces a scalar. If you're using `first` in a group-by context, where
64+
you're required to provide aggregations, then we recommend using the latter.

docs/generating_sql.md

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,62 @@ For example, what's the SQL equivalent to:
55

66
```python exec="1" source="above" session="generating-sql"
77
import narwhals as nw
8-
from narwhals.typing import IntoFrameT
8+
from narwhals.typing import FrameT
99

1010

11-
def avg_monthly_price(df_native: IntoFrameT) -> IntoFrameT:
11+
def avg_monthly_price(df: FrameT) -> FrameT:
1212
return (
13-
nw.from_native(df_native)
14-
.group_by(nw.col("date").dt.truncate("1mo"))
13+
df.group_by(nw.col("date").dt.truncate("1mo"))
1514
.agg(nw.col("price").mean())
1615
.sort("date")
17-
.to_native()
1816
)
1917
```
2018

2119
?
2220

23-
There are several ways to find out.
21+
Narwhals provides you with a `narwhals.sql` module to do just that!
2422

25-
## Via DuckDB
23+
!!! info
24+
`narwhals.sql` currently requires DuckDB to be installed.
25+
26+
## `narwhals.sql`
2627

2728
You can generate SQL directly from DuckDB.
2829

2930
```python exec="1" source="above" session="generating-sql" result="sql"
30-
import duckdb
31+
import narwhals as nw
32+
from narwhals.sql import table
3133

32-
conn = duckdb.connect()
33-
conn.sql("""CREATE TABLE prices (date DATE, price DOUBLE);""")
34+
prices = table("prices", {"date": nw.Date, "price": nw.Float64})
3435

35-
df = nw.from_native(conn.table("prices"))
36-
print(avg_monthly_price(df).sql_query())
36+
result = (
37+
prices.group_by(nw.col("date").dt.truncate("1mo"))
38+
.agg(nw.col("price").mean())
39+
.sort("date")
40+
)
41+
print(result.to_sql())
3742
```
3843

39-
To make it look a bit prettier, or to then transpile it to other SQL dialects, we can pass it to [SQLGlot](https://github.com/tobymao/sqlglot):
44+
To make it look a bit prettier, you can pass `pretty=True`, but
45+
note that this currently requires [sqlparse](https://github.com/andialbrecht/sqlparse) to be installed.
4046

4147
```python exec="1" source="above" session="generating-sql" result="sql"
42-
import sqlglot
43-
44-
print(sqlglot.transpile(avg_monthly_price(df).sql_query(), pretty=True)[0])
48+
print(result.to_sql(pretty=True))
4549
```
4650

51+
Note that the generated SQL follows DuckDB's dialect. To translate it to other dialects,
52+
you may want to look into [sqlglot](https://github.com/tobymao/sqlglot), or use one of the
53+
solutions below (which also use sqlglot).
54+
4755
## Via Ibis
4856

49-
We can also use Ibis to generate SQL:
57+
You can also use Ibis or SQLFrame to generate SQL:
5058

5159
```python exec="1" source="above" session="generating-sql" result="sql"
5260
import ibis
5361

54-
t = ibis.table({"date": "date", "price": "double"}, name="prices")
55-
print(ibis.to_sql(avg_monthly_price(t)))
62+
df = nw.from_native(ibis.table({"date": "date", "price": "double"}, name="prices"))
63+
print(ibis.to_sql(avg_monthly_price(df).to_native()))
5664
```
5765

5866
## Via SQLFrame
@@ -66,11 +74,5 @@ session = StandaloneSession.builder.getOrCreate()
6674
session.catalog.add_table("prices", column_mapping={"date": "date", "price": "float"})
6775
df = nw.from_native(session.read.table("prices"))
6876

69-
print(avg_monthly_price(df).sql(dialect="duckdb"))
70-
```
71-
72-
Or, to print the SQL code in a different dialect (say, databricks):
73-
74-
```python exec="1" source="above" session="generating-sql" result="sql"
75-
print(avg_monthly_price(df).sql(dialect="databricks"))
77+
print(avg_monthly_price(df).to_native().sql(dialect="duckdb"))
7678
```

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ nav:
7070
- api-reference/dtypes.md
7171
- api-reference/exceptions.md
7272
- api-reference/selectors.md
73+
- api-reference/sql.md
7374
- api-reference/testing.md
7475
- api-reference/typing.md
7576
- api-reference/utils.md

narwhals/_arrow/group_by.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import pyarrow as pa
77
import pyarrow.compute as pc
88

9-
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
9+
from narwhals._arrow.utils import (
10+
BACKEND_VERSION,
11+
cast_to_comparable_string_types,
12+
extract_py_scalar,
13+
)
1014
from narwhals._compliant import EagerGroupBy
1115
from narwhals._expression_parsing import evaluate_output_names_and_aliases
1216
from narwhals._utils import generate_temporary_column_name, requires
@@ -71,12 +75,11 @@ def __init__(
7175
self._df = df
7276
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
7377
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
74-
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
7578
self._drop_null_keys = drop_null_keys
7679

7780
def _configure_agg(
78-
self, grouped: pa.TableGroupBy, expr: ArrowExpr, /
79-
) -> tuple[pa.TableGroupBy, Aggregation, AggregateOptions | None]:
81+
self, expr: ArrowExpr, /
82+
) -> tuple[Aggregation, AggregateOptions | None]:
8083
option: AggregateOptions | None = None
8184
function_name = self._leaf_name(expr)
8285
kwargs = self._kwargs(expr)
@@ -91,50 +94,49 @@ def _configure_agg(
9194
option = pc.ScalarAggregateOptions(min_count=0)
9295
elif function_name in self._OPTION_ORDERED:
9396
ignore_nulls = kwargs.get("ignore_nulls", False)
94-
grouped, option = self._ordered_agg(
95-
grouped, function_name, ignore_nulls=ignore_nulls
96-
)
97-
return grouped, self._remap_expr_name(function_name), option
98-
99-
def _ordered_agg(
100-
self,
101-
grouped: pa.TableGroupBy,
102-
name: NarwhalsAggregation,
103-
/,
104-
*,
105-
ignore_nulls: bool,
106-
) -> tuple[pa.TableGroupBy, AggregateOptions]:
107-
"""The default behavior of `pyarrow` raises when `first` or `last` are used.
108-
109-
You'd see an error like:
97+
option = pc.ScalarAggregateOptions(skip_nulls=ignore_nulls)
98+
return self._remap_expr_name(function_name), option
11099

111-
ArrowNotImplementedError: Using ordered aggregator in multiple threaded execution is not supported
112-
113-
We need to **disable** multi-threading to use them, but the ability to do so
114-
wasn't possible before `14.0.0` ([pyarrow-36709])
115-
116-
[pyarrow-36709]: https://github.com/apache/arrow/issues/36709
117-
"""
118-
backend_version = self.compliant._backend_version
119-
if backend_version >= (14, 0) and grouped._use_threads:
120-
native = self.compliant.native
121-
grouped = pa.TableGroupBy(native, grouped.keys, use_threads=False)
122-
elif backend_version < (14, 0): # pragma: no cover
100+
def _configure_grouped(self, *exprs: ArrowExpr) -> pa.TableGroupBy:
101+
order_by = ()
102+
use_threads = True
103+
for expr in exprs:
104+
md = next(expr._metadata.op_nodes_reversed())
105+
if md.name not in self._OPTION_ORDERED:
106+
continue
107+
# [pyarrow-36709]: https://github.com/apache/arrow/issues/36709
108+
use_threads = False
109+
if _current_order_by := md.kwargs.get("order_by", ()):
110+
if order_by and _current_order_by != order_by:
111+
msg = f"Only one `order_by` can be specified in `group_by`. Found both {order_by} and {_current_order_by}."
112+
raise NotImplementedError(msg)
113+
order_by = _current_order_by
114+
if not use_threads and BACKEND_VERSION < (14,): # pragma: no cover
123115
msg = (
124-
f"Using `{name}()` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', "
125-
f"found version {requires._unparse_version(backend_version)!r}.\n\n"
116+
f"Using `first/last` in a `group_by().agg(...)` context is only available in 'pyarrow>=14.0.0', "
117+
f"found version {requires._unparse_version(BACKEND_VERSION)!r}.\n\n"
126118
f"See https://github.com/apache/arrow/issues/36709"
127119
)
128120
raise NotImplementedError(msg)
129-
return grouped, pc.ScalarAggregateOptions(skip_nulls=ignore_nulls)
121+
if order_by:
122+
return pa.TableGroupBy(
123+
self.compliant.sort(*order_by, descending=False, nulls_last=False).native,
124+
self._keys,
125+
use_threads=use_threads,
126+
)
127+
if not use_threads:
128+
return pa.TableGroupBy(self.compliant.native, self._keys, use_threads=False)
129+
# TODO(unassigned): combine with `return` above once PyArrow 15 is the minimum.
130+
return pa.TableGroupBy(self.compliant.native, self._keys)
130131

131132
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
132133
self._ensure_all_simple(exprs)
134+
grouped = self._configure_grouped(*exprs)
135+
133136
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
134137
expected_pyarrow_column_names: list[str] = self._keys.copy()
135138
new_column_names: list[str] = self._keys.copy()
136139
exclude = (*self._keys, *self._output_key_names)
137-
grouped = self._grouped
138140

139141
for expr in exprs:
140142
output_names, aliases = evaluate_output_names_and_aliases(
@@ -153,7 +155,7 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
153155
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
154156
continue
155157

156-
grouped, function_name, option = self._configure_agg(grouped, expr)
158+
function_name, option = self._configure_agg(expr)
157159
new_column_names.extend(aliases)
158160
expected_pyarrow_column_names.extend(
159161
[f"{output_name}_{function_name}" for output_name in output_names]

0 commit comments

Comments
 (0)