1010from narwhals ._plan .typing import (
1111 IntoExpr ,
1212 NativeDataFrameT ,
13+ NativeDataFrameT_co ,
1314 NativeFrameT ,
15+ NativeFrameT_co ,
1416 NativeSeriesT ,
1517 OneOrIterable ,
1618)
2123if TYPE_CHECKING :
2224 from collections .abc import Sequence
2325
24- import pyarrow as pa
2526 from typing_extensions import Self
2627
2728 from narwhals ._plan .compliant .dataframe import CompliantDataFrame , CompliantFrame
28- from narwhals .typing import NativeFrame
2929
3030
31- class BaseFrame (Generic [NativeFrameT ]):
32- _compliant : CompliantFrame [Any , NativeFrameT ]
31+ class BaseFrame (Generic [NativeFrameT_co ]):
32+ _compliant : CompliantFrame [Any , NativeFrameT_co ]
3333 _version : ClassVar [Version ] = Version .MAIN
3434
3535 @property
@@ -47,30 +47,26 @@ def columns(self) -> list[str]:
4747 def __repr__ (self ) -> str : # pragma: no cover
4848 return generate_repr (f"nw.{ type (self ).__name__ } " , self .to_native ().__repr__ ())
4949
50- @classmethod
51- def from_native (cls , native : Any , / ) -> Self :
52- raise NotImplementedError
50+ def __init__ (self , compliant : Any , / ) -> None :
51+ self ._compliant = compliant
5352
54- @classmethod
55- def _from_compliant (cls , compliant : CompliantFrame [Any , NativeFrameT ], / ) -> Self :
56- obj = cls .__new__ (cls )
57- obj ._compliant = compliant
58- return obj
53+ def _with_compliant (self , compliant : CompliantFrame [Any , NativeFrameT ], / ) -> Self :
54+ return type (self )(compliant )
5955
60- def to_native (self ) -> NativeFrameT :
56+ def to_native (self ) -> NativeFrameT_co :
6157 return self ._compliant .native
6258
6359 def select (self , * exprs : OneOrIterable [IntoExpr ], ** named_exprs : Any ) -> Self :
6460 named_irs , schema = prepare_projection (
6561 _parse .parse_into_seq_of_expr_ir (* exprs , ** named_exprs ), schema = self
6662 )
67- return self ._from_compliant (self ._compliant .select (schema .select_irs (named_irs )))
63+ return self ._with_compliant (self ._compliant .select (schema .select_irs (named_irs )))
6864
6965 def with_columns (self , * exprs : OneOrIterable [IntoExpr ], ** named_exprs : Any ) -> Self :
7066 named_irs , schema = prepare_projection (
7167 _parse .parse_into_seq_of_expr_ir (* exprs , ** named_exprs ), schema = self
7268 )
73- return self ._from_compliant (
69+ return self ._with_compliant (
7470 self ._compliant .with_columns (schema .with_columns_irs (named_irs ))
7571 )
7672
@@ -85,32 +81,33 @@ def sort(
8581 by , * more_by , descending = descending , nulls_last = nulls_last
8682 )
8783 named_irs , _ = prepare_projection (sort , schema = self )
88- return self ._from_compliant (self ._compliant .sort (named_irs , opts ))
84+ return self ._with_compliant (self ._compliant .sort (named_irs , opts ))
8985
9086 def drop (self , columns : Sequence [str ], * , strict : bool = True ) -> Self :
91- return self ._from_compliant (self ._compliant .drop (columns , strict = strict ))
87+ return self ._with_compliant (self ._compliant .drop (columns , strict = strict ))
9288
9389 def drop_nulls (self , subset : str | Sequence [str ] | None = None ) -> Self :
9490 subset = [subset ] if isinstance (subset , str ) else subset
95- return self ._from_compliant (self ._compliant .drop_nulls (subset ))
91+ return self ._with_compliant (self ._compliant .drop_nulls (subset ))
9692
9793
98- class DataFrame (BaseFrame [NativeDataFrameT ], Generic [NativeDataFrameT , NativeSeriesT ]):
99- _compliant : CompliantDataFrame [Any , NativeDataFrameT , NativeSeriesT ]
94+ class DataFrame (
95+ BaseFrame [NativeDataFrameT_co ], Generic [NativeDataFrameT_co , NativeSeriesT ]
96+ ):
97+ _compliant : CompliantDataFrame [Any , NativeDataFrameT_co , NativeSeriesT ]
10098
10199 @property
102100 def _series (self ) -> type [Series [NativeSeriesT ]]:
103101 return Series [NativeSeriesT ]
104102
105- # NOTE: Gave up on trying to get typing working for now
106103 @classmethod
107- def from_native ( # type: ignore[override]
108- cls , native : NativeFrame , /
109- ) -> DataFrame [pa . Table , pa . ChunkedArray [ Any ] ]:
104+ def from_native (
105+ cls : type [ DataFrame [ Any , Any ]], native : NativeDataFrameT , /
106+ ) -> DataFrame [NativeDataFrameT ]:
110107 if is_pyarrow_table (native ):
111108 from narwhals ._plan .arrow .dataframe import ArrowDataFrame
112109
113- return ArrowDataFrame .from_native (native , cls ._version ). to_narwhals ( )
110+ return cls ( ArrowDataFrame .from_native (native , cls ._version ))
114111
115112 raise NotImplementedError (type (native ))
116113
@@ -129,7 +126,7 @@ def to_dict(
129126 ) -> dict [str , Series [NativeSeriesT ]] | dict [str , list [Any ]]:
130127 if as_series :
131128 return {
132- key : self ._series . _from_compliant (value )
129+ key : self ._series (value )
133130 for key , value in self ._compliant .to_dict (as_series = as_series ).items ()
134131 }
135132 return self ._compliant .to_dict (as_series = as_series )
0 commit comments