@@ -129,17 +129,24 @@ def _Window(self) -> type[Window]: # noqa: N802
129129 return import_window (self ._implementation )
130130
131131 def _sort (
132- self , * cols : Column | str , descending : bool = False , nulls_last : bool = False
132+ self ,
133+ * cols : Column | str ,
134+ descending : Sequence [bool ] | None = None ,
135+ nulls_last : Sequence [bool ] | None = None ,
133136 ) -> Iterator [Column ]:
134137 F = self ._F # noqa: N806
138+ descending = descending or [False ] * len (cols )
139+ nulls_last = nulls_last or [False ] * len (cols )
135140 mapping = {
136141 (False , False ): F .asc_nulls_first ,
137142 (False , True ): F .asc_nulls_last ,
138143 (True , False ): F .desc_nulls_first ,
139144 (True , True ): F .desc_nulls_last ,
140145 }
141- sort = mapping [(descending , nulls_last )]
142- yield from (sort (col ) for col in cols )
146+ yield from (
147+ mapping [(_desc , _nulls_last )](col )
148+ for col , _desc , _nulls_last in zip (cols , descending , nulls_last )
149+ )
143150
144151 def partition_by (self , * cols : Column | str ) -> WindowSpec :
145152 """Wraps `Window().paritionBy`, with default and `WindowInputs` handling."""
@@ -178,7 +185,9 @@ def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]:
178185 window = (
179186 self .partition_by (* inputs .partition_by )
180187 .orderBy (
181- * self ._sort (* inputs .order_by , descending = reverse , nulls_last = reverse )
188+ * self ._sort (
189+ * inputs .order_by , descending = [reverse ], nulls_last = [reverse ]
190+ )
182191 )
183192 .rowsBetween (self ._Window .unboundedPreceding , 0 )
184193 )
@@ -695,7 +704,9 @@ def func(df: SparkLikeLazyFrame, inputs: SparkWindowInputs) -> Sequence[Column]:
695704 return [
696705 self ._F .row_number ().over (
697706 self .partition_by (* inputs .partition_by , expr ).orderBy (
698- * self ._sort (* inputs .order_by , descending = True , nulls_last = True )
707+ * self ._sort (
708+ * inputs .order_by , descending = [True ], nulls_last = [True ]
709+ )
699710 )
700711 )
701712 == 1
@@ -823,17 +834,17 @@ def rank(self, method: RankMethod, *, descending: bool) -> Self:
823834
824835 def _rank (
825836 expr : Column ,
837+ partition_by : Sequence [str | Column ] = (),
838+ order_by : Sequence [str | Column ] = (),
826839 * ,
827- descending : bool ,
828- partition_by : Sequence [str | Column ] | None = None ,
840+ descending : Sequence [ bool ] ,
841+ nulls_last : Sequence [bool ] ,
829842 ) -> Column :
830- order_by = self ._sort (expr , descending = descending , nulls_last = True )
831- if partition_by is not None :
832- window = self .partition_by (* partition_by ).orderBy (* order_by )
833- count_window = self .partition_by (* partition_by , expr )
834- else :
835- window = self .partition_by ().orderBy (* order_by )
836- count_window = self .partition_by (expr )
843+ _order_by = self ._sort (
844+ expr , * order_by , descending = descending , nulls_last = nulls_last
845+ )
846+ window = self .partition_by (* partition_by ).orderBy (* _order_by )
847+ count_window = self .partition_by (* partition_by , expr )
837848 if method == "max" :
838849 rank_expr = (
839850 getattr (self ._F , func_name )().over (window )
@@ -852,14 +863,21 @@ def _rank(
852863 return self ._F .when (expr .isNotNull (), rank_expr )
853864
854865 def _unpartitioned_rank (expr : Column ) -> Column :
855- return _rank (expr , descending = descending )
866+ return _rank (expr , descending = [ descending ], nulls_last = [ True ] )
856867
857868 def _partitioned_rank (
858869 df : SparkLikeLazyFrame , inputs : SparkWindowInputs
859870 ) -> Sequence [Column ]:
860- assert not inputs .order_by # noqa: S101
871+ # node: when `descending` / `nulls_last` are supported in `.over`, they should be respected here
872+ # https://github.com/narwhals-dev/narwhals/issues/2790
861873 return [
862- _rank (expr , descending = descending , partition_by = inputs .partition_by )
874+ _rank (
875+ expr ,
876+ inputs .partition_by ,
877+ inputs .order_by ,
878+ descending = [descending ] + [False ] * len (inputs .order_by ),
879+ nulls_last = [True ] + [False ] * len (inputs .order_by ),
880+ )
863881 for expr in self (df )
864882 ]
865883
0 commit comments