Skip to content

Commit 538cf1f

Browse files
authored
Merge pull request #74 from raisadz/feat/rank
feat: add `rank`
2 parents bdfd9b9 + e79696d commit 538cf1f

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

narwhals_daft/expr.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries
2424
from narwhals._utils import Version, _LimitedContext
2525
from narwhals.dtypes import DType
26+
from narwhals.typing import RankMethod
2627
from typing_extensions import TypeIs
2728

2829
from narwhals_daft.dataframe import DaftLazyFrame
@@ -783,6 +784,67 @@ def func(df: DaftLazyFrame, inputs: WindowInputs) -> Sequence[Expression]:
783784

784785
return self._with_window_function(func)
785786

787+
def rank(self, method: RankMethod, *, descending: bool) -> DaftExpr:
788+
if method in {"min", "max", "average"}:
789+
func = F.rank()
790+
elif method == "dense":
791+
func = F.dense_rank()
792+
else: # method == "ordinal"
793+
func = F.row_number()
794+
795+
def _rank(
796+
expr: Expression,
797+
partition_by: Sequence[str | Expression] = (),
798+
*,
799+
descending: bool,
800+
) -> Expression:
801+
count_expr = expr.count(mode="all")
802+
window_kwargs: dict[str, Any] = {
803+
"partition_by": partition_by,
804+
"order_by": (expr,),
805+
"descending": [descending],
806+
"nulls_first": [False],
807+
}
808+
count_window_kwargs: dict[str, Any] = {
809+
"partition_by": (*partition_by, expr)
810+
}
811+
window = self._window_expression
812+
if method == "max":
813+
rank_expr = op.sub(
814+
op.add(
815+
window(func, **window_kwargs),
816+
window(count_expr, **count_window_kwargs),
817+
),
818+
lit(1),
819+
)
820+
elif method == "average":
821+
rank_expr = op.add(
822+
window(func, **window_kwargs),
823+
op.truediv(
824+
op.sub(window(count_expr, **count_window_kwargs), lit(1)),
825+
lit(2.0),
826+
),
827+
)
828+
else:
829+
rank_expr = window(func, **window_kwargs)
830+
return F.when(F.is_null(expr), expr).otherwise(rank_expr)
831+
832+
def _unpartitioned_rank(expr: Expression) -> Expression:
833+
return _rank(expr, descending=descending)
834+
835+
def _partitioned_rank(
836+
df: DaftLazyFrame, inputs: WindowInputs
837+
) -> Sequence[Expression]:
838+
if inputs.order_by:
839+
msg = "`rank` followed by `over` with `order_by` specified is not supported for Daft."
840+
raise NotImplementedError(msg)
841+
return [
842+
_rank(expr, inputs.partition_by, descending=descending)
843+
for expr in self(df)
844+
]
845+
846+
return self._with_callable(_unpartitioned_rank, _partitioned_rank)
847+
786848
@property
787849
def name(self) -> ExprNameNamespace:
788850
return ExprNameNamespace(self)
@@ -796,7 +858,6 @@ def str(self) -> ExprStringNamespace:
796858
filter = not_implemented()
797859
ewm_mean = not_implemented()
798860
kurtosis = not_implemented()
799-
rank = not_implemented()
800861
map_batches = not_implemented()
801862
median = not_implemented()
802863
mode = not_implemented()

run_tests.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@
5757
"test_joinasof_suffix",
5858
"test_joinasof_time",
5959
"test_kurtosis_expr",
60-
"test_lazy_rank_expr",
61-
"test_lazy_rank_expr_desc",
6260
"test_left_to_right_broadcasting",
6361
"test_len_expr",
6462
"test_mapping_key_not_in_expr",
@@ -89,8 +87,6 @@
8987
"test_parse_weight",
9088
"test_pipe_expr",
9189
"test_quantile_expr",
92-
"test_rank_expr_in_over_context",
93-
"test_rank_expr_in_over_desc",
9490
"test_rank_with_order_by",
9591
"test_rank_with_order_by_and_partition_by",
9692
"test_replace_strict_expr_basic",

0 commit comments

Comments
 (0)