|
7 | 7 | from typing import Any |
8 | 8 | from typing import Iterable |
9 | 9 | from typing import Literal |
| 10 | +from typing import Sequence |
10 | 11 |
|
11 | 12 | import pyarrow as pa |
12 | 13 | import pyarrow.compute as pc |
|
17 | 18 | from narwhals._arrow.series import ArrowSeries |
18 | 19 | from narwhals._arrow.utils import align_series_full_broadcast |
19 | 20 | from narwhals._arrow.utils import cast_to_comparable_string_types |
20 | | -from narwhals._arrow.utils import diagonal_concat |
21 | | -from narwhals._arrow.utils import horizontal_concat |
22 | | -from narwhals._arrow.utils import vertical_concat |
23 | 21 | from narwhals._compliant import CompliantThen |
24 | 22 | from narwhals._compliant import EagerNamespace |
25 | 23 | from narwhals._compliant import EagerWhen |
@@ -211,30 +209,46 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: |
211 | 209 | context=self, |
212 | 210 | ) |
213 | 211 |
|
| 212 | + def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: |
| 213 | + if self._backend_version >= (14,): |
| 214 | + return pa.concat_tables(dfs, promote_options="default") |
| 215 | + return pa.concat_tables(dfs, promote=True) |
| 216 | + |
| 217 | + def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: |
| 218 | + names = [name for df in dfs for name in df.column_names] |
| 219 | + |
| 220 | + if len(set(names)) < len(names): # pragma: no cover |
| 221 | + msg = "Expected unique column names" |
| 222 | + raise ValueError(msg) |
| 223 | + arrays = list(chain.from_iterable(df.itercolumns() for df in dfs)) |
| 224 | + return pa.Table.from_arrays(arrays, names=names) |
| 225 | + |
| 226 | + def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: |
| 227 | + cols_0 = dfs[0].column_names |
| 228 | + for i, df in enumerate(dfs[1:], start=1): |
| 229 | + cols_current = df.column_names |
| 230 | + if cols_current != cols_0: |
| 231 | + msg = ( |
| 232 | + "unable to vstack, column names don't match:\n" |
| 233 | + f" - dataframe 0: {cols_0}\n" |
| 234 | + f" - dataframe {i}: {cols_current}\n" |
| 235 | + ) |
| 236 | + raise TypeError(msg) |
| 237 | + return pa.concat_tables(dfs) |
| 238 | + |
214 | 239 | def concat( |
215 | 240 | self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod |
216 | 241 | ) -> ArrowDataFrame: |
217 | 242 | dfs = [item.native for item in items] |
218 | | - |
219 | | - if not dfs: |
220 | | - msg = "No dataframes to concatenate" # pragma: no cover |
221 | | - raise AssertionError(msg) |
222 | | - |
223 | 243 | if how == "horizontal": |
224 | | - result_table = horizontal_concat(dfs) |
| 244 | + native = self._concat_horizontal(dfs) |
225 | 245 | elif how == "vertical": |
226 | | - result_table = vertical_concat(dfs) |
| 246 | + native = self._concat_vertical(dfs) |
227 | 247 | elif how == "diagonal": |
228 | | - result_table = diagonal_concat(dfs, self._backend_version) |
| 248 | + native = self._concat_diagonal(dfs) |
229 | 249 | else: |
230 | 250 | raise NotImplementedError |
231 | | - |
232 | | - return ArrowDataFrame( |
233 | | - result_table, |
234 | | - backend_version=self._backend_version, |
235 | | - version=self._version, |
236 | | - validate_column_names=True, |
237 | | - ) |
| 251 | + return self._dataframe.from_native(native, context=self) |
238 | 252 |
|
239 | 253 | @property |
240 | 254 | def selectors(self: Self) -> ArrowSelectorNamespace: |
|
0 commit comments