Skip to content

Commit dabe5d4

Browse files
authored
Merge branch 'main' into from-numpy-2d-ns
2 parents 7a137b3 + 4053603 commit dabe5d4

File tree

14 files changed

+318
-236
lines changed

14 files changed

+318
-236
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,8 @@ def collect(
691691
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
692692
raise AssertionError(msg) # pragma: no cover
693693

694-
clone = not_implemented()
694+
def clone(self) -> Self:
695+
return self._from_native_frame(self.native, validate_column_names=False)
695696

696697
def item(self: Self, row: int | None, column: int | str | None) -> Any:
697698
from narwhals._arrow.series import maybe_extract_py_scalar

narwhals/_dask/dataframe.py

Lines changed: 42 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -71,34 +71,28 @@ def __narwhals_lazyframe__(self: Self) -> Self:
7171

7272
def _change_version(self: Self, version: Version) -> Self:
7373
return self.__class__(
74-
self._native_frame,
75-
backend_version=self._backend_version,
76-
version=version,
74+
self.native, backend_version=self._backend_version, version=version
7775
)
7876

7977
def _from_native_frame(self: Self, df: Any) -> Self:
8078
return self.__class__(
81-
df,
82-
backend_version=self._backend_version,
83-
version=self._version,
79+
df, backend_version=self._backend_version, version=self._version
8480
)
8581

8682
def _iter_columns(self) -> Iterator[dx.Series]:
87-
for _col, ser in self._native_frame.items(): # noqa: PERF102
83+
for _col, ser in self.native.items(): # noqa: PERF102
8884
yield ser
8985

9086
def with_columns(self: Self, *exprs: DaskExpr) -> Self:
91-
df = self._native_frame
9287
new_series = evaluate_exprs(self, *exprs)
93-
df = df.assign(**dict(new_series))
94-
return self._from_native_frame(df)
88+
return self._from_native_frame(self.native.assign(**dict(new_series)))
9589

9690
def collect(
9791
self: Self,
9892
backend: Implementation | None,
9993
**kwargs: Any,
10094
) -> CompliantDataFrame[Any, Any, Any]:
101-
result = self._native_frame.compute(**kwargs)
95+
result = self.native.compute(**kwargs)
10296

10397
if backend is None or backend is Implementation.PANDAS:
10498
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
@@ -143,14 +137,13 @@ def columns(self: Self) -> list[str]:
143137

144138
def filter(self: Self, predicate: DaskExpr) -> Self:
145139
# `[0]` is safe as the predicate's expression only returns a single column
146-
mask = predicate._call(self)[0]
147-
148-
return self._from_native_frame(self._native_frame.loc[mask])
140+
mask = predicate(self)[0]
141+
return self._from_native_frame(self.native.loc[mask])
149142

150143
def simple_select(self: Self, *column_names: str) -> Self:
151144
return self._from_native_frame(
152145
select_columns_by_name(
153-
self._native_frame,
146+
self.native,
154147
list(column_names),
155148
self._backend_version,
156149
self._implementation,
@@ -165,7 +158,7 @@ def aggregate(self: Self, *exprs: DaskExpr) -> Self:
165158
def select(self: Self, *exprs: DaskExpr) -> Self:
166159
new_series = evaluate_exprs(self, *exprs)
167160
df = select_columns_by_name(
168-
self._native_frame.assign(**dict(new_series)),
161+
self.native.assign(**dict(new_series)),
169162
[s[0] for s in new_series],
170163
self._backend_version,
171164
self._implementation,
@@ -174,19 +167,19 @@ def select(self: Self, *exprs: DaskExpr) -> Self:
174167

175168
def drop_nulls(self: Self, subset: Sequence[str] | None) -> Self:
176169
if subset is None:
177-
return self._from_native_frame(self._native_frame.dropna())
170+
return self._from_native_frame(self.native.dropna())
178171
plx = self.__narwhals_namespace__()
179172
return self.filter(~plx.any_horizontal(plx.col(*subset).is_null()))
180173

181174
@property
182175
def schema(self: Self) -> dict[str, DType]:
183176
if self._cached_schema is None:
184-
native_dtypes = self._native_frame.dtypes
177+
native_dtypes = self.native.dtypes
185178
self._cached_schema = {
186179
col: native_to_narwhals_dtype(
187180
native_dtypes[col], self._version, self._implementation
188181
)
189-
for col in self._native_frame.columns
182+
for col in self.native.columns
190183
}
191184
return self._cached_schema
192185

@@ -198,23 +191,21 @@ def drop(self: Self, columns: Sequence[str], *, strict: bool) -> Self:
198191
compliant_frame=self, columns=columns, strict=strict
199192
)
200193

201-
return self._from_native_frame(self._native_frame.drop(columns=to_drop))
194+
return self._from_native_frame(self.native.drop(columns=to_drop))
202195

203196
def with_row_index(self: Self, name: str) -> Self:
204197
# Implementation is based on the following StackOverflow reply:
205198
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
206199
return self._from_native_frame(
207-
add_row_index(
208-
self._native_frame, name, self._backend_version, self._implementation
209-
)
200+
add_row_index(self.native, name, self._backend_version, self._implementation)
210201
)
211202

212203
def rename(self: Self, mapping: Mapping[str, str]) -> Self:
213-
return self._from_native_frame(self._native_frame.rename(columns=mapping))
204+
return self._from_native_frame(self.native.rename(columns=mapping))
214205

215206
def head(self: Self, n: int) -> Self:
216207
return self._from_native_frame(
217-
self._native_frame.head(n=n, compute=False, npartitions=-1)
208+
self.native.head(n=n, compute=False, npartitions=-1)
218209
)
219210

220211
def unique(
@@ -224,17 +215,16 @@ def unique(
224215
keep: Literal["any", "none"],
225216
) -> Self:
226217
check_column_exists(self.columns, subset)
227-
native_frame = self._native_frame
228218
if keep == "none":
229219
subset = subset or self.columns
230220
token = generate_temporary_column_name(n_bytes=8, columns=subset)
231-
ser = native_frame.groupby(subset).size().rename(token)
221+
ser = self.native.groupby(subset).size().rename(token)
232222
ser = ser[ser == 1]
233223
unique = ser.reset_index().drop(columns=token)
234-
result = native_frame.merge(unique, on=subset, how="inner")
224+
result = self.native.merge(unique, on=subset, how="inner")
235225
else:
236226
mapped_keep = {"any": "first"}.get(keep, keep)
237-
result = native_frame.drop_duplicates(subset=subset, keep=mapped_keep)
227+
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
238228
return self._from_native_frame(result)
239229

240230
def sort(
@@ -243,14 +233,13 @@ def sort(
243233
descending: bool | Sequence[bool],
244234
nulls_last: bool,
245235
) -> Self:
246-
df = self._native_frame
247236
if isinstance(descending, bool):
248237
ascending: bool | list[bool] = not descending
249238
else:
250239
ascending = [not d for d in descending]
251-
na_position = "last" if nulls_last else "first"
240+
position = "last" if nulls_last else "first"
252241
return self._from_native_frame(
253-
df.sort_values(list(by), ascending=ascending, na_position=na_position)
242+
self.native.sort_values(list(by), ascending=ascending, na_position=position)
254243
)
255244

256245
def join(
@@ -268,15 +257,15 @@ def join(
268257
)
269258

270259
return self._from_native_frame(
271-
self._native_frame.assign(**{key_token: 0})
260+
self.native.assign(**{key_token: 0})
272261
.merge(
273-
other._native_frame.assign(**{key_token: 0}),
262+
other.native.assign(**{key_token: 0}),
274263
how="inner",
275264
left_on=key_token,
276265
right_on=key_token,
277266
suffixes=("", suffix),
278267
)
279-
.drop(columns=key_token),
268+
.drop(columns=key_token)
280269
)
281270

282271
if how == "anti":
@@ -289,7 +278,7 @@ def join(
289278
raise TypeError(msg)
290279
other_native = (
291280
select_columns_by_name(
292-
other._native_frame,
281+
other.native,
293282
list(right_on),
294283
self._backend_version,
295284
self._implementation,
@@ -299,7 +288,7 @@ def join(
299288
)
300289
.drop_duplicates()
301290
)
302-
df = self._native_frame.merge(
291+
df = self.native.merge(
303292
other_native,
304293
how="outer",
305294
indicator=indicator_token, # pyright: ignore[reportArgumentType]
@@ -316,7 +305,7 @@ def join(
316305
raise TypeError(msg)
317306
other_native = (
318307
select_columns_by_name(
319-
other._native_frame,
308+
other.native,
320309
list(right_on),
321310
self._backend_version,
322311
self._implementation,
@@ -327,18 +316,14 @@ def join(
327316
.drop_duplicates() # avoids potential rows duplication from inner join
328317
)
329318
return self._from_native_frame(
330-
self._native_frame.merge(
331-
other_native,
332-
how="inner",
333-
left_on=left_on,
334-
right_on=left_on,
319+
self.native.merge(
320+
other_native, how="inner", left_on=left_on, right_on=left_on
335321
)
336322
)
337323

338324
if how == "left":
339-
other_native = other._native_frame
340-
result_native = self._native_frame.merge(
341-
other_native,
325+
result_native = self.native.merge(
326+
other.native,
342327
how="left",
343328
left_on=left_on,
344329
right_on=right_on,
@@ -361,29 +346,27 @@ def join(
361346
assert right_on is not None # noqa: S101
362347

363348
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
364-
365-
other_native = other._native_frame
366-
other_native = other_native.rename(columns=right_on_mapper)
349+
other_native = other.native.rename(columns=right_on_mapper)
367350
check_column_names_are_unique(other_native.columns)
368351
right_on = list(right_on_mapper.values()) # we now have the suffixed keys
369352
return self._from_native_frame(
370-
self._native_frame.merge(
353+
self.native.merge(
371354
other_native,
372355
left_on=left_on,
373356
right_on=right_on,
374357
how="outer",
375358
suffixes=("", suffix),
376-
),
359+
)
377360
)
378361

379362
return self._from_native_frame(
380-
self._native_frame.merge(
381-
other._native_frame,
363+
self.native.merge(
364+
other.native,
382365
left_on=left_on,
383366
right_on=right_on,
384367
how=how,
385368
suffixes=("", suffix),
386-
),
369+
)
387370
)
388371

389372
def join_asof(
@@ -400,8 +383,8 @@ def join_asof(
400383
plx = self.__native_namespace__()
401384
return self._from_native_frame(
402385
plx.merge_asof(
403-
self._native_frame,
404-
other._native_frame,
386+
self.native,
387+
other.native,
405388
left_on=left_on,
406389
right_on=right_on,
407390
left_by=by_left,
@@ -417,11 +400,11 @@ def group_by(self: Self, *by: str, drop_null_keys: bool) -> DaskLazyGroupBy:
417400
return DaskLazyGroupBy(self, by, drop_null_keys=drop_null_keys)
418401

419402
def tail(self: Self, n: int) -> Self: # pragma: no cover
420-
native_frame = self._native_frame
403+
native_frame = self.native
421404
n_partitions = native_frame.npartitions
422405

423406
if n_partitions == 1:
424-
return self._from_native_frame(self._native_frame.tail(n=n, compute=False))
407+
return self._from_native_frame(self.native.tail(n=n, compute=False))
425408
else:
426409
msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
427410
raise NotImplementedError(msg)
@@ -446,7 +429,7 @@ def unpivot(
446429
value_name: str,
447430
) -> Self:
448431
return self._from_native_frame(
449-
self._native_frame.melt(
432+
self.native.melt(
450433
id_vars=index,
451434
value_vars=on,
452435
var_name=variable_name,

0 commit comments

Comments
 (0)