Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,6 @@ def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
return pa.Table.from_arrays(arrays, names=names)

def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
cols_0 = dfs[0].column_names
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.column_names
if cols_current != cols_0:
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)
return pa.concat_tables(dfs)

@property
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
exclude_column_names,
get_column_names,
passthrough_column_names,
validate_concat_vertical_schemas,
)
from narwhals.dependencies import is_numpy_array, is_numpy_array_2d

Expand Down Expand Up @@ -239,6 +240,7 @@ def concat(
if how == "horizontal":
native = self._concat_horizontal(dfs)
elif how == "vertical":
validate_concat_vertical_schemas(item.schema for item in items)
native = self._concat_vertical(dfs)
elif how == "diagonal":
native = self._concat_diagonal(dfs)
Expand Down
37 changes: 11 additions & 26 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation, zip_strict
from narwhals._utils import Implementation, validate_concat_vertical_schemas, zip_strict

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
Expand Down Expand Up @@ -148,32 +148,17 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
def concat(
self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
) -> DaskLazyFrame:
if not items:
msg = "No items to concatenate" # pragma: no cover
raise AssertionError(msg)
dfs = [i._native_frame for i in items]
cols_0 = dfs[0].columns
items = list(items)
dfs = [item._native_frame for item in items]
if how == "vertical":
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not (
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0.to_list()}\n"
f" - dataframe {i}: {cols_current.to_list()}\n"
)
raise TypeError(msg)
return DaskLazyFrame(
dd.concat(dfs, axis=0, join="inner"), version=self._version
)
if how == "diagonal":
return DaskLazyFrame(
dd.concat(dfs, axis=0, join="outer"), version=self._version
)

raise NotImplementedError
validate_concat_vertical_schemas(item.schema for item in items)
native_result = dd.concat(dfs, axis=0, join="inner")
elif how == "diagonal":
native_result = dd.concat(dfs, axis=0, join="outer")
else: # pragma: no cover
raise NotImplementedError

return self._lazyframe.from_native(native_result, context=self)

def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
Expand Down
24 changes: 11 additions & 13 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from narwhals._sql.namespace import SQLNamespace
from narwhals._sql.when_then import SQLThen, SQLWhen
from narwhals._utils import Implementation
from narwhals._utils import Implementation, validate_concat_vertical_schemas

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -80,23 +80,21 @@ def _coalesce(self, *exprs: Expression) -> Expression:
def concat(
self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
) -> DuckDBLazyFrame:
native_items = [item._native_frame for item in items]
items = list(items)
first = items[0]
schema = first.schema
if how == "vertical" and not all(x.schema == schema for x in items[1:]):
msg = "inputs should all have the same schema"
raise TypeError(msg)
if how == "diagonal":
res = first.native
for _item in native_items[1:]:
if how == "vertical":
validate_concat_vertical_schemas(item.schema for item in items)
res = reduce(lambda x, y: x.union(y), (item._native_frame for item in items))
elif how == "diagonal":
res, *others = (item._native_frame for item in items)
for _item in others:
# TODO(unassigned): use relational API when available https://github.com/duckdb/duckdb/discussions/16996
res = duckdb.sql("""
from res select * union all by name from _item select *
""")
return first._with_native(res)
res = reduce(lambda x, y: x.union(y), native_items)
return first._with_native(res)
else: # pragma: no cover
raise NotImplementedError

return self._lazyframe.from_native(res, context=self)

def concat_str(
self, *exprs: DuckDBExpr, separator: str, ignore_nulls: bool
Expand Down
12 changes: 5 additions & 7 deletions narwhals/_ibis/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype
from narwhals._sql.namespace import SQLNamespace
from narwhals._sql.when_then import SQLThen, SQLWhen
from narwhals._utils import Implementation, requires
from narwhals._utils import Implementation, requires, validate_concat_vertical_schemas

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -69,12 +69,10 @@ def concat(
raise NotImplementedError(msg)

items = list(items)
native_items = [item.native for item in items]
schema = items[0].schema
if not all(x.schema == schema for x in items[1:]):
msg = "inputs should all have the same schema"
raise TypeError(msg)
return self._lazyframe.from_native(ibis.union(*native_items), context=self)
validate_concat_vertical_schemas(item.schema for item in items)
return self._lazyframe.from_native(
ibis.union(*(item.native for item in items)), context=self
)

def concat_str(
self, *exprs: IbisExpr, separator: str, ignore_nulls: bool
Expand Down
12 changes: 0 additions & 12 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,18 +306,6 @@ def _concat_horizontal(
return self._concat(dfs, axis=HORIZONTAL)

def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFrameT:
cols_0 = dfs[0].columns
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not (
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0.to_list()}\n"
f" - dataframe {i}: {cols_current.to_list()}\n"
)
raise TypeError(msg)
if self._implementation.is_pandas() and self._backend_version < (3,):
return self._concat(dfs, axis=VERTICAL, copy=False)
return self._concat(dfs, axis=VERTICAL)
Expand Down
35 changes: 9 additions & 26 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from narwhals._sql.namespace import SQLNamespace
from narwhals._sql.when_then import SQLThen, SQLWhen
from narwhals._utils import zip_strict
from narwhals._utils import validate_concat_vertical_schemas, zip_strict

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -142,32 +142,15 @@ def concat(
) -> SparkLikeLazyFrame:
dfs = [item._native_frame for item in items]
if how == "vertical":
cols_0 = dfs[0].columns
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)

return SparkLikeLazyFrame(
native_dataframe=reduce(lambda x, y: x.union(y), dfs),
version=self._version,
implementation=self._implementation,
validate_concat_vertical_schemas(item.schema for item in items)
native_result = reduce(lambda x, y: x.union(y), dfs)
elif how == "diagonal":
native_result = reduce(
lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs
)

if how == "diagonal":
return SparkLikeLazyFrame(
native_dataframe=reduce(
lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs
),
version=self._version,
implementation=self._implementation,
)
raise NotImplementedError
else: # pragma: no cover
raise NotImplementedError
return self._lazyframe.from_native(native_result, context=self)

def concat_str(
self, *exprs: SparkLikeExpr, separator: str, ignore_nulls: bool
Expand Down
19 changes: 19 additions & 0 deletions narwhals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,3 +2109,22 @@ def normalize_path(source: FileSource, /) -> str:
from pathlib import Path

return str(Path(source))


def validate_concat_vertical_schemas(schemas: Iterable[Mapping[str, DType]]) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By passing an iterable, we can avoid to immediately evaluate all the schemas

schemas_iter = iter(schemas)
schema_0 = next(schemas_iter)
for idx, schema_current in enumerate(schemas_iter, start=1):
if schema_0 != schema_current:
msg = (
"'union'/'concat' inputs should all have the same schema,got\n"
f"{_pretty_format_schema(schema_0, index=0)}\n"
" and\n"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The space here is not a typo, that's what polars has πŸ˜‚

f"{_pretty_format_schema(schema_current, index=idx)}"
)
raise InvalidOperationError(msg)


def _pretty_format_schema(schema: Mapping[str, DType], index: int) -> str:
body = "\n".join(f"name: {name}, field: {field}" for name, field in schema.items())
return f"Schema at index: {index}:\n{body}"
24 changes: 12 additions & 12 deletions tests/frame/concat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ def test_concat_horizontal(constructor_eager: ConstructorEager) -> None:


def test_concat_vertical(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df_left = (
nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z")
)
data = {"c": [1, 3, 2], "d": [4, 4, 6]}
df_left = nw.from_native(constructor(data)).lazy()

data_right = {"c": [6, 12, -1], "d": [0, -4, 2]}
df_right = nw.from_native(constructor(data_right)).lazy()
Expand All @@ -49,17 +47,19 @@ def test_concat_vertical(constructor: Constructor) -> None:
with pytest.raises(ValueError, match="No items"):
nw.concat([], how="vertical")

with pytest.raises(
(Exception, TypeError),
match=r"unable to vstack|inputs should all have the same schema",
):
err_msg = r"unable to vstack|inputs should all have the same schema"

with pytest.raises((Exception, InvalidOperationError), match=err_msg):
nw.concat([df_left, df_right.rename({"d": "i"})], how="vertical").collect()
with pytest.raises(
(Exception, TypeError),
match=r"unable to vstack|unable to append|inputs should all have the same schema",
):

with pytest.raises((Exception, InvalidOperationError), match=err_msg):
nw.concat([df_left, df_left.select("d")], how="vertical").collect()

with pytest.raises((Exception, InvalidOperationError), match=err_msg):
nw.concat(
[df_left, df_left.select("c", nw.col("d").cast(nw.Int32))], how="vertical"
).collect()


def test_concat_diagonal(
constructor: Constructor, request: pytest.FixtureRequest
Expand Down
Loading