1212from typing import TypeVar
1313from typing import overload
1414
15+ from narwhals ._compliant .typing import CompliantExprT_contra
1516from narwhals ._compliant .typing import CompliantSeriesT
17+ from narwhals ._compliant .typing import EagerExprT_contra
1618from narwhals ._compliant .typing import EagerSeriesT
1719from narwhals ._expression_parsing import evaluate_output_names_and_aliases
1820
3739T = TypeVar ("T" )
3840
3941
40- class CompliantDataFrame (Sized , Protocol [CompliantSeriesT ]):
42+ class CompliantDataFrame (Sized , Protocol [CompliantSeriesT , CompliantExprT_contra ]):
4143 def __narwhals_dataframe__ (self ) -> Self : ...
4244 def __narwhals_namespace__ (self ) -> Any : ...
4345 def __array__ (self , dtype : Any , * , copy : bool | None ) -> _2DArray : ...
@@ -46,7 +48,7 @@ def simple_select(self, *column_names: str) -> Self:
4648 """`select` where all args are column names."""
4749 ...
4850
49- def aggregate (self , * exprs : Any ) -> Self : # pragma: no cover
51+ def aggregate (self , * exprs : CompliantExprT_contra ) -> Self : # pragma: no cover
5052 """`select` where all args are aggregations or literals.
5153
5254 (so, no broadcasting is necessary).
@@ -62,12 +64,12 @@ def shape(self) -> tuple[int, int]: ...
6264 def clone (self ) -> Self : ...
6365 def collect (
6466 self , backend : Implementation | None , ** kwargs : Any
65- ) -> CompliantDataFrame [Any ]: ...
67+ ) -> CompliantDataFrame [Any , Any ]: ...
6668 def collect_schema (self ) -> Mapping [str , DType ]: ...
6769 def drop (self , columns : Sequence [str ], * , strict : bool ) -> Self : ...
6870 def drop_nulls (self , subset : Sequence [str ] | None ) -> Self : ...
6971 def estimated_size (self , unit : SizeUnit ) -> int | float : ...
70- def filter (self , predicate : Any ) -> Self : ...
72+ def filter (self , predicate : CompliantExprT_contra | Any ) -> Self : ...
7173 def gather_every (self , n : int , offset : int ) -> Self : ...
7274 def get_column (self , name : str ) -> CompliantSeriesT : ...
7375 def group_by (self , * keys : str , drop_null_keys : bool ) -> Any : ...
@@ -112,7 +114,7 @@ def sample(
112114 with_replacement : bool ,
113115 seed : int | None ,
114116 ) -> Self : ...
115- def select (self , * exprs : Any ) -> Self : ...
117+ def select (self , * exprs : CompliantExprT_contra ) -> Self : ...
116118 def sort (
117119 self , * by : str , descending : bool | Sequence [bool ], nulls_last : bool
118120 ) -> Self : ...
@@ -142,7 +144,7 @@ def unpivot(
142144 variable_name : str ,
143145 value_name : str ,
144146 ) -> Self : ...
145- def with_columns (self , * exprs : Any ) -> Self : ...
147+ def with_columns (self , * exprs : CompliantExprT_contra ) -> Self : ...
146148 def with_row_index (self , name : str ) -> Self : ...
147149 @overload
148150 def write_csv (self , file : None ) -> str : ...
@@ -169,10 +171,11 @@ def schema(self) -> Mapping[str, DType]: ...
169171 def _iter_columns (self ) -> Iterator [Any ]: ...
170172
171173
172- class EagerDataFrame (CompliantDataFrame [EagerSeriesT ], Protocol [EagerSeriesT ]):
173- def _maybe_evaluate_expr (
174- self , expr : EagerExpr [Self , EagerSeriesT ] | T , /
175- ) -> EagerSeriesT | T :
174+ class EagerDataFrame (
175+ CompliantDataFrame [EagerSeriesT , EagerExprT_contra ],
176+ Protocol [EagerSeriesT , EagerExprT_contra ],
177+ ):
178+ def _maybe_evaluate_expr (self , expr : EagerExprT_contra | T , / ) -> EagerSeriesT | T :
176179 if is_eager_expr (expr ):
177180 result : Sequence [EagerSeriesT ] = expr (self )
178181 if len (result ) > 1 :
@@ -184,14 +187,10 @@ def _maybe_evaluate_expr(
184187 return result [0 ]
185188 return expr
186189
187- def _evaluate_into_exprs (
188- self , * exprs : EagerExpr [Self , EagerSeriesT ]
189- ) -> Sequence [EagerSeriesT ]:
190+ def _evaluate_into_exprs (self , * exprs : EagerExprT_contra ) -> Sequence [EagerSeriesT ]:
190191 return list (chain .from_iterable (self ._evaluate_into_expr (expr ) for expr in exprs ))
191192
192- def _evaluate_into_expr (
193- self , expr : EagerExpr [Self , EagerSeriesT ], /
194- ) -> Sequence [EagerSeriesT ]:
193+ def _evaluate_into_expr (self , expr : EagerExprT_contra , / ) -> Sequence [EagerSeriesT ]:
195194 """Return list of raw columns.
196195
197196 For eager backends we alias operations at each step.
0 commit comments