2323 from narwhals ._compliant .typing import AliasNames , EvalNames , EvalSeries
2424 from narwhals ._utils import Version , _LimitedContext
2525 from narwhals .dtypes import DType
26+ from narwhals .typing import RankMethod
2627 from typing_extensions import TypeIs
2728
2829 from narwhals_daft .dataframe import DaftLazyFrame
@@ -783,6 +784,67 @@ def func(df: DaftLazyFrame, inputs: WindowInputs) -> Sequence[Expression]:
783784
784785 return self ._with_window_function (func )
785786
787+ def rank (self , method : RankMethod , * , descending : bool ) -> DaftExpr :
788+ if method in {"min" , "max" , "average" }:
789+ func = F .rank ()
790+ elif method == "dense" :
791+ func = F .dense_rank ()
792+ else : # method == "ordinal"
793+ func = F .row_number ()
794+
795+ def _rank (
796+ expr : Expression ,
797+ partition_by : Sequence [str | Expression ] = (),
798+ * ,
799+ descending : bool ,
800+ ) -> Expression :
801+ count_expr = expr .count (mode = "all" )
802+ window_kwargs : dict [str , Any ] = {
803+ "partition_by" : partition_by ,
804+ "order_by" : (expr ,),
805+ "descending" : [descending ],
806+ "nulls_first" : [False ],
807+ }
808+ count_window_kwargs : dict [str , Any ] = {
809+ "partition_by" : (* partition_by , expr )
810+ }
811+ window = self ._window_expression
812+ if method == "max" :
813+ rank_expr = op .sub (
814+ op .add (
815+ window (func , ** window_kwargs ),
816+ window (count_expr , ** count_window_kwargs ),
817+ ),
818+ lit (1 ),
819+ )
820+ elif method == "average" :
821+ rank_expr = op .add (
822+ window (func , ** window_kwargs ),
823+ op .truediv (
824+ op .sub (window (count_expr , ** count_window_kwargs ), lit (1 )),
825+ lit (2.0 ),
826+ ),
827+ )
828+ else :
829+ rank_expr = window (func , ** window_kwargs )
830+ return F .when (F .is_null (expr ), expr ).otherwise (rank_expr )
831+
832+ def _unpartitioned_rank (expr : Expression ) -> Expression :
833+ return _rank (expr , descending = descending )
834+
835+ def _partitioned_rank (
836+ df : DaftLazyFrame , inputs : WindowInputs
837+ ) -> Sequence [Expression ]:
838+ if inputs .order_by :
839+ msg = "`rank` followed by `over` with `order_by` specified is not supported for Daft."
840+ raise NotImplementedError (msg )
841+ return [
842+ _rank (expr , inputs .partition_by , descending = descending )
843+ for expr in self (df )
844+ ]
845+
846+ return self ._with_callable (_unpartitioned_rank , _partitioned_rank )
847+
786848 @property
787849 def name (self ) -> ExprNameNamespace :
788850 return ExprNameNamespace (self )
@@ -796,7 +858,6 @@ def str(self) -> ExprStringNamespace:
796858 filter = not_implemented ()
797859 ewm_mean = not_implemented ()
798860 kurtosis = not_implemented ()
799- rank = not_implemented ()
800861 map_batches = not_implemented ()
801862 median = not_implemented ()
802863 mode = not_implemented ()
0 commit comments