diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 02a2417014..db896891d8 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -204,16 +204,6 @@ def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: return pa.Table.from_arrays(arrays, names=names) def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: - cols_0 = dfs[0].column_names - for i, df in enumerate(dfs[1:], start=1): - cols_current = df.column_names - if cols_current != cols_0: - msg = ( - "unable to vstack, column names don't match:\n" - f" - dataframe 0: {cols_0}\n" - f" - dataframe {i}: {cols_current}\n" - ) - raise TypeError(msg) return pa.concat_tables(dfs) @property diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 4cc7130828..3a92506908 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -20,6 +20,7 @@ exclude_column_names, get_column_names, passthrough_column_names, + validate_concat_vertical_schemas, ) from narwhals.dependencies import is_numpy_array, is_numpy_array_2d @@ -239,6 +240,7 @@ def concat( if how == "horizontal": native = self._concat_horizontal(dfs) elif how == "vertical": + validate_concat_vertical_schemas(item.schema for item in items) native = self._concat_vertical(dfs) elif how == "diagonal": native = self._concat_diagonal(dfs) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index df66a583c4..a25608d28d 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -27,7 +27,7 @@ combine_alias_output_names, combine_evaluate_output_names, ) -from narwhals._utils import Implementation, zip_strict +from narwhals._utils import Implementation, validate_concat_vertical_schemas, zip_strict if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence @@ -148,32 +148,17 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def concat( self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod ) -> DaskLazyFrame: - if not items: - msg = "No items to concatenate" # pragma: no cover - raise AssertionError(msg) - dfs = [i._native_frame for i in items] - cols_0 = dfs[0].columns + items = list(items) + dfs = [item._native_frame for item in items] if how == "vertical": - for i, df in enumerate(dfs[1:], start=1): - cols_current = df.columns - if not ( - (len(cols_current) == len(cols_0)) and (cols_current == cols_0).all() - ): - msg = ( - "unable to vstack, column names don't match:\n" - f" - dataframe 0: {cols_0.to_list()}\n" - f" - dataframe {i}: {cols_current.to_list()}\n" - ) - raise TypeError(msg) - return DaskLazyFrame( - dd.concat(dfs, axis=0, join="inner"), version=self._version - ) - if how == "diagonal": - return DaskLazyFrame( - dd.concat(dfs, axis=0, join="outer"), version=self._version - ) - - raise NotImplementedError + validate_concat_vertical_schemas(item.schema for item in items) + native_result = dd.concat(dfs, axis=0, join="inner") + elif how == "diagonal": + native_result = dd.concat(dfs, axis=0, join="outer") + else: # pragma: no cover + raise NotImplementedError + + return self._lazyframe.from_native(native_result, context=self) def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 09b5ecd8eb..5235b2b532 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -27,7 +27,7 @@ ) from narwhals._sql.namespace import SQLNamespace from narwhals._sql.when_then import SQLThen, SQLWhen -from narwhals._utils import Implementation +from narwhals._utils import Implementation, validate_concat_vertical_schemas if TYPE_CHECKING: from collections.abc import Iterable @@ -82,23 +82,21 @@ def _coalesce(self, *exprs: Expression) -> Expression: def concat( self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod ) -> DuckDBLazyFrame: - native_items = [item._native_frame for item in items] items = list(items) - first = items[0] - schema = first.schema - if how == "vertical" and not all(x.schema == schema for x in items[1:]): - msg = "inputs should all have the same schema" - raise TypeError(msg) - if how == "diagonal": - res = first.native - for _item in native_items[1:]: + if how == "vertical": + validate_concat_vertical_schemas(item.schema for item in items) + res = reduce(lambda x, y: x.union(y), (item._native_frame for item in items)) + elif how == "diagonal": + res, *others = (item._native_frame for item in items) + for _item in others: # TODO(unassigned): use relational API when available https://github.com/duckdb/duckdb/discussions/16996 res = duckdb.sql(""" from res select * union all by name from _item select * """) - return first._with_native(res) - res = reduce(lambda x, y: x.union(y), native_items) - return first._with_native(res) + else: # pragma: no cover + raise NotImplementedError + + return self._lazyframe.from_native(res, context=self) def concat_str( self, *exprs: DuckDBExpr, separator: str, ignore_nulls: bool diff --git a/narwhals/_ibis/namespace.py b/narwhals/_ibis/namespace.py index 3509d805fc..6cb0bc8802 100644 --- a/narwhals/_ibis/namespace.py +++ b/narwhals/_ibis/namespace.py @@ -18,7 +18,7 @@ from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype from narwhals._sql.namespace import SQLNamespace from narwhals._sql.when_then import SQLThen, SQLWhen -from narwhals._utils import Implementation, requires +from narwhals._utils import Implementation, requires, validate_concat_vertical_schemas if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -69,12 +69,10 @@ def concat( raise NotImplementedError(msg) items = list(items) - native_items = [item.native for item in items] - schema = items[0].schema - if not all(x.schema == schema for x in items[1:]): - msg = "inputs should all have the same schema" - raise TypeError(msg) - return self._lazyframe.from_native(ibis.union(*native_items), context=self) + validate_concat_vertical_schemas(item.schema for item in items) + return self._lazyframe.from_native( + ibis.union(*(item.native for item in items)), context=self + ) def concat_str( self, *exprs: IbisExpr, separator: str, ignore_nulls: bool diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index fff20e292b..d528562cf8 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -292,18 +292,6 @@ def _concat_horizontal( return self._concat(dfs, axis=HORIZONTAL) def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFrameT: - cols_0 = dfs[0].columns - for i, df in enumerate(dfs[1:], start=1): - cols_current = df.columns - if not ( - (len(cols_current) == len(cols_0)) and (cols_current == cols_0).all() - ): - msg = ( - "unable to vstack, column names don't match:\n" - f" - dataframe 0: {cols_0.to_list()}\n" - f" - dataframe {i}: {cols_current.to_list()}\n" - ) - raise TypeError(msg) if self._implementation.is_pandas() and self._backend_version < (3,): return self._concat(dfs, axis=VERTICAL, copy=False) return self._concat(dfs, axis=VERTICAL) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 8f257b36bd..b6813e41fd 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Literal, cast, overload +from typing import TYPE_CHECKING, Any, cast, overload import polars as pl @@ -26,6 +26,7 @@ from narwhals.expr import Expr from narwhals.series import Series from narwhals.typing import ( + ConcatMethod, Into1DArray, IntoDType, IntoSchema, @@ -162,10 +163,7 @@ def any_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr: return self._expr(pl.any_horizontal(*(expr.native for expr in it)), self._version) def concat( - self, - items: Iterable[FrameT], - *, - how: Literal["vertical", "horizontal", "diagonal"], + self, items: Iterable[FrameT], *, how: ConcatMethod ) -> PolarsDataFrame | PolarsLazyFrame: result = pl.concat((item.native for item in items), how=how) if isinstance(result, pl.DataFrame): diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 2e0be736ca..d52d822dd4 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -19,7 +19,7 @@ ) from narwhals._sql.namespace import SQLNamespace from narwhals._sql.when_then import SQLThen, SQLWhen -from narwhals._utils import zip_strict +from narwhals._utils import validate_concat_vertical_schemas, zip_strict if TYPE_CHECKING: from collections.abc import Iterable @@ -142,32 +142,15 @@ def concat( ) -> SparkLikeLazyFrame: dfs = [item._native_frame for item in items] if how == "vertical": - cols_0 = dfs[0].columns - for i, df in enumerate(dfs[1:], start=1): - cols_current = df.columns - if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)): - msg = ( - "unable to vstack, column names don't match:\n" - f" - dataframe 0: {cols_0}\n" - f" - dataframe {i}: {cols_current}\n" - ) - raise TypeError(msg) - - return SparkLikeLazyFrame( - native_dataframe=reduce(lambda x, y: x.union(y), dfs), - version=self._version, - implementation=self._implementation, + validate_concat_vertical_schemas(item.schema for item in items) + native_result = reduce(lambda x, y: x.union(y), dfs) + elif how == "diagonal": + native_result = reduce( + lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs ) - - if how == "diagonal": - return SparkLikeLazyFrame( - native_dataframe=reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs - ), - version=self._version, - implementation=self._implementation, - ) - raise NotImplementedError + else: # pragma: no cover + raise NotImplementedError + return self._lazyframe.from_native(native_result, context=self) def concat_str( self, *exprs: SparkLikeExpr, separator: str, ignore_nulls: bool diff --git a/narwhals/_utils.py b/narwhals/_utils.py index 65a2c183f1..ec2a89f677 100644 --- a/narwhals/_utils.py +++ b/narwhals/_utils.py @@ -2120,3 +2120,22 @@ def extend_bool( Stolen from https://github.com/pola-rs/polars/blob/b8bfb07a4a37a8d449d6d1841e345817431142df/py-polars/polars/_utils/various.py#L580-L594 """ return (value,) * n_match if isinstance(value, bool) else tuple(value) + + +def validate_concat_vertical_schemas(schemas: Iterable[Mapping[str, DType]]) -> None: + schemas_iter = iter(schemas) + schema_0 = next(schemas_iter) + for idx, schema_current in enumerate(schemas_iter, start=1): + if schema_0 != schema_current: + msg = ( + "'union'/'concat' inputs should all have the same schema,got\n" + f"{_pretty_format_schema(schema_0, index=0)}\n" + " and\n" + f"{_pretty_format_schema(schema_current, index=idx)}" + ) + raise InvalidOperationError(msg) + + +def _pretty_format_schema(schema: Mapping[str, DType], index: int) -> str: + body = "\n".join(f"name: {name}, field: {field}" for name, field in schema.items()) + return f"Schema at index: {index}:\n{body}" diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 35ebd54d39..08eae95fa9 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -34,10 +34,8 @@ def test_concat_horizontal(constructor_eager: ConstructorEager) -> None: def test_concat_vertical(constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} - df_left = ( - nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z") - ) + data = {"c": [1, 3, 2], "d": [4, 4, 6]} + df_left = nw.from_native(constructor(data)).lazy() data_right = {"c": [6, 12, -1], "d": [0, -4, 2]} df_right = nw.from_native(constructor(data_right)).lazy() @@ -49,17 +47,19 @@ def test_concat_vertical(constructor: Constructor) -> None: with pytest.raises(ValueError, match="No items"): nw.concat([], how="vertical") - with pytest.raises( - (Exception, TypeError), - match=r"unable to vstack|inputs should all have the same schema", - ): + err_msg = r"unable to vstack|unable to append|inputs should all have the same schema|cannot extend/append" + + with pytest.raises((Exception, InvalidOperationError), match=err_msg): nw.concat([df_left, df_right.rename({"d": "i"})], how="vertical").collect() - with pytest.raises( - (Exception, TypeError), - match=r"unable to vstack|unable to append|inputs should all have the same schema", - ): + + with pytest.raises((Exception, InvalidOperationError), match=err_msg): nw.concat([df_left, df_left.select("d")], how="vertical").collect() + with pytest.raises((Exception, InvalidOperationError), match=err_msg): + nw.concat( + [df_left, df_left.select("c", nw.col("d").cast(nw.Int32))], how="vertical" + ).collect() + def test_concat_diagonal( constructor: Constructor, request: pytest.FixtureRequest