1010
1111from narwhals ._dask .utils import add_row_index
1212from narwhals ._dask .utils import evaluate_exprs
13- from narwhals ._pandas_like .utils import check_column_names_are_unique
1413from narwhals ._pandas_like .utils import native_to_narwhals_dtype
1514from narwhals ._pandas_like .utils import select_columns_by_name
1615from narwhals .typing import CompliantDataFrame
@@ -41,15 +40,14 @@ def __init__(
4140 * ,
4241 backend_version : tuple [int , ...],
4342 version : Version ,
44- validate_column_names : bool ,
43+ # Unused, just for compatibility. We only validate when collecting.
44+ validate_column_names : bool = False ,
4545 ) -> None :
4646 self ._native_frame : dd .DataFrame = native_dataframe
4747 self ._backend_version = backend_version
4848 self ._implementation = Implementation .DASK
4949 self ._version = version
5050 validate_backend_version (self ._implementation , self ._backend_version )
51- if validate_column_names :
52- check_column_names_are_unique (native_dataframe .columns )
5351
5452 def __native_namespace__ (self : Self ) -> ModuleType :
5553 if self ._implementation is Implementation .DASK :
@@ -71,23 +69,19 @@ def _change_version(self: Self, version: Version) -> Self:
7169 self ._native_frame ,
7270 backend_version = self ._backend_version ,
7371 version = version ,
74- validate_column_names = False ,
7572 )
7673
77- def _from_native_frame (
78- self : Self , df : Any , * , validate_column_names : bool = True
79- ) -> Self :
74+ def _from_native_frame (self : Self , df : Any ) -> Self :
8075 return self .__class__ (
8176 df ,
8277 backend_version = self ._backend_version ,
8378 version = self ._version ,
84- validate_column_names = validate_column_names ,
8579 )
8680
8781 def with_columns (self : Self , * exprs : DaskExpr ) -> Self :
8882 df = self ._native_frame
8983 new_series = evaluate_exprs (self , * exprs )
90- df = df .assign (** new_series )
84+ df = df .assign (** dict ( new_series ) )
9185 return self ._from_native_frame (df )
9286
9387 def collect (
@@ -107,7 +101,7 @@ def collect(
107101 implementation = Implementation .PANDAS ,
108102 backend_version = parse_version (pd ),
109103 version = self ._version ,
110- validate_column_names = False ,
104+ validate_column_names = True ,
111105 )
112106
113107 if backend is Implementation .POLARS :
@@ -130,7 +124,7 @@ def collect(
130124 pa .Table .from_pandas (result ),
131125 backend_version = parse_version (pa ),
132126 version = self ._version ,
133- validate_column_names = False ,
127+ validate_column_names = True ,
134128 )
135129
136130 msg = f"Unsupported `backend` value: { backend } " # pragma: no cover
@@ -144,9 +138,7 @@ def filter(self: Self, predicate: DaskExpr) -> Self:
144138 # `[0]` is safe as the predicate's expression only returns a single column
145139 mask = predicate ._call (self )[0 ]
146140
147- return self ._from_native_frame (
148- self ._native_frame .loc [mask ], validate_column_names = False
149- )
141+ return self ._from_native_frame (self ._native_frame .loc [mask ])
150142
151143 def simple_select (self : Self , * column_names : str ) -> Self :
152144 return self ._from_native_frame (
@@ -156,13 +148,12 @@ def simple_select(self: Self, *column_names: str) -> Self:
156148 self ._backend_version ,
157149 self ._implementation ,
158150 ),
159- validate_column_names = False ,
160151 )
161152
162153 def aggregate (self : Self , * exprs : DaskExpr ) -> Self :
163154 new_series = evaluate_exprs (self , * exprs )
164- df = dd .concat ([val .rename (name ) for name , val in new_series . items () ], axis = 1 )
165- return self ._from_native_frame (df , validate_column_names = False )
155+ df = dd .concat ([val .rename (name ) for name , val in new_series ], axis = 1 )
156+ return self ._from_native_frame (df )
166157
167158 def select (self : Self , * exprs : DaskExpr ) -> Self :
168159 new_series = evaluate_exprs (self , * exprs )
@@ -173,22 +164,19 @@ def select(self: Self, *exprs: DaskExpr) -> Self:
173164 dd .from_pandas (
174165 pd .DataFrame (), npartitions = self ._native_frame .npartitions
175166 ),
176- validate_column_names = False ,
177167 )
178168
179169 df = select_columns_by_name (
180- self ._native_frame .assign (** new_series ),
181- list ( new_series . keys ()) ,
170+ self ._native_frame .assign (** dict ( new_series ) ),
171+ [ s [ 0 ] for s in new_series ] ,
182172 self ._backend_version ,
183173 self ._implementation ,
184174 )
185- return self ._from_native_frame (df , validate_column_names = False )
175+ return self ._from_native_frame (df )
186176
187177 def drop_nulls (self : Self , subset : list [str ] | None ) -> Self :
188178 if subset is None :
189- return self ._from_native_frame (
190- self ._native_frame .dropna (), validate_column_names = False
191- )
179+ return self ._from_native_frame (self ._native_frame .dropna ())
192180 plx = self .__narwhals_namespace__ ()
193181 return self .filter (~ plx .any_horizontal (plx .col (* subset ).is_null ()))
194182
@@ -210,9 +198,7 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
210198 compliant_frame = self , columns = columns , strict = strict
211199 )
212200
213- return self ._from_native_frame (
214- self ._native_frame .drop (columns = to_drop ), validate_column_names = False
215- )
201+ return self ._from_native_frame (self ._native_frame .drop (columns = to_drop ))
216202
217203 def with_row_index (self : Self , name : str ) -> Self :
218204 # Implementation is based on the following StackOverflow reply:
@@ -228,8 +214,7 @@ def rename(self: Self, mapping: dict[str, str]) -> Self:
228214
229215 def head (self : Self , n : int ) -> Self :
230216 return self ._from_native_frame (
231- self ._native_frame .head (n = n , compute = False , npartitions = - 1 ),
232- validate_column_names = False ,
217+ self ._native_frame .head (n = n , compute = False , npartitions = - 1 )
233218 )
234219
235220 def unique (
@@ -250,7 +235,7 @@ def unique(
250235 else :
251236 mapped_keep = {"any" : "first" }.get (keep , keep )
252237 result = native_frame .drop_duplicates (subset = subset , keep = mapped_keep )
253- return self ._from_native_frame (result , validate_column_names = False )
238+ return self ._from_native_frame (result )
254239
255240 def sort (
256241 self : Self ,
@@ -265,8 +250,7 @@ def sort(
265250 ascending = [not d for d in descending ]
266251 na_position = "last" if nulls_last else "first"
267252 return self ._from_native_frame (
268- df .sort_values (list (by ), ascending = ascending , na_position = na_position ),
269- validate_column_names = False ,
253+ df .sort_values (list (by ), ascending = ascending , na_position = na_position )
270254 )
271255
272256 def join (
@@ -413,9 +397,7 @@ def tail(self: Self, n: int) -> Self: # pragma: no cover
413397 n_partitions = native_frame .npartitions
414398
415399 if n_partitions == 1 :
416- return self ._from_native_frame (
417- self ._native_frame .tail (n = n , compute = False ), validate_column_names = False
418- )
400+ return self ._from_native_frame (self ._native_frame .tail (n = n , compute = False ))
419401 else :
420402 msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
421403 raise NotImplementedError (msg )
0 commit comments