Skip to content

Commit c743de3

Browse files
committed
chore(typing): Kinda type pandas_like.utils.select_columns_by_name`
- Somewhat of a resurrection of #2227 - But this time building on #2693
1 parent 5168fe0 commit c743de3

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

narwhals/_dask/dataframe.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from types import ModuleType
2525

2626
import dask.dataframe.dask_expr as dx
27-
from typing_extensions import Self, TypeIs
27+
from typing_extensions import Self, TypeAlias, TypeIs
2828

2929
from narwhals._compliant.typing import CompliantDataFrameAny
3030
from narwhals._dask.expr import DaskExpr
@@ -35,6 +35,13 @@
3535
from narwhals.dtypes import DType
3636
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
3737

38+
Incomplete: TypeAlias = "Any"
39+
"""Using `_pandas_like` utils with `_dask`.
40+
41+
Typing this correctly will complicate the `_pandas_like`-side.
42+
Very low priority until `dask` adds typing.
43+
"""
44+
3845

3946
class DaskLazyFrame(
4047
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"]
@@ -158,8 +165,9 @@ def filter(self, predicate: DaskExpr) -> Self:
158165
return self._with_native(self.native.loc[mask])
159166

160167
def simple_select(self, *column_names: str) -> Self:
168+
df: Incomplete = self.native
161169
native = select_columns_by_name(
162-
self.native, list(column_names), self._backend_version, self._implementation
170+
df, list(column_names), self._backend_version, self._implementation
163171
)
164172
return self._with_native(native)
165173

@@ -170,8 +178,9 @@ def aggregate(self, *exprs: DaskExpr) -> Self:
170178

171179
def select(self, *exprs: DaskExpr) -> Self:
172180
new_series = evaluate_exprs(self, *exprs)
181+
df: Incomplete = self.native
173182
df = select_columns_by_name(
174-
self.native.assign(**dict(new_series)),
183+
df.assign(**dict(new_series)),
175184
[s[0] for s in new_series],
176185
self._backend_version,
177186
self._implementation,
@@ -269,6 +278,7 @@ def join( # noqa: C901
269278
)
270279
.drop(columns=key_token)
271280
)
281+
other_native: Incomplete = other.native
272282

273283
if how == "anti":
274284
indicator_token = generate_temporary_column_name(
@@ -280,7 +290,7 @@ def join( # noqa: C901
280290
raise TypeError(msg)
281291
other_native = (
282292
select_columns_by_name(
283-
other.native,
293+
other_native,
284294
list(right_on),
285295
self._backend_version,
286296
self._implementation,
@@ -307,7 +317,7 @@ def join( # noqa: C901
307317
raise TypeError(msg)
308318
other_native = (
309319
select_columns_by_name(
310-
other.native,
320+
other_native,
311321
list(right_on),
312322
self._backend_version,
313323
self._implementation,

narwhals/_namespace.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,14 @@
101101
class _NativeDask(Protocol):
102102
_partition_type: type[pd.DataFrame]
103103

104-
class _CuDFDataFrame(NativeFrame, Protocol):
104+
class _BasePandasLikeFrame(NativeFrame, Protocol):
105+
@property
106+
def shape(self) -> tuple[int, int]: ...
107+
def __getitem__(self, key: Any, /) -> Any: ...
108+
@property
109+
def loc(self) -> Any: ...
110+
111+
class _CuDFDataFrame(_BasePandasLikeFrame, Protocol):
105112
def to_pylibcudf(self, *args: Any, **kwds: Any) -> Any: ...
106113

107114
class _CuDFSeries(NativeSeries, Protocol):
@@ -114,7 +121,7 @@ def __pyarrow_result__(self, *args: Any, **kwds: Any) -> Any: ...
114121
def __pandas_result__(self, *args: Any, **kwds: Any) -> Any: ...
115122
def __polars_result__(self, *args: Any, **kwds: Any) -> Any: ...
116123

117-
class _ModinDataFrame(NativeFrame, Protocol):
124+
class _ModinDataFrame(_BasePandasLikeFrame, Protocol):
118125
_pandas_class: type[pd.DataFrame]
119126

120127
class _ModinSeries(NativeSeries, Protocol):

narwhals/_pandas_like/utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from narwhals._pandas_like.expr import PandasLikeExpr
2727
from narwhals._pandas_like.series import PandasLikeSeries
28+
from narwhals._pandas_like.typing import NativeDataFrameT
2829
from narwhals.dtypes import DType
2930
from narwhals.typing import DTypeBackend, IntoDType, TimeUnit, _1DArray
3031

@@ -558,35 +559,37 @@ def calculate_timestamp_date(s: pd.Series[int], time_unit: str) -> pd.Series[int
558559

559560

560561
def select_columns_by_name(
561-
df: T,
562+
df: NativeDataFrameT,
562563
column_names: list[str] | _1DArray, # NOTE: Cannot be a tuple!
563564
backend_version: tuple[int, ...],
564565
implementation: Implementation,
565-
) -> T:
566+
) -> NativeDataFrameT | Any:
566567
"""Select columns by name.
567568
568569
Prefer this over `df.loc[:, column_names]` as it's
569570
generally more performant.
570571
"""
571-
if len(column_names) == df.shape[1] and all(column_names == df.columns): # type: ignore[attr-defined]
572-
return df
573-
if (df.columns.dtype.kind == "b") or ( # type: ignore[attr-defined]
572+
if len(column_names) == df.shape[1]: # noqa: SIM102
573+
# NOTE: I'm pretty unsure on how this doesn't trigger a runtime error
574+
if all(column_names == df.columns): # type: ignore[arg-type]
575+
return df
576+
if (df.columns.dtype.kind == "b") or (
574577
implementation is Implementation.PANDAS and backend_version < (1, 5)
575578
):
576579
# See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122
577580
# for why we need this
578581
if error := check_columns_exist(
579582
column_names, # type: ignore[arg-type]
580-
available=df.columns.tolist(), # type: ignore[attr-defined]
583+
available=df.columns.tolist(),
581584
):
582585
raise error
583-
return df.loc[:, column_names] # type: ignore[attr-defined]
586+
return df.loc[:, column_names]
584587
try:
585-
return df[column_names] # type: ignore[index]
588+
return df[column_names]
586589
except KeyError as e:
587590
if error := check_columns_exist(
588591
column_names, # type: ignore[arg-type]
589-
available=df.columns.tolist(), # type: ignore[attr-defined]
592+
available=df.columns.tolist(),
590593
):
591594
raise error from e
592595
raise

0 commit comments

Comments
 (0)