Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def collect(
raise ValueError(msg) # pragma: no cover

def simple_select(self: Self, *column_names: str) -> Self:
return self._from_native_frame(self._native_frame.select(*column_names))
return self._from_native_frame(self._native_frame.select(*column_names)) # pyright: ignore[reportArgumentType]

def aggregate(
self: Self,
Expand All @@ -259,7 +259,7 @@ def aggregate(
new_columns = evaluate_exprs(self, *exprs)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
return self._from_native_frame(self._native_frame.agg(*new_columns_list)) # pyright: ignore[reportArgumentType]

def select(
self: Self,
Expand All @@ -274,17 +274,17 @@ def select(
return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
return self._from_native_frame(self._native_frame.select(*new_columns_list))
return self._from_native_frame(self._native_frame.select(*new_columns_list)) # pyright: ignore[reportArgumentType]

def with_columns(self: Self, *exprs: SparkLikeExpr) -> Self:
new_columns = evaluate_exprs(self, *exprs)
return self._from_native_frame(self._native_frame.withColumns(dict(new_columns)))
return self._from_native_frame(self._native_frame.withColumns(dict(new_columns))) # pyright: ignore[reportArgumentType]

def filter(self: Self, predicate: SparkLikeExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
condition = predicate._call(self)[0]
spark_df = self._native_frame.where(condition)
return self._from_native_frame(spark_df)
spark_df = self._native_frame.where(condition) # pyright: ignore[reportArgumentType]
return self._from_native_frame(spark_df) # pyright: ignore[reportArgumentType]

@property
def schema(self: Self) -> dict[str, DType]:
Expand All @@ -307,10 +307,10 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._from_native_frame(self._native_frame.drop(*columns_to_drop))
return self._from_native_frame(self._native_frame.drop(*columns_to_drop)) # pyright: ignore[reportArgumentType]

def head(self: Self, n: int) -> Self:
return self._from_native_frame(self._native_frame.limit(num=n))
return self._from_native_frame(self._native_frame.limit(num=n)) # pyright: ignore[reportArgumentType]

def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy:
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
Expand Down Expand Up @@ -340,18 +340,18 @@ def sort(
)

sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)]
return self._from_native_frame(self._native_frame.sort(*sort_cols))
return self._from_native_frame(self._native_frame.sort(*sort_cols)) # pyright: ignore[reportArgumentType]

def drop_nulls(self: Self, subset: list[str] | None) -> Self:
return self._from_native_frame(self._native_frame.dropna(subset=subset))
return self._from_native_frame(self._native_frame.dropna(subset=subset)) # pyright: ignore[reportArgumentType]

def rename(self: Self, mapping: dict[str, str]) -> Self:
rename_mapping = {
colname: mapping.get(colname, colname) for colname in self.columns
}
return self._from_native_frame(
self._native_frame.select(
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
[self._F.col(old).alias(new) for old, new in rename_mapping.items()] # pyright: ignore[reportArgumentType]
)
)

Expand All @@ -365,7 +365,7 @@ def unique(
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
raise ValueError(msg)
check_column_exists(self.columns, subset)
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset)) # pyright: ignore[reportArgumentType]

def join(
self: Self,
Expand Down Expand Up @@ -409,7 +409,7 @@ def join(
]
)
return self._from_native_frame(
self_native.join(other_native, on=left_on, how=how).select(col_order)
self_native.join(other_native, on=left_on, how=how).select(col_order) # pyright: ignore[reportArgumentType]
)

def explode(self: Self, columns: list[str]) -> Self:
Expand Down Expand Up @@ -445,7 +445,7 @@ def explode(self: Self, columns: list[str]) -> Self:
else self._F.explode_outer(col_name).alias(col_name)
for col_name in column_names
]
),
), # pyright: ignore[reportArgumentType]
)
elif self._implementation.is_sqlframe():
# Not every sqlframe dialect supports `explode_outer` function
Expand All @@ -466,14 +466,14 @@ def null_condition(col_name: str) -> Column:
for col_name in column_names
]
).union(
native_frame.filter(null_condition(columns[0])).select(
native_frame.filter(null_condition(columns[0])).select( # pyright: ignore[reportArgumentType]
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.lit(None).alias(col_name)
for col_name in column_names
]
)
) # pyright: ignore[reportArgumentType]
),
)
else: # pragma: no cover
Expand Down Expand Up @@ -508,4 +508,4 @@ def unpivot(
)
if index is None:
unpivoted_native_frame = unpivoted_native_frame.drop(*ids)
return self._from_native_frame(unpivoted_native_frame)
return self._from_native_frame(unpivoted_native_frame) # pyright: ignore[reportArgumentType]
2 changes: 1 addition & 1 deletion narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:

if not agg_columns:
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.select(*self._keys).dropDuplicates()
self._compliant_frame._native_frame.select(*self._keys).dropDuplicates() # pyright: ignore[reportArgumentType]
)
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.groupBy(*self._keys).agg(*agg_columns)
Expand Down
Loading