Skip to content

Commit 91cb28b

Browse files
committed
refactor: refactor: Simplify ArrowNamespace.concat
Related #2368
1 parent b1d40d2 commit 91cb28b

File tree

2 files changed

+32
-65
lines changed

2 files changed

+32
-65
lines changed

narwhals/_arrow/namespace.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88
from typing import Iterable
99
from typing import Literal
10+
from typing import Sequence
1011

1112
import pyarrow as pa
1213
import pyarrow.compute as pc
@@ -17,9 +18,6 @@
1718
from narwhals._arrow.series import ArrowSeries
1819
from narwhals._arrow.utils import align_series_full_broadcast
1920
from narwhals._arrow.utils import cast_to_comparable_string_types
20-
from narwhals._arrow.utils import diagonal_concat
21-
from narwhals._arrow.utils import horizontal_concat
22-
from narwhals._arrow.utils import vertical_concat
2321
from narwhals._compliant import CompliantThen
2422
from narwhals._compliant import EagerNamespace
2523
from narwhals._compliant import EagerWhen
@@ -211,30 +209,46 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
211209
context=self,
212210
)
213211

212+
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
213+
if self._backend_version >= (14,):
214+
return pa.concat_tables(dfs, promote_options="default")
215+
return pa.concat_tables(dfs, promote=True)
216+
217+
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
218+
names = [name for df in dfs for name in df.column_names]
219+
220+
if len(set(names)) < len(names): # pragma: no cover
221+
msg = "Expected unique column names"
222+
raise ValueError(msg)
223+
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
224+
return pa.Table.from_arrays(arrays, names=names)
225+
226+
def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
227+
cols_0 = dfs[0].column_names
228+
for i, df in enumerate(dfs[1:], start=1):
229+
cols_current = df.column_names
230+
if cols_current != cols_0:
231+
msg = (
232+
"unable to vstack, column names don't match:\n"
233+
f" - dataframe 0: {cols_0}\n"
234+
f" - dataframe {i}: {cols_current}\n"
235+
)
236+
raise TypeError(msg)
237+
return pa.concat_tables(dfs)
238+
214239
def concat(
215240
self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod
216241
) -> ArrowDataFrame:
217242
dfs = [item.native for item in items]
218-
219-
if not dfs:
220-
msg = "No dataframes to concatenate" # pragma: no cover
221-
raise AssertionError(msg)
222-
223243
if how == "horizontal":
224-
result_table = horizontal_concat(dfs)
244+
native = self._concat_horizontal(dfs)
225245
elif how == "vertical":
226-
result_table = vertical_concat(dfs)
246+
native = self._concat_vertical(dfs)
227247
elif how == "diagonal":
228-
result_table = diagonal_concat(dfs, self._backend_version)
248+
native = self._concat_diagonal(dfs)
229249
else:
230250
raise NotImplementedError
231-
232-
return ArrowDataFrame(
233-
result_table,
234-
backend_version=self._backend_version,
235-
version=self._version,
236-
validate_column_names=True,
237-
)
251+
return self._dataframe.from_native(native, context=self)
238252

239253
@property
240254
def selectors(self: Self) -> ArrowSelectorNamespace:

narwhals/_arrow/utils.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from functools import lru_cache
4-
from itertools import chain
54
from typing import TYPE_CHECKING
65
from typing import Any
76
from typing import Iterable
@@ -280,52 +279,6 @@ def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]:
280279
return reshaped
281280

282281

283-
def horizontal_concat(dfs: list[pa.Table]) -> pa.Table:
284-
"""Concatenate (native) DataFrames horizontally.
285-
286-
Should be in namespace.
287-
"""
288-
names = [name for df in dfs for name in df.column_names]
289-
290-
if len(set(names)) < len(names): # pragma: no cover
291-
msg = "Expected unique column names"
292-
raise ValueError(msg)
293-
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
294-
return pa.Table.from_arrays(arrays, names=names)
295-
296-
297-
def vertical_concat(dfs: list[pa.Table]) -> pa.Table:
298-
"""Concatenate (native) DataFrames vertically.
299-
300-
Should be in namespace.
301-
"""
302-
cols_0 = dfs[0].column_names
303-
for i, df in enumerate(dfs[1:], start=1):
304-
cols_current = df.column_names
305-
if cols_current != cols_0:
306-
msg = (
307-
"unable to vstack, column names don't match:\n"
308-
f" - dataframe 0: {cols_0}\n"
309-
f" - dataframe {i}: {cols_current}\n"
310-
)
311-
raise TypeError(msg)
312-
313-
return pa.concat_tables(dfs)
314-
315-
316-
def diagonal_concat(dfs: list[pa.Table], backend_version: tuple[int, ...]) -> pa.Table:
317-
"""Concatenate (native) DataFrames diagonally.
318-
319-
Should be in namespace.
320-
"""
321-
kwargs: dict[str, Any] = (
322-
{"promote": True}
323-
if backend_version < (14, 0, 0)
324-
else {"promote_options": "default"}
325-
)
326-
return pa.concat_tables(dfs, **kwargs)
327-
328-
329282
def floordiv_compat(left: Any, right: Any) -> Any:
330283
# The following lines are adapted from pandas' pyarrow implementation.
331284
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154

0 commit comments

Comments
 (0)