Skip to content

Commit 12afbfe

Browse files
authored
perf: only validate duplicate column names when collecting for duckdb/pyspark/dask (#2092)
1 parent 4517baa commit 12afbfe

File tree

15 files changed

+121
-166
lines changed

15 files changed

+121
-166
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def select(self: Self, *exprs: ArrowExpr) -> Self:
356356
names = [s.name for s in new_series]
357357
new_series = align_series_full_broadcast(*new_series)
358358
df = pa.Table.from_arrays([s._native_series for s in new_series], names=names)
359-
return self._from_native_frame(df, validate_column_names=False)
359+
return self._from_native_frame(df, validate_column_names=True)
360360

361361
def with_columns(self: Self, *exprs: ArrowExpr) -> Self:
362362
native_frame = self._native_frame

narwhals/_dask/dataframe.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from narwhals._dask.utils import add_row_index
1212
from narwhals._dask.utils import evaluate_exprs
13-
from narwhals._pandas_like.utils import check_column_names_are_unique
1413
from narwhals._pandas_like.utils import native_to_narwhals_dtype
1514
from narwhals._pandas_like.utils import select_columns_by_name
1615
from narwhals.typing import CompliantDataFrame
@@ -41,15 +40,14 @@ def __init__(
4140
*,
4241
backend_version: tuple[int, ...],
4342
version: Version,
44-
validate_column_names: bool,
43+
# Unused, just for compatibility. We only validate when collecting.
44+
validate_column_names: bool = False,
4545
) -> None:
4646
self._native_frame: dd.DataFrame = native_dataframe
4747
self._backend_version = backend_version
4848
self._implementation = Implementation.DASK
4949
self._version = version
5050
validate_backend_version(self._implementation, self._backend_version)
51-
if validate_column_names:
52-
check_column_names_are_unique(native_dataframe.columns)
5351

5452
def __native_namespace__(self: Self) -> ModuleType:
5553
if self._implementation is Implementation.DASK:
@@ -71,23 +69,19 @@ def _change_version(self: Self, version: Version) -> Self:
7169
self._native_frame,
7270
backend_version=self._backend_version,
7371
version=version,
74-
validate_column_names=False,
7572
)
7673

77-
def _from_native_frame(
78-
self: Self, df: Any, *, validate_column_names: bool = True
79-
) -> Self:
74+
def _from_native_frame(self: Self, df: Any) -> Self:
8075
return self.__class__(
8176
df,
8277
backend_version=self._backend_version,
8378
version=self._version,
84-
validate_column_names=validate_column_names,
8579
)
8680

8781
def with_columns(self: Self, *exprs: DaskExpr) -> Self:
8882
df = self._native_frame
8983
new_series = evaluate_exprs(self, *exprs)
90-
df = df.assign(**new_series)
84+
df = df.assign(**dict(new_series))
9185
return self._from_native_frame(df)
9286

9387
def collect(
@@ -107,7 +101,7 @@ def collect(
107101
implementation=Implementation.PANDAS,
108102
backend_version=parse_version(pd),
109103
version=self._version,
110-
validate_column_names=False,
104+
validate_column_names=True,
111105
)
112106

113107
if backend is Implementation.POLARS:
@@ -130,7 +124,7 @@ def collect(
130124
pa.Table.from_pandas(result),
131125
backend_version=parse_version(pa),
132126
version=self._version,
133-
validate_column_names=False,
127+
validate_column_names=True,
134128
)
135129

136130
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
@@ -144,9 +138,7 @@ def filter(self: Self, predicate: DaskExpr) -> Self:
144138
# `[0]` is safe as the predicate's expression only returns a single column
145139
mask = predicate._call(self)[0]
146140

147-
return self._from_native_frame(
148-
self._native_frame.loc[mask], validate_column_names=False
149-
)
141+
return self._from_native_frame(self._native_frame.loc[mask])
150142

151143
def simple_select(self: Self, *column_names: str) -> Self:
152144
return self._from_native_frame(
@@ -156,13 +148,12 @@ def simple_select(self: Self, *column_names: str) -> Self:
156148
self._backend_version,
157149
self._implementation,
158150
),
159-
validate_column_names=False,
160151
)
161152

162153
def aggregate(self: Self, *exprs: DaskExpr) -> Self:
163154
new_series = evaluate_exprs(self, *exprs)
164-
df = dd.concat([val.rename(name) for name, val in new_series.items()], axis=1)
165-
return self._from_native_frame(df, validate_column_names=False)
155+
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
156+
return self._from_native_frame(df)
166157

167158
def select(self: Self, *exprs: DaskExpr) -> Self:
168159
new_series = evaluate_exprs(self, *exprs)
@@ -173,22 +164,19 @@ def select(self: Self, *exprs: DaskExpr) -> Self:
173164
dd.from_pandas(
174165
pd.DataFrame(), npartitions=self._native_frame.npartitions
175166
),
176-
validate_column_names=False,
177167
)
178168

179169
df = select_columns_by_name(
180-
self._native_frame.assign(**new_series),
181-
list(new_series.keys()),
170+
self._native_frame.assign(**dict(new_series)),
171+
[s[0] for s in new_series],
182172
self._backend_version,
183173
self._implementation,
184174
)
185-
return self._from_native_frame(df, validate_column_names=False)
175+
return self._from_native_frame(df)
186176

187177
def drop_nulls(self: Self, subset: list[str] | None) -> Self:
188178
if subset is None:
189-
return self._from_native_frame(
190-
self._native_frame.dropna(), validate_column_names=False
191-
)
179+
return self._from_native_frame(self._native_frame.dropna())
192180
plx = self.__narwhals_namespace__()
193181
return self.filter(~plx.any_horizontal(plx.col(*subset).is_null()))
194182

@@ -210,9 +198,7 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
210198
compliant_frame=self, columns=columns, strict=strict
211199
)
212200

213-
return self._from_native_frame(
214-
self._native_frame.drop(columns=to_drop), validate_column_names=False
215-
)
201+
return self._from_native_frame(self._native_frame.drop(columns=to_drop))
216202

217203
def with_row_index(self: Self, name: str) -> Self:
218204
# Implementation is based on the following StackOverflow reply:
@@ -228,8 +214,7 @@ def rename(self: Self, mapping: dict[str, str]) -> Self:
228214

229215
def head(self: Self, n: int) -> Self:
230216
return self._from_native_frame(
231-
self._native_frame.head(n=n, compute=False, npartitions=-1),
232-
validate_column_names=False,
217+
self._native_frame.head(n=n, compute=False, npartitions=-1)
233218
)
234219

235220
def unique(
@@ -250,7 +235,7 @@ def unique(
250235
else:
251236
mapped_keep = {"any": "first"}.get(keep, keep)
252237
result = native_frame.drop_duplicates(subset=subset, keep=mapped_keep)
253-
return self._from_native_frame(result, validate_column_names=False)
238+
return self._from_native_frame(result)
254239

255240
def sort(
256241
self: Self,
@@ -265,8 +250,7 @@ def sort(
265250
ascending = [not d for d in descending]
266251
na_position = "last" if nulls_last else "first"
267252
return self._from_native_frame(
268-
df.sort_values(list(by), ascending=ascending, na_position=na_position),
269-
validate_column_names=False,
253+
df.sort_values(list(by), ascending=ascending, na_position=na_position)
270254
)
271255

272256
def join(
@@ -413,9 +397,7 @@ def tail(self: Self, n: int) -> Self: # pragma: no cover
413397
n_partitions = native_frame.npartitions
414398

415399
if n_partitions == 1:
416-
return self._from_native_frame(
417-
self._native_frame.tail(n=n, compute=False), validate_column_names=False
418-
)
400+
return self._from_native_frame(self._native_frame.tail(n=n, compute=False))
419401
else:
420402
msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
421403
raise NotImplementedError(msg)

narwhals/_dask/group_by.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def _from_native_frame(self: Self, df: dd.DataFrame) -> DaskLazyFrame:
9999
df,
100100
backend_version=self._df._backend_version,
101101
version=self._df._version,
102-
validate_column_names=True,
103102
)
104103

105104

narwhals/_dask/namespace.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def concat(
192192
dd.concat(dfs, axis=0, join="inner"),
193193
backend_version=self._backend_version,
194194
version=self._version,
195-
validate_column_names=True,
196195
)
197196
if how == "horizontal":
198197
all_column_names: list[str] = [
@@ -211,14 +210,12 @@ def concat(
211210
dd.concat(dfs, axis=1, join="outer"),
212211
backend_version=self._backend_version,
213212
version=self._version,
214-
validate_column_names=True,
215213
)
216214
if how == "diagonal":
217215
return DaskLazyFrame(
218216
dd.concat(dfs, axis=0, join="outer"),
219217
backend_version=self._backend_version,
220218
version=self._version,
221-
validate_column_names=True,
222219
)
223220

224221
raise NotImplementedError

narwhals/_dask/utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,15 @@ def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series
3939
return obj
4040

4141

42-
def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> dict[str, dx.Series]:
43-
native_results: dict[str, dx.Series] = {}
42+
def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]:
43+
native_results: list[tuple[str, dx.Series]] = []
4444
for expr in exprs:
4545
native_series_list = expr._call(df)
4646
_, aliases = evaluate_output_names_and_aliases(expr, df, [])
4747
if len(aliases) != len(native_series_list): # pragma: no cover
4848
msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results"
4949
raise AssertionError(msg)
50-
native_results.update(
51-
{
52-
alias: native_series
53-
for native_series, alias in zip(native_series_list, aliases)
54-
}
55-
)
50+
native_results.extend(zip(aliases, native_series_list))
5651
return native_results
5752

5853

0 commit comments

Comments
 (0)