Skip to content

Commit 78f27af

Browse files
authored
refactor: Simplify (Arrow|PandasLike)Namespace.concat (#2381)
* refactor: refactor: Simplify `ArrowNamespace.concat` Related #2368 * chore(typing): Ignore stub issues Fixed in zen-xu/pyarrow-stubs#203 * cov https://github.com/narwhals-dev/narwhals/actions/runs/14422362924/job/40446526457?pr=2381 * refactor: Implement `EagerNamespace.concat` * refactor: Avoid redundant unique check Has no coverage, would be caught later with a nicer error anyway w/ https://github.com/narwhals-dev/narwhals/blob/41871f775589359c563150526d03935449709d7d/narwhals/utils.py#L1466-L1475 * refactor: Consistently use `chain.from_iterable` * chore: Remove notes One became a type, the other is moving to github
1 parent e943a2a commit 78f27af

File tree

4 files changed

+47
-98
lines changed

4 files changed

+47
-98
lines changed

narwhals/_arrow/namespace.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from itertools import chain
66
from typing import TYPE_CHECKING
77
from typing import Any
8-
from typing import Iterable
98
from typing import Literal
9+
from typing import Sequence
1010

1111
import pyarrow as pa
1212
import pyarrow.compute as pc
@@ -17,9 +17,6 @@
1717
from narwhals._arrow.series import ArrowSeries
1818
from narwhals._arrow.utils import align_series_full_broadcast
1919
from narwhals._arrow.utils import cast_to_comparable_string_types
20-
from narwhals._arrow.utils import diagonal_concat
21-
from narwhals._arrow.utils import horizontal_concat
22-
from narwhals._arrow.utils import vertical_concat
2320
from narwhals._compliant import CompliantThen
2421
from narwhals._compliant import EagerNamespace
2522
from narwhals._compliant import EagerWhen
@@ -34,7 +31,6 @@
3431
from narwhals._arrow.typing import ArrowChunkedArray
3532
from narwhals._arrow.typing import Incomplete
3633
from narwhals.dtypes import DType
37-
from narwhals.typing import ConcatMethod
3834
from narwhals.utils import Version
3935

4036

@@ -211,30 +207,29 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
211207
context=self,
212208
)
213209

214-
def concat(
215-
self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod
216-
) -> ArrowDataFrame:
217-
dfs = [item.native for item in items]
218-
219-
if not dfs:
220-
msg = "No dataframes to concatenate" # pragma: no cover
221-
raise AssertionError(msg)
222-
223-
if how == "horizontal":
224-
result_table = horizontal_concat(dfs)
225-
elif how == "vertical":
226-
result_table = vertical_concat(dfs)
227-
elif how == "diagonal":
228-
result_table = diagonal_concat(dfs, self._backend_version)
229-
else:
230-
raise NotImplementedError
231-
232-
return ArrowDataFrame(
233-
result_table,
234-
backend_version=self._backend_version,
235-
version=self._version,
236-
validate_column_names=True,
237-
)
210+
# NOTE: Stub issue fixed in https://github.com/zen-xu/pyarrow-stubs/pull/203
211+
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
212+
if self._backend_version >= (14,):
213+
return pa.concat_tables(dfs, promote_options="default") # type: ignore[arg-type]
214+
return pa.concat_tables(dfs, promote=True) # type: ignore[arg-type] # pragma: no cover
215+
216+
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
217+
names = list(chain.from_iterable(df.column_names for df in dfs))
218+
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
219+
return pa.Table.from_arrays(arrays, names=names)
220+
221+
def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
222+
cols_0 = dfs[0].column_names
223+
for i, df in enumerate(dfs[1:], start=1):
224+
cols_current = df.column_names
225+
if cols_current != cols_0:
226+
msg = (
227+
"unable to vstack, column names don't match:\n"
228+
f" - dataframe 0: {cols_0}\n"
229+
f" - dataframe {i}: {cols_current}\n"
230+
)
231+
raise TypeError(msg)
232+
return pa.concat_tables(dfs) # type: ignore[arg-type]
238233

239234
@property
240235
def selectors(self: Self) -> ArrowSelectorNamespace:

narwhals/_arrow/utils.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from functools import lru_cache
4-
from itertools import chain
54
from typing import TYPE_CHECKING
65
from typing import Any
76
from typing import Iterable
@@ -280,52 +279,6 @@ def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]:
280279
return reshaped
281280

282281

283-
def horizontal_concat(dfs: list[pa.Table]) -> pa.Table:
284-
"""Concatenate (native) DataFrames horizontally.
285-
286-
Should be in namespace.
287-
"""
288-
names = [name for df in dfs for name in df.column_names]
289-
290-
if len(set(names)) < len(names): # pragma: no cover
291-
msg = "Expected unique column names"
292-
raise ValueError(msg)
293-
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
294-
return pa.Table.from_arrays(arrays, names=names)
295-
296-
297-
def vertical_concat(dfs: list[pa.Table]) -> pa.Table:
298-
"""Concatenate (native) DataFrames vertically.
299-
300-
Should be in namespace.
301-
"""
302-
cols_0 = dfs[0].column_names
303-
for i, df in enumerate(dfs[1:], start=1):
304-
cols_current = df.column_names
305-
if cols_current != cols_0:
306-
msg = (
307-
"unable to vstack, column names don't match:\n"
308-
f" - dataframe 0: {cols_0}\n"
309-
f" - dataframe {i}: {cols_current}\n"
310-
)
311-
raise TypeError(msg)
312-
313-
return pa.concat_tables(dfs)
314-
315-
316-
def diagonal_concat(dfs: list[pa.Table], backend_version: tuple[int, ...]) -> pa.Table:
317-
"""Concatenate (native) DataFrames diagonally.
318-
319-
Should be in namespace.
320-
"""
321-
kwargs: dict[str, Any] = (
322-
{"promote": True}
323-
if backend_version < (14, 0, 0)
324-
else {"promote_options": "default"}
325-
)
326-
return pa.concat_tables(dfs, **kwargs)
327-
328-
329282
def floordiv_compat(left: Any, right: Any) -> Any:
330283
# The following lines are adapted from pandas' pyarrow implementation.
331284
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154

narwhals/_compliant/namespace.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from narwhals._compliant.typing import EagerExprT
1919
from narwhals._compliant.typing import EagerSeriesT
2020
from narwhals._compliant.typing import LazyExprT
21+
from narwhals._compliant.typing import NativeFrameT
2122
from narwhals._compliant.typing import NativeFrameT_co
22-
from narwhals._compliant.typing import NativeFrameT_contra
2323
from narwhals._compliant.typing import NativeSeriesT
2424
from narwhals.dependencies import is_numpy_array_2d
2525
from narwhals.utils import exclude_column_names
@@ -130,9 +130,7 @@ def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
130130

131131
class EagerNamespace(
132132
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
133-
Protocol[
134-
EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT_contra, NativeSeriesT
135-
],
133+
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
136134
):
137135
@property
138136
def _dataframe(self) -> type[EagerDataFrameT]: ...
@@ -143,11 +141,11 @@ def when(
143141
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
144142

145143
@overload
146-
def from_native(self, data: NativeFrameT_contra, /) -> EagerDataFrameT: ...
144+
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
147145
@overload
148146
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
149147
def from_native(
150-
self, data: NativeFrameT_contra | NativeSeriesT | Any, /
148+
self, data: NativeFrameT | NativeSeriesT | Any, /
151149
) -> EagerDataFrameT | EagerSeriesT:
152150
if self._dataframe._is_native(data):
153151
return self._dataframe.from_native(data, context=self)
@@ -181,3 +179,22 @@ def from_numpy(
181179
if is_numpy_array_2d(data):
182180
return self._dataframe.from_numpy(data, schema=schema, context=self)
183181
return self._series.from_numpy(data, context=self)
182+
183+
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
184+
def _concat_horizontal(
185+
self, dfs: Sequence[NativeFrameT | Any], /
186+
) -> NativeFrameT: ...
187+
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
188+
def concat(
189+
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
190+
) -> EagerDataFrameT:
191+
dfs = [item.native for item in items]
192+
if how == "horizontal":
193+
native = self._concat_horizontal(dfs)
194+
elif how == "vertical":
195+
native = self._concat_vertical(dfs)
196+
elif how == "diagonal":
197+
native = self._concat_diagonal(dfs)
198+
else:
199+
raise NotImplementedError
200+
return self._dataframe.from_native(native, context=self)

narwhals/_pandas_like/namespace.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import reduce
66
from typing import TYPE_CHECKING
77
from typing import Any
8-
from typing import Iterable
98
from typing import Literal
109
from typing import Sequence
1110

@@ -27,7 +26,6 @@
2726

2827
from narwhals._pandas_like.typing import NDFrameT
2928
from narwhals.dtypes import DType
30-
from narwhals.typing import ConcatMethod
3129
from narwhals.utils import Implementation
3230
from narwhals.utils import Version
3331

@@ -274,20 +272,6 @@ def _concat_vertical(self, dfs: Sequence[pd.DataFrame], /) -> pd.DataFrame:
274272
return self._concat(dfs, axis=VERTICAL, copy=False)
275273
return self._concat(dfs, axis=VERTICAL)
276274

277-
def concat(
278-
self, items: Iterable[PandasLikeDataFrame], *, how: ConcatMethod
279-
) -> PandasLikeDataFrame:
280-
dfs: list[pd.DataFrame] = [item.native for item in items]
281-
if how == "horizontal":
282-
native = self._concat_horizontal(dfs)
283-
elif how == "vertical":
284-
native = self._concat_vertical(dfs)
285-
elif how == "diagonal":
286-
native = self._concat_diagonal(dfs)
287-
else:
288-
raise NotImplementedError
289-
return self._dataframe.from_native(native, context=self)
290-
291275
def when(self: Self, predicate: PandasLikeExpr) -> PandasWhen:
292276
return PandasWhen.from_expr(predicate, context=self)
293277

0 commit comments

Comments
 (0)