Skip to content

Commit caabc0e

Browse files
authored
refactor: Simplify PandasLikeNamespace.concat (#2368)
1 parent b1d40d2 commit caabc0e

File tree

5 files changed

+76
-162
lines changed

5 files changed

+76
-162
lines changed

narwhals/_pandas_like/dataframe.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from narwhals._pandas_like.utils import check_column_names_are_unique
2222
from narwhals._pandas_like.utils import convert_str_slice_to_int_slice
2323
from narwhals._pandas_like.utils import get_dtype_backend
24-
from narwhals._pandas_like.utils import horizontal_concat
2524
from narwhals._pandas_like.utils import native_to_narwhals_dtype
2625
from narwhals._pandas_like.utils import object_native_to_narwhals_dtype
2726
from narwhals._pandas_like.utils import pivot_table
@@ -504,11 +503,8 @@ def select(self: PandasLikeDataFrame, *exprs: PandasLikeExpr) -> PandasLikeDataF
504503
# return empty dataframe, like Polars does
505504
return self._with_native(self.native.__class__(), validate_column_names=False)
506505
new_series = align_series_full_broadcast(*new_series)
507-
df = horizontal_concat(
508-
[s.native for s in new_series],
509-
implementation=self._implementation,
510-
backend_version=self._backend_version,
511-
)
506+
namespace = self.__narwhals_namespace__()
507+
df = namespace._concat_horizontal([s.native for s in new_series])
512508
return self._with_native(df, validate_column_names=True)
513509

514510
def drop_nulls(
@@ -531,13 +527,7 @@ def with_row_index(self: Self, name: str) -> Self:
531527
row_index = namespace._series.from_iterable(
532528
range(len(frame)), context=self, index=frame.index
533529
).alias(name)
534-
return self._with_native(
535-
horizontal_concat(
536-
[row_index.native, frame],
537-
implementation=self._implementation,
538-
backend_version=self._backend_version,
539-
)
540-
)
530+
return self._with_native(namespace._concat_horizontal([row_index.native, frame]))
541531

542532
def row(self: Self, index: int) -> tuple[Any, ...]:
543533
return tuple(x for x in self.native.iloc[index])
@@ -571,11 +561,8 @@ def with_columns(
571561
series = self.native[name]
572562
to_concat.append(series)
573563
to_concat.extend(self._extract_comparand(s) for s in name_columns.values())
574-
df = horizontal_concat(
575-
to_concat,
576-
implementation=self._implementation,
577-
backend_version=self._backend_version,
578-
)
564+
namespace = self.__narwhals_namespace__()
565+
df = namespace._concat_horizontal(to_concat)
579566
return self._with_native(df, validate_column_names=False)
580567

581568
def rename(self: Self, mapping: Mapping[str, str]) -> Self:

narwhals/_pandas_like/group_by.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from narwhals._compliant import EagerGroupBy
1313
from narwhals._expression_parsing import evaluate_output_names_and_aliases
14-
from narwhals._pandas_like.utils import horizontal_concat
1514
from narwhals._pandas_like.utils import select_columns_by_name
1615
from narwhals._pandas_like.utils import set_columns
1716
from narwhals.utils import find_stacklevel
@@ -233,11 +232,8 @@ def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR
233232
pass
234233
msg = f"Expected unique output names, got:{msg}"
235234
raise ValueError(msg)
236-
result = horizontal_concat(
237-
dfs=result_aggs,
238-
implementation=implementation,
239-
backend_version=backend_version,
240-
)
235+
namespace = self.compliant.__narwhals_namespace__()
236+
result = namespace._concat_horizontal(result_aggs)
241237
else:
242238
# No aggregation provided
243239
result = self.compliant.__native_namespace__().DataFrame(

narwhals/_pandas_like/namespace.py

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
import operator
4+
import warnings
45
from functools import reduce
56
from typing import TYPE_CHECKING
67
from typing import Any
78
from typing import Iterable
9+
from typing import Literal
10+
from typing import Sequence
811

912
from narwhals._compliant import CompliantThen
1013
from narwhals._compliant import EagerNamespace
@@ -16,20 +19,21 @@
1619
from narwhals._pandas_like.selectors import PandasSelectorNamespace
1720
from narwhals._pandas_like.series import PandasLikeSeries
1821
from narwhals._pandas_like.utils import align_series_full_broadcast
19-
from narwhals._pandas_like.utils import diagonal_concat
20-
from narwhals._pandas_like.utils import horizontal_concat
21-
from narwhals._pandas_like.utils import vertical_concat
2222
from narwhals.utils import import_dtypes_module
2323

2424
if TYPE_CHECKING:
2525
import pandas as pd
2626
from typing_extensions import Self
2727

28+
from narwhals._pandas_like.typing import NDFrameT
2829
from narwhals.dtypes import DType
2930
from narwhals.typing import ConcatMethod
3031
from narwhals.utils import Implementation
3132
from narwhals.utils import Version
3233

34+
VERTICAL: Literal[0] = 0
35+
HORIZONTAL: Literal[1] = 1
36+
3337

3438
class PandasLikeNamespace(
3539
EagerNamespace[
@@ -223,48 +227,66 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
223227
context=self,
224228
)
225229

230+
@property
231+
def _concat(self): # type: ignore[no-untyped-def] # noqa: ANN202
232+
"""Return the **native** equivalent of `pd.concat`."""
233+
# NOTE: Leave un-annotated to allow `@overload` matching via inference.
234+
if TYPE_CHECKING:
235+
import pandas as pd
236+
237+
return pd.concat
238+
return self._implementation.to_native_namespace().concat
239+
240+
def _concat_diagonal(self, dfs: Sequence[pd.DataFrame], /) -> pd.DataFrame:
241+
if self._implementation.is_pandas() and self._backend_version < (3,):
242+
if self._backend_version < (1,):
243+
return self._concat(dfs, axis=VERTICAL, copy=False, sort=False)
244+
return self._concat(dfs, axis=VERTICAL, copy=False)
245+
return self._concat(dfs, axis=VERTICAL)
246+
247+
def _concat_horizontal(self, dfs: Sequence[NDFrameT], /) -> pd.DataFrame:
248+
if self._implementation.is_cudf():
249+
with warnings.catch_warnings():
250+
warnings.filterwarnings(
251+
"ignore",
252+
message="The behavior of array concatenation with empty entries is deprecated",
253+
category=FutureWarning,
254+
)
255+
return self._concat(dfs, axis=HORIZONTAL)
256+
elif self._implementation.is_pandas() and self._backend_version < (3,):
257+
return self._concat(dfs, axis=HORIZONTAL, copy=False)
258+
return self._concat(dfs, axis=HORIZONTAL)
259+
260+
def _concat_vertical(self, dfs: Sequence[pd.DataFrame], /) -> pd.DataFrame:
261+
cols_0 = dfs[0].columns
262+
for i, df in enumerate(dfs[1:], start=1):
263+
cols_current = df.columns
264+
if not (
265+
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
266+
):
267+
msg = (
268+
"unable to vstack, column names don't match:\n"
269+
f" - dataframe 0: {cols_0.to_list()}\n"
270+
f" - dataframe {i}: {cols_current.to_list()}\n"
271+
)
272+
raise TypeError(msg)
273+
if self._implementation.is_pandas() and self._backend_version < (3,):
274+
return self._concat(dfs, axis=VERTICAL, copy=False)
275+
return self._concat(dfs, axis=VERTICAL)
276+
226277
def concat(
227278
self, items: Iterable[PandasLikeDataFrame], *, how: ConcatMethod
228279
) -> PandasLikeDataFrame:
229-
dfs: list[Any] = [item._native_frame for item in items]
280+
dfs: list[pd.DataFrame] = [item.native for item in items]
230281
if how == "horizontal":
231-
return PandasLikeDataFrame(
232-
horizontal_concat(
233-
dfs,
234-
implementation=self._implementation,
235-
backend_version=self._backend_version,
236-
),
237-
implementation=self._implementation,
238-
backend_version=self._backend_version,
239-
version=self._version,
240-
validate_column_names=True,
241-
)
242-
if how == "vertical":
243-
return PandasLikeDataFrame(
244-
vertical_concat(
245-
dfs,
246-
implementation=self._implementation,
247-
backend_version=self._backend_version,
248-
),
249-
implementation=self._implementation,
250-
backend_version=self._backend_version,
251-
version=self._version,
252-
validate_column_names=True,
253-
)
254-
255-
if how == "diagonal":
256-
return PandasLikeDataFrame(
257-
diagonal_concat(
258-
dfs,
259-
implementation=self._implementation,
260-
backend_version=self._backend_version,
261-
),
262-
implementation=self._implementation,
263-
backend_version=self._backend_version,
264-
version=self._version,
265-
validate_column_names=True,
266-
)
267-
raise NotImplementedError
282+
native = self._concat_horizontal(dfs)
283+
elif how == "vertical":
284+
native = self._concat_vertical(dfs)
285+
elif how == "diagonal":
286+
native = self._concat_diagonal(dfs)
287+
else:
288+
raise NotImplementedError
289+
return self._dataframe.from_native(native, context=self)
268290

269291
def when(self: Self, predicate: PandasLikeExpr) -> PandasWhen:
270292
return PandasWhen.from_expr(predicate, context=self)

narwhals/_pandas_like/typing.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from __future__ import annotations # pragma: no cover
22

33
from typing import TYPE_CHECKING # pragma: no cover
4-
from typing import Union # pragma: no cover
54

65
if TYPE_CHECKING:
7-
import sys
6+
from typing import Any
7+
from typing import TypeVar
88

9-
if sys.version_info >= (3, 10):
10-
from typing import TypeAlias
11-
else:
12-
from typing_extensions import TypeAlias
9+
import pandas as pd
10+
from typing_extensions import TypeAlias
1311

1412
from narwhals._pandas_like.expr import PandasLikeExpr
1513
from narwhals._pandas_like.series import PandasLikeSeries
1614

17-
IntoPandasLikeExpr: TypeAlias = Union[PandasLikeExpr, PandasLikeSeries]
15+
IntoPandasLikeExpr: TypeAlias = "PandasLikeExpr | PandasLikeSeries"
16+
NDFrameT = TypeVar("NDFrameT", "pd.DataFrame", "pd.Series[Any]")

narwhals/_pandas_like/utils.py

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import functools
44
import re
5-
import warnings
65
from contextlib import suppress
76
from typing import TYPE_CHECKING
87
from typing import Any
@@ -130,95 +129,6 @@ def align_and_extract_native(
130129
return lhs.native, rhs
131130

132131

133-
def horizontal_concat(
134-
dfs: list[Any], *, implementation: Implementation, backend_version: tuple[int, ...]
135-
) -> Any:
136-
"""Concatenate (native) DataFrames horizontally.
137-
138-
Should be in namespace.
139-
"""
140-
if implementation is Implementation.CUDF:
141-
with warnings.catch_warnings():
142-
warnings.filterwarnings(
143-
"ignore",
144-
message="The behavior of array concatenation with empty entries is deprecated",
145-
category=FutureWarning,
146-
)
147-
return implementation.to_native_namespace().concat(dfs, axis=1)
148-
149-
if implementation.is_pandas_like():
150-
extra_kwargs = (
151-
{"copy": False}
152-
if implementation is Implementation.PANDAS and backend_version < (3,)
153-
else {}
154-
)
155-
return implementation.to_native_namespace().concat(dfs, axis=1, **extra_kwargs)
156-
157-
else: # pragma: no cover
158-
msg = f"Expected pandas-like implementation ({PANDAS_LIKE_IMPLEMENTATION}), found {implementation}"
159-
raise TypeError(msg)
160-
161-
162-
def vertical_concat(
163-
dfs: list[Any], *, implementation: Implementation, backend_version: tuple[int, ...]
164-
) -> Any:
165-
"""Concatenate (native) DataFrames vertically.
166-
167-
Should be in namespace.
168-
"""
169-
if not dfs:
170-
msg = "No dataframes to concatenate" # pragma: no cover
171-
raise AssertionError(msg)
172-
cols_0 = dfs[0].columns
173-
for i, df in enumerate(dfs[1:], start=1):
174-
cols_current = df.columns
175-
if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()):
176-
msg = (
177-
"unable to vstack, column names don't match:\n"
178-
f" - dataframe 0: {cols_0.to_list()}\n"
179-
f" - dataframe {i}: {cols_current.to_list()}\n"
180-
)
181-
raise TypeError(msg)
182-
183-
if implementation in PANDAS_LIKE_IMPLEMENTATION:
184-
extra_kwargs = (
185-
{"copy": False}
186-
if implementation is Implementation.PANDAS and backend_version < (3,)
187-
else {}
188-
)
189-
return implementation.to_native_namespace().concat(dfs, axis=0, **extra_kwargs)
190-
191-
else: # pragma: no cover
192-
msg = f"Expected pandas-like implementation ({PANDAS_LIKE_IMPLEMENTATION}), found {implementation}"
193-
raise TypeError(msg)
194-
195-
196-
def diagonal_concat(
197-
dfs: list[Any], *, implementation: Implementation, backend_version: tuple[int, ...]
198-
) -> Any:
199-
"""Concatenate (native) DataFrames diagonally.
200-
201-
Should be in namespace.
202-
"""
203-
if not dfs:
204-
msg = "No dataframes to concatenate" # pragma: no cover
205-
raise AssertionError(msg)
206-
207-
if implementation in PANDAS_LIKE_IMPLEMENTATION:
208-
extra_kwargs = (
209-
{"copy": False, "sort": False}
210-
if implementation is Implementation.PANDAS and backend_version < (1,)
211-
else {"copy": False}
212-
if implementation is Implementation.PANDAS and backend_version < (3,)
213-
else {}
214-
)
215-
return implementation.to_native_namespace().concat(dfs, axis=0, **extra_kwargs)
216-
217-
else: # pragma: no cover
218-
msg = f"Expected pandas-like implementation ({PANDAS_LIKE_IMPLEMENTATION}), found {implementation}"
219-
raise TypeError(msg)
220-
221-
222132
def set_index(
223133
obj: T,
224134
index: Any,

0 commit comments

Comments
 (0)