Skip to content

Commit 1581500

Browse files
committed
Merge branch 'fea/polars/rank' into branch-25.10
2 parents 5f83c84 + 35c9b5f commit 1581500

File tree

5 files changed

+139
-5
lines changed

5 files changed

+139
-5
lines changed

python/cudf_polars/cudf_polars/dsl/expressions/unary.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,13 @@ class UnaryFunction(Expr):
106106
"drop_nulls",
107107
"fill_null",
108108
"mask_nans",
109+
"null_count",
110+
"rank",
109111
"round",
110112
"set_sorted",
113+
"top_k",
111114
"unique",
112115
"value_counts",
113-
"null_count",
114-
"top_k",
115116
}
116117
)
117118
_supported_cum_aggs = frozenset(
@@ -135,13 +136,14 @@ def __init__(
135136
self.children = children
136137
self.is_pointwise = self.name not in (
137138
"as_struct",
138-
"cum_min",
139139
"cum_max",
140+
"cum_min",
140141
"cum_prod",
141142
"cum_sum",
142143
"drop_nulls",
143-
"unique",
144+
"rank",
144145
"top_k",
146+
"unique",
145147
)
146148

147149
if self.name not in UnaryFunction._supported_fns:
@@ -152,6 +154,12 @@ def __init__(
152154
raise NotImplementedError(
153155
"reverse=True is not supported for cumulative aggregations"
154156
)
157+
if self.name == "rank":
158+
method, _, _ = self.options
159+
if method not in {"average", "min", "max", "dense", "ordinal"}:
160+
raise NotImplementedError(
161+
f"ranking with {method=} is not yet supported"
162+
)
155163

156164
def do_evaluate(
157165
self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
@@ -342,6 +350,71 @@ def do_evaluate(
342350
),
343351
dtype=self.dtype,
344352
)
353+
elif self.name == "rank":
354+
(column,) = (child.evaluate(df, context=context) for child in self.children)
355+
method_str, descending, _ = self.options
356+
357+
method = {
358+
"average": plc.aggregation.RankMethod.AVERAGE,
359+
"min": plc.aggregation.RankMethod.MIN,
360+
"max": plc.aggregation.RankMethod.MAX,
361+
"dense": plc.aggregation.RankMethod.DENSE,
362+
"ordinal": plc.aggregation.RankMethod.FIRST,
363+
}[method_str]
364+
365+
order = (
366+
plc.types.Order.DESCENDING if descending else plc.types.Order.ASCENDING
367+
)
368+
369+
ranked: plc.Column = plc.sorting.rank(
370+
column.obj,
371+
method,
372+
order,
373+
plc.types.NullPolicy.EXCLUDE,
374+
plc.types.NullOrder.BEFORE,
375+
percentage=False,
376+
)
377+
378+
null_count = column.null_count
379+
if null_count and not descending:
380+
# libcudf rank is offset when nulls would sort first and are excluded:
381+
# - dense: +1 (nulls count as a skipped leading group)
382+
# - min/max/ordinal/average: +k (nulls counted before all valid rows)
383+
rank_dtype = ranked.type()
384+
if method_str == "dense":
385+
one = plc.Scalar.from_py(
386+
1.0
387+
if rank_dtype.id() in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}
388+
else 1,
389+
rank_dtype,
390+
)
391+
ranked = plc.binaryop.binary_operation(
392+
ranked, one, plc.binaryop.BinaryOperator.SUB, rank_dtype
393+
)
394+
else:
395+
k_scalar = plc.Scalar.from_py(
396+
float(null_count)
397+
if rank_dtype.id() in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}
398+
else int(null_count),
399+
rank_dtype,
400+
)
401+
ranked = plc.binaryop.binary_operation(
402+
ranked, k_scalar, plc.binaryop.BinaryOperator.SUB, rank_dtype
403+
)
404+
405+
# Min/Max/Dense/Ordinal -> IDX_DTYPE
406+
# See https://github.com/pola-rs/polars/blob/main/crates/polars-ops/src/series/ops/rank.rs
407+
if method_str in {"min", "max", "dense", "ordinal"}:
408+
dest = self.dtype.plc.id()
409+
src = ranked.type().id()
410+
if dest == plc.TypeId.UINT32 and src != plc.TypeId.UINT32:
411+
ranked = plc.unary.cast(ranked, plc.DataType(plc.TypeId.UINT32))
412+
elif (
413+
dest == plc.TypeId.UINT64 and src != plc.TypeId.UINT64
414+
): # pragma: no cover
415+
ranked = plc.unary.cast(ranked, plc.DataType(plc.TypeId.UINT64))
416+
417+
return Column(ranked, dtype=self.dtype)
345418
elif self.name == "top_k":
346419
(column, k) = (
347420
child.evaluate(df, context=context) for child in self.children

python/cudf_polars/cudf_polars/dsl/utils/aggregations.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def decompose_single_agg(
8989
"""
9090
agg = named_expr.value
9191
name = named_expr.name
92+
if isinstance(agg, expr.UnaryFunction) and agg.name in {"rank"}:
93+
name = agg.name
94+
raise NotImplementedError(
95+
f"UnaryFunction {name=} not supported in groupby context"
96+
)
9297
if isinstance(agg, expr.UnaryFunction) and agg.name == "null_count":
9398
(child,) = agg.children
9499

python/cudf_polars/tests/expressions/test_numeric_unaryops.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
import polars as pl
1010

11-
from cudf_polars.testing.asserts import assert_gpu_result_equal
11+
from cudf_polars.testing.asserts import (
12+
assert_gpu_result_equal,
13+
assert_ir_translation_raises,
14+
)
15+
from cudf_polars.utils.versions import POLARS_VERSION_LT_132
1216

1317

1418
@pytest.fixture(
@@ -112,3 +116,42 @@ def test_null_count():
112116
pl.col("baz").is_null().sum(),
113117
)
114118
assert_gpu_result_equal(q)
119+
120+
121+
@pytest.mark.parametrize("method", ["ordinal", "dense", "min", "max", "average"])
122+
@pytest.mark.parametrize("descending", [False, True])
123+
def test_rank_supported(request, ldf: pl.LazyFrame, method: str, *, descending: bool):
124+
request.applymarker(
125+
pytest.mark.xfail(condition=POLARS_VERSION_LT_132, reason="nested loop join")
126+
)
127+
expr = pl.col("a").rank(method=method, descending=descending)
128+
q = ldf.select(expr)
129+
assert_gpu_result_equal(q)
130+
131+
132+
@pytest.mark.parametrize("method", ["ordinal", "dense", "min", "max", "average"])
133+
@pytest.mark.parametrize("descending", [False, True])
134+
@pytest.mark.parametrize("test", ["with_nulls", "with_ties"])
135+
def test_rank_methods_with_nulls_or_ties(
136+
request, ldf: pl.LazyFrame, method: str, *, descending: bool, test: str
137+
) -> None:
138+
request.applymarker(
139+
pytest.mark.xfail(condition=POLARS_VERSION_LT_132, reason="nested loop join")
140+
)
141+
142+
base = pl.col("a")
143+
if test == "with_nulls":
144+
expr = pl.when((base % 2) == 0).then(None).otherwise(base)
145+
else:
146+
expr = pl.when((base % 2) == 0).then(pl.lit(-5)).otherwise(base)
147+
148+
q = ldf.select(expr.rank(method=method, descending=descending))
149+
assert_gpu_result_equal(q)
150+
151+
152+
@pytest.mark.parametrize("seed", [42])
153+
@pytest.mark.parametrize("method", ["random"])
154+
def test_rank_unsupported(ldf: pl.LazyFrame, method: str, seed: int) -> None:
155+
expr = pl.col("a").rank(method=method, seed=seed)
156+
q = ldf.select(expr)
157+
assert_ir_translation_raises(q, NotImplementedError)

python/cudf_polars/tests/test_groupby.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,9 @@ def test_groupby_aggs_keep_unsupported_as_null(df: pl.LazyFrame, agg_expr) -> No
384384
def test_groupby_ternary_supported(df: pl.LazyFrame, expr: pl.Expr) -> None:
385385
q = df.group_by("key1").agg(expr)
386386
assert_gpu_result_equal(q, check_row_order=False)
387+
388+
389+
def test_groupby_rank_raises(df: pl.LazyFrame) -> None:
390+
q = df.group_by("key1").agg(pl.col("int").rank())
391+
392+
assert_ir_translation_raises(q, NotImplementedError)

python/cudf_polars/tests/test_rolling.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,10 @@ def test_rolling_ternary_supported(df, expr):
272272
def test_rolling_ternary_unsupported(df, expr):
273273
q = df.rolling("dt", period="48h", closed="both").agg(expr.alias("out"))
274274
assert_ir_translation_raises(q, NotImplementedError)
275+
276+
277+
def test_rolling_rank_unsupported(df):
278+
q = df.rolling("dt", period="48h", closed="both").agg(
279+
pl.col("values").rank(method="dense", descending=False)
280+
)
281+
assert_ir_translation_raises(q, NotImplementedError)

0 commit comments

Comments
 (0)