Skip to content

Commit 7ae5a8c

Browse files
authored
feat!: disallow concat(..., how="horizontal") for LazyFrame (#2341)
1 parent 16b4527 commit 7ae5a8c

File tree

7 files changed

+45
-48
lines changed

7 files changed

+45
-48
lines changed

narwhals/_dask/namespace.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -175,24 +175,6 @@ def concat(
175175
backend_version=self._backend_version,
176176
version=self._version,
177177
)
178-
if how == "horizontal":
179-
all_column_names: list[str] = [
180-
column for frame in dfs for column in frame.columns
181-
]
182-
if len(all_column_names) != len(set(all_column_names)): # pragma: no cover
183-
duplicates = [
184-
i for i in all_column_names if all_column_names.count(i) > 1
185-
]
186-
msg = (
187-
f"Columns with name(s): {', '.join(duplicates)} "
188-
"have more than one occurrence"
189-
)
190-
raise AssertionError(msg)
191-
return DaskLazyFrame(
192-
dd.concat(dfs, axis=1, join="outer"),
193-
backend_version=self._backend_version,
194-
version=self._version,
195-
)
196178
if how == "diagonal":
197179
return DaskLazyFrame(
198180
dd.concat(dfs, axis=0, join="outer"),

narwhals/_duckdb/namespace.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,16 @@ def _lazyframe(self) -> type[DuckDBLazyFrame]:
6060
return DuckDBLazyFrame
6161

6262
def concat(
63-
self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
63+
self: Self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
6464
) -> DuckDBLazyFrame:
65-
if how == "horizontal":
66-
msg = "horizontal concat not supported for duckdb. Please join instead"
67-
raise TypeError(msg)
68-
if how == "diagonal":
69-
msg = "Not implemented yet"
70-
raise NotImplementedError(msg)
65+
native_items = [item._native_frame for item in items]
7166
items = list(items)
7267
first = items[0]
7368
schema = first.schema
7469
if how == "vertical" and not all(x.schema == schema for x in items[1:]):
7570
msg = "inputs should all have the same schema"
7671
raise TypeError(msg)
77-
res = reduce(lambda x, y: x.union(y), (item._native_frame for item in items))
72+
res = reduce(lambda x, y: x.union(y), native_items)
7873
return first._with_native(res)
7974

8075
def concat_str(

narwhals/_spark_like/namespace.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,6 @@ def concat(
192192
self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod
193193
) -> SparkLikeLazyFrame:
194194
dfs = [item._native_frame for item in items]
195-
if how == "horizontal":
196-
msg = (
197-
"Horizontal concatenation is not supported for LazyFrame backed by "
198-
"a PySpark DataFrame."
199-
)
200-
raise NotImplementedError(msg)
201-
202195
if how == "vertical":
203196
cols_0 = dfs[0].columns
204197
for i, df in enumerate(dfs[1:], start=1):

narwhals/functions.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from narwhals.dependencies import is_numpy_array
2828
from narwhals.dependencies import is_numpy_array_2d
2929
from narwhals.dependencies import is_pyarrow_table
30+
from narwhals.exceptions import InvalidOperationError
3031
from narwhals.expr import Expr
3132
from narwhals.series import Series
3233
from narwhals.translate import from_native
@@ -79,12 +80,13 @@ def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT
7980
8081
- vertical: Concatenate vertically. Column names must match.
8182
- horizontal: Concatenate horizontally. If lengths don't match, then
82-
missing rows are filled with null values.
83+
missing rows are filled with null values. This is only supported
84+
when all inputs are (eager) DataFrames.
8385
- diagonal: Finds a union between the column schemas and fills missing column
8486
values with null.
8587
8688
Returns:
87-
A new DataFrame, Lazyframe resulting from the concatenation.
89+
A new DataFrame or LazyFrame resulting from the concatenation.
8890
8991
Raises:
9092
TypeError: The items to concatenate should either all be eager, or all lazy
@@ -151,15 +153,23 @@ def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT
151153
|z: [[null,null],["x","y"]]|
152154
└──────────────────────────┘
153155
"""
154-
if how not in {"horizontal", "vertical", "diagonal"}: # pragma: no cover
155-
msg = "Only vertical, horizontal and diagonal concatenations are supported."
156-
raise NotImplementedError(msg)
156+
from narwhals.dependencies import is_narwhals_lazyframe
157+
157158
if not items:
158-
msg = "No items to concatenate"
159+
msg = "No items to concatenate."
159160
raise ValueError(msg)
160161
items = list(items)
161162
validate_laziness(items)
163+
if how not in {"horizontal", "vertical", "diagonal"}: # pragma: no cover
164+
msg = "Only vertical, horizontal and diagonal concatenations are supported."
165+
raise NotImplementedError(msg)
162166
first_item = items[0]
167+
if is_narwhals_lazyframe(first_item) and how == "horizontal":
168+
msg = (
169+
"Horizontal concatenation is not supported for LazyFrames.\n\n"
170+
"Hint: you may want to use `join` instead."
171+
)
172+
raise InvalidOperationError(msg)
163173
plx = first_item.__narwhals_namespace__()
164174
return first_item._with_compliant(
165175
plx.concat([df._compliant_frame for df in items], how=how),

narwhals/stable/v1/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,12 +2069,13 @@ def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT
20692069
20702070
- vertical: Concatenate vertically. Column names must match.
20712071
- horizontal: Concatenate horizontally. If lengths don't match, then
2072-
missing rows are filled with null values.
2072+
missing rows are filled with null values. This is only supported
2073+
when all inputs are (eager) DataFrames.
20732074
- diagonal: Finds a union between the column schemas and fills missing column
20742075
values with null.
20752076
20762077
Returns:
2077-
A new DataFrame, Lazyframe resulting from the concatenation.
2078+
A new DataFrame or LazyFrame resulting from the concatenation.
20782079
20792080
Raises:
20802081
TypeError: The items to concatenate should either all be eager, or all lazy

tests/frame/concat_test.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
from __future__ import annotations
22

3+
import re
4+
35
import pytest
46

57
import narwhals.stable.v1 as nw
8+
from narwhals.exceptions import InvalidOperationError
69
from tests.utils import Constructor
10+
from tests.utils import ConstructorEager
711
from tests.utils import assert_equal_data
812

913

10-
def test_concat_horizontal(
11-
constructor: Constructor, request: pytest.FixtureRequest
12-
) -> None:
13-
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
14-
request.applymarker(pytest.mark.xfail)
14+
def test_concat_horizontal(constructor_eager: ConstructorEager) -> None:
1515
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
16-
df_left = nw.from_native(constructor(data)).lazy()
16+
df_left = nw.from_native(constructor_eager(data), eager_only=True)
1717

1818
data_right = {"c": [6, 12, -1], "d": [0, -4, 2]}
19-
df_right = nw.from_native(constructor(data_right)).lazy()
19+
df_right = nw.from_native(constructor_eager(data_right), eager_only=True)
2020

2121
result = nw.concat([df_left, df_right], how="horizontal")
2222
expected = {
@@ -30,6 +30,9 @@ def test_concat_horizontal(
3030

3131
with pytest.raises(ValueError, match="No items"):
3232
nw.concat([])
33+
pattern = re.compile(r"horizontal.+not supported.+lazyframe", re.IGNORECASE)
34+
with pytest.raises(InvalidOperationError, match=pattern):
35+
nw.concat([df_left.lazy()], how="horizontal")
3336

3437

3538
def test_concat_vertical(constructor: Constructor) -> None:

tests/series_only/hist_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from tests.utils import ConstructorEager
1414
from tests.utils import assert_equal_data
1515

16+
xfail_hist = pytest.mark.xfail(
17+
reason="https://github.com/narwhals-dev/narwhals/issues/2348", strict=False
18+
)
19+
20+
1621
data = {
1722
"int": [0, 1, 2, 3, 4, 5, 6],
1823
}
@@ -76,6 +81,7 @@
7681
]
7782

7883

84+
@xfail_hist
7985
@pytest.mark.parametrize("params", bins_and_expected)
8086
@pytest.mark.parametrize("include_breakpoint", [True, False])
8187
@pytest.mark.filterwarnings(
@@ -161,6 +167,7 @@ def test_hist_bin(
161167
assert_equal_data(result, expected)
162168

163169

170+
@xfail_hist
164171
@pytest.mark.parametrize("params", counts_and_expected)
165172
@pytest.mark.parametrize("include_breakpoint", [True, False])
166173
@pytest.mark.filterwarnings(
@@ -232,6 +239,7 @@ def test_hist_count(
232239
)
233240

234241

242+
@xfail_hist
235243
@pytest.mark.filterwarnings(
236244
"ignore:`Series.hist` is being called from the stable API although considered an unstable feature."
237245
)
@@ -268,6 +276,7 @@ def test_hist_count_no_spread(
268276
assert_equal_data(result, expected)
269277

270278

279+
@xfail_hist
271280
@pytest.mark.filterwarnings(
272281
"ignore:`Series.hist` is being called from the stable API although considered an unstable feature."
273282
)
@@ -283,6 +292,7 @@ def test_hist_bin_and_bin_count() -> None:
283292
s.hist(bins=[1, 3], bin_count=4)
284293

285294

295+
@xfail_hist
286296
@pytest.mark.filterwarnings(
287297
"ignore:`Series.hist` is being called from the stable API although considered an unstable feature."
288298
)
@@ -331,6 +341,7 @@ def test_hist_small_bins(
331341
s["values"].hist(bins=[1, 3], bin_count=4)
332342

333343

344+
@xfail_hist
334345
@pytest.mark.filterwarnings(
335346
"ignore:`Series.hist` is being called from the stable API although considered an unstable feature."
336347
)
@@ -365,6 +376,7 @@ def test_hist_non_monotonic(constructor_eager: ConstructorEager) -> None:
365376
st.floats(min_value=0.001, max_value=1_000, allow_nan=False), max_size=50
366377
),
367378
)
379+
@xfail_hist
368380
@pytest.mark.filterwarnings(
369381
"ignore:`Series.hist` is being called from the stable API although considered an unstable feature.",
370382
"ignore:invalid value encountered in cast:RuntimeWarning",
@@ -421,6 +433,7 @@ def test_hist_bin_hypotheis(
421433
),
422434
bin_count=st.integers(min_value=0, max_value=1_000),
423435
)
436+
@xfail_hist
424437
@pytest.mark.skipif(
425438
POLARS_VERSION < (1, 15),
426439
reason="hist(bin_count=...) behavior significantly changed after this version",

0 commit comments

Comments
 (0)