diff --git a/narwhals/_pandas_like/typing.py b/narwhals/_pandas_like/typing.py index 9def250676..6e7d28cac4 100644 --- a/narwhals/_pandas_like/typing.py +++ b/narwhals/_pandas_like/typing.py @@ -5,13 +5,40 @@ if TYPE_CHECKING: import sys + from typing import Any + + if sys.version_info >= (3, 13): + from typing import TypeVar + else: + from typing_extensions import TypeVar if sys.version_info >= (3, 10): from typing import TypeAlias else: from typing_extensions import TypeAlias + import cudf + import modin.pandas as mpd + import pandas as pd + from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries IntoPandasLikeExpr: TypeAlias = Union[PandasLikeExpr, PandasLikeSeries] + + DataFrameT = TypeVar( + "DataFrameT", pd.DataFrame, mpd.DataFrame, cudf.DataFrame, default=pd.DataFrame + ) + SeriesT = TypeVar( + "SeriesT", pd.Series[Any], mpd.Series, cudf.Series[Any], default=pd.Series[Any] + ) + NDFrameT = TypeVar( + "NDFrameT", + pd.DataFrame, + mpd.DataFrame, + cudf.DataFrame, + pd.Series[Any], + mpd.Series, + cudf.Series[Any], + default=pd.DataFrame, + ) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 2191c33a8f..e11b5688ef 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -29,6 +29,8 @@ from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries + from narwhals._pandas_like.typing import DataFrameT + from narwhals._pandas_like.typing import NDFrameT from narwhals.dtypes import DType from narwhals.typing import DTypeBackend from narwhals.typing import TimeUnit @@ -262,20 +264,20 @@ def native_series_from_iterable( def set_index( - obj: T, + obj: NDFrameT, index: Any, *, implementation: Implementation, backend_version: tuple[int, ...], -) -> T: +) -> NDFrameT: """Wrapper around pandas' set_axis to set object index. We can set `copy` / `inplace` based on implementation/version. """ if implementation is Implementation.CUDF: # pragma: no cover - obj = obj.copy(deep=False) # type: ignore[attr-defined] - obj.index = index # type: ignore[attr-defined] - return obj + cudf_frame = obj.copy(deep=False) + cudf_frame.index = index + return cast("NDFrameT", cudf_frame) # type: ignore[redundant-cast] if implementation is Implementation.PANDAS and ( backend_version < (1,) ): # pragma: no cover @@ -288,24 +290,25 @@ def set_index( kwargs["copy"] = False else: # pragma: no cover pass - return obj.set_axis(index, axis=0, **kwargs) # type: ignore[attr-defined] + nd_frame = obj.set_axis(index, axis=0, **kwargs) + return cast("NDFrameT", nd_frame) # type: ignore[redundant-cast] def set_columns( - obj: T, + obj: NDFrameT, columns: list[str], *, implementation: Implementation, backend_version: tuple[int, ...], -) -> T: +) -> NDFrameT: """Wrapper around pandas' set_axis to set object columns. We can set `copy` / `inplace` based on implementation/version. """ if implementation is Implementation.CUDF: # pragma: no cover - obj = obj.copy(deep=False) # type: ignore[attr-defined] - obj.columns = columns # type: ignore[attr-defined] - return obj + cudf_frame = obj.copy(deep=False) + cudf_frame.columns = cast("pd.Index[str]", columns) + return cast("NDFrameT", cudf_frame) # type: ignore[redundant-cast] if implementation is Implementation.PANDAS and ( backend_version < (1,) ): # pragma: no cover @@ -318,22 +321,24 @@ def set_columns( kwargs["copy"] = False else: # pragma: no cover pass - return obj.set_axis(columns, axis=1, **kwargs) # type: ignore[attr-defined] + nd_frame = obj.set_axis(columns, axis=1, **kwargs) + return cast("NDFrameT", nd_frame) # type: ignore[redundant-cast] def rename( - obj: T, + obj: NDFrameT, *args: Any, implementation: Implementation, backend_version: tuple[int, ...], **kwargs: Any, -) -> T: +) -> NDFrameT: """Wrapper around pandas' rename so that we can set `copy` based on implementation/version.""" - if implementation is Implementation.PANDAS and ( - backend_version >= (3,) - ): # pragma: no cover - return obj.rename(*args, **kwargs) # type: ignore[attr-defined] - return obj.rename(*args, **kwargs, copy=False) # type: ignore[attr-defined] + nd_frame = ( + obj.rename(*args, **kwargs, inplace=False) + if implementation.is_pandas() and (backend_version >= (3,)) + else obj.rename(*args, **kwargs, copy=False, inplace=False) + ) + return cast("NDFrameT", nd_frame) # type: ignore[redundant-cast] @functools.lru_cache(maxsize=16) @@ -749,34 +754,34 @@ def calculate_timestamp_date(s: pd.Series[int], time_unit: str) -> pd.Series[int def select_columns_by_name( - df: T, + df: DataFrameT, column_names: list[str] | _1DArray, # NOTE: Cannot be a tuple! backend_version: tuple[int, ...], implementation: Implementation, -) -> T: +) -> DataFrameT: """Select columns by name. Prefer this over `df.loc[:, column_names]` as it's generally more performant. """ - if len(column_names) == df.shape[1] and all(column_names == df.columns): # type: ignore[attr-defined] + if len(column_names) == df.shape[1] and (df.columns == column_names).all(): return df - if (df.columns.dtype.kind == "b") or ( # type: ignore[attr-defined] - implementation is Implementation.PANDAS and backend_version < (1, 5) + if (df.columns.dtype.kind == "b") or ( + implementation.is_pandas() and backend_version < (1, 5) ): # See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122 # for why we need this - available_columns = df.columns.tolist() # type: ignore[attr-defined] + available_columns = df.columns.tolist() missing_columns = [x for x in column_names if x not in available_columns] if missing_columns: # pragma: no cover raise ColumnNotFoundError.from_missing_and_available_column_names( missing_columns, available_columns ) - return df.loc[:, column_names] # type: ignore[attr-defined] + return cast("DataFrameT", df.loc[:, column_names]) # type: ignore[redundant-cast] try: - return df[column_names] # type: ignore[index] + return df[column_names] except KeyError as e: - available_columns = df.columns.tolist() # type: ignore[attr-defined] + available_columns = df.columns.tolist() missing_columns = [x for x in column_names if x not in available_columns] raise ColumnNotFoundError.from_missing_and_available_column_names( missing_columns, available_columns