Skip to content

Commit 058782d

Browse files
committed
Merge remote-tracking branch 'upstream/main' into oh-nodes
2 parents a1a1645 + cd8085b commit 058782d

File tree

19 files changed

+358
-109
lines changed

19 files changed

+358
-109
lines changed

.github/workflows/downstream_tests.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ jobs:
367367
- name: install-deps
368368
run: |
369369
cd hierarchicalforecast
370-
uv pip install --system ".[dev,polars]"
370+
uv pip install . --group dev --group polars --system
371371
- name: install-narwhals-dev
372372
run: |
373373
uv pip uninstall narwhals --system
@@ -627,7 +627,8 @@ jobs:
627627
- name: install-deps
628628
run: |
629629
cd fairlearn
630-
uv pip install -e . -r requirements.txt matplotlib polars pyarrow pytest typing-extensions --system
630+
# TODO(FBruzzesi): Align with fairlearn team to get a minimal requirement to test narwhals features
631+
uv pip install -e . -r requirements.txt lightgbm matplotlib polars pyarrow pytest typing-extensions --system
631632
- name: install-narwhals-dev
632633
run: |
633634
cd fairlearn
@@ -637,6 +638,4 @@ jobs:
637638
- name: run-pytest
638639
run: |
639640
cd fairlearn
640-
# TODO(FBruzzesi): I hope this will be simplified once there is a decision on
641-
# https://github.com/fairlearn/fairlearn/issues/1555
642-
pytest test/unit/preprocessing test/unit/metrics
641+
pytest test/unit -m "narwhals"

docs/api-reference/expr_list.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- mean
1212
- median
1313
- min
14+
- sort
1415
- sum
1516
- unique
1617
show_source: false

docs/api-reference/series_list.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- mean
1212
- median
1313
- min
14+
- sort
1415
- sum
1516
- unique
1617
show_source: false

docs/javascripts/table.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
document$.subscribe(function() {
2+
// Only run on api-completeness pages
3+
var currentPath = window.location.pathname
4+
if (!currentPath.includes('/api-completeness/')) {
5+
return
6+
}
7+
28
var tables = document.querySelectorAll("article table:not([class])")
39
tables.forEach(function(table) {
410
new Tablesort(table)

narwhals/_arrow/series_list.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pyarrow as pa
66
import pyarrow.compute as pc
77

8-
from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg
8+
from narwhals._arrow.utils import ArrowSeriesNamespace, list_agg, list_sort
99
from narwhals._compliant.any_namespace import ListNamespace
1010
from narwhals._utils import not_implemented
1111

@@ -35,5 +35,10 @@ def median(self) -> ArrowSeries:
3535
def sum(self) -> ArrowSeries:
3636
return self.with_native(list_agg(self.native, "sum"))
3737

38+
def sort(self, *, descending: bool, nulls_last: bool) -> ArrowSeries:
39+
return self.with_native(
40+
list_sort(self.native, descending=descending, nulls_last=nulls_last)
41+
)
42+
3843
unique = not_implemented()
3944
contains = not_implemented()

narwhals/_arrow/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,3 +532,35 @@ def list_agg(
532532
)
533533
]
534534
)
535+
536+
537+
def list_sort(
538+
array: ChunkedArrayAny, *, descending: bool, nulls_last: bool
539+
) -> ChunkedArrayAny:
540+
sort_direction: Literal["ascending", "descending"] = (
541+
"descending" if descending else "ascending"
542+
)
543+
nulls_position: Literal["at_start", "at_end"] = "at_end" if nulls_last else "at_start"
544+
idx, v = "idx", "values"
545+
is_not_sorted = pc.greater(pc.list_value_length(array), lit(0))
546+
indexed = pa.Table.from_arrays(
547+
[arange(start=0, end=len(array), step=1), array], names=[idx, v]
548+
)
549+
not_sorted_part = indexed.filter(is_not_sorted)
550+
pass_through = indexed.filter(pc.fill_null(pc.invert(is_not_sorted), lit(True))) # pyright: ignore[reportArgumentType]
551+
exploded = pa.Table.from_arrays(
552+
[pc.list_flatten(array), pc.list_parent_indices(array)], names=[v, idx]
553+
)
554+
sorted_indices = pc.sort_indices(
555+
exploded,
556+
sort_keys=[(idx, "ascending"), (v, sort_direction)],
557+
null_placement=nulls_position,
558+
)
559+
offsets = not_sorted_part.column(v).combine_chunks().offsets # type: ignore[attr-defined]
560+
sorted_imploded = pa.ListArray.from_arrays(
561+
offsets, pa.array(exploded.take(sorted_indices).column(v))
562+
)
563+
imploded_by_idx = pa.Table.from_arrays(
564+
[not_sorted_part.column(idx), sorted_imploded], names=[idx, v]
565+
)
566+
return pa.concat_tables([imploded_by_idx, pass_through]).sort_by(idx).column(v)

narwhals/_compliant/any_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def max(self) -> CompliantT_co: ...
7575
def mean(self) -> CompliantT_co: ...
7676
def median(self) -> CompliantT_co: ...
7777
def sum(self) -> CompliantT_co: ...
78+
def sort(self, *, descending: bool, nulls_last: bool) -> CompliantT_co: ...
7879

7980

8081
class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):

narwhals/_compliant/expr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,11 @@ def median(self) -> EagerExprT:
10141014
def sum(self) -> EagerExprT:
10151015
return self.compliant._reuse_series_namespace("list", "sum")
10161016

1017+
def sort(self, *, descending: bool, nulls_last: bool) -> EagerExprT:
1018+
return self.compliant._reuse_series_namespace(
1019+
"list", "sort", descending=descending, nulls_last=nulls_last
1020+
)
1021+
10171022

10181023
class CompliantExprNameNamespace( # type: ignore[misc]
10191024
_ExprNamespace[CompliantExprT_co],

narwhals/_duckdb/expr_list.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,10 @@ def func(expr: Expression) -> Expression:
6464
)
6565

6666
return self.compliant._with_callable(func)
67+
68+
def sort(self, *, descending: bool, nulls_last: bool) -> DuckDBExpr:
69+
sort_direction = "DESC" if descending else "ASC"
70+
nulls_position = "NULLS LAST" if nulls_last else "NULLS FIRST"
71+
return self.compliant._with_elementwise(
72+
lambda expr: F("list_sort", expr, lit(sort_direction), lit(nulls_position))
73+
)

narwhals/_ibis/expr_list.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,18 @@ def func(expr: ir.ArrayColumn) -> ir.Value:
5252

5353
return self.compliant._with_callable(func)
5454

55+
def sort(self, *, descending: bool, nulls_last: bool) -> IbisExpr:
56+
if descending:
57+
msg = "Descending sort is not currently supported for Ibis."
58+
raise NotImplementedError(msg)
59+
60+
def func(expr: ir.ArrayColumn) -> ir.ArrayValue:
61+
if nulls_last:
62+
return expr.sort()
63+
expr_no_nulls = expr.filter(lambda x: x.notnull())
64+
expr_nulls = expr.filter(lambda x: x.isnull())
65+
return expr_nulls.concat(expr_no_nulls.sort())
66+
67+
return self.compliant._with_callable(func)
68+
5569
median = not_implemented()

0 commit comments

Comments
 (0)