Skip to content

Commit d09aa74

Browse files
committed
adapt old BinaryExpr tests and make them fail
1 parent 936c78b commit d09aa74

File tree

4 files changed

+49
-40
lines changed

4 files changed

+49
-40
lines changed

narwhals/_plan/_expansion.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,12 @@ def _expand_combination(self, origin: Combination, /) -> Iterator[Combination]:
284284

285285
# TODO @dangotbanned: Relax `BinaryExpr.right`
286286
# - https://github.com/narwhals-dev/narwhals/pull/3233#discussion_r2472757798
287-
# - https://github.com/narwhals-dev/narwhals/pull/3233#discussion_r2473810664
287+
# - https://github.com/narwhals-dev/narwhals/pull/3233#discussion_r2473810664=
288+
# NOTE: Only need to raise if outputs are not:
289+
# - 1:1
290+
# - M:1
291+
# - 1:M
292+
# - N:N
288293
elif isinstance(origin, ir.BinaryExpr):
289294
binary = origin.__replace__(right=self._expand_only(origin.right))
290295
for root in self._expand_recursive(binary.left):

narwhals/_plan/expressions/operators.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from narwhals._plan._immutable import Immutable
88
from narwhals._plan.exceptions import (
99
binary_expr_length_changing_error,
10-
binary_expr_multi_output_error,
1110
binary_expr_shape_error,
1211
)
1312

@@ -46,15 +45,6 @@ def to_binary_expr(
4645
) -> BinaryExpr[LeftT, Self, RightT]:
4746
from narwhals._plan.expressions.expr import BinaryExpr
4847

49-
# TODO @dangotbanned: Leave validating this until actually expanding
50-
# Only need to raise if outputs are not:
51-
# - 1:1
52-
# - M:1
53-
# - 1:M
54-
# - N:N
55-
# https://github.com/narwhals-dev/narwhals/pull/3233#discussion_r2473810664
56-
if right.meta.has_multiple_outputs():
57-
raise binary_expr_multi_output_error(left, self, right)
5848
if _is_filtration(left):
5949
if _is_filtration(right):
6050
raise binary_expr_length_changing_error(left, self, right)

tests/plan/expr_expansion_test.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from narwhals import _plan as nwp
1010
from narwhals._plan import expressions as ir, selectors as ncs
1111
from narwhals._utils import zip_strict
12-
from narwhals.exceptions import ColumnNotFoundError, DuplicateError
12+
from narwhals.exceptions import (
13+
ColumnNotFoundError,
14+
DuplicateError,
15+
MultiOutputExpressionError,
16+
)
1317
from tests.plan.utils import Frame, assert_expr_ir_equal, named_ir, re_compile
1418

1519
if TYPE_CHECKING:
@@ -596,3 +600,41 @@ def test_prepare_projection_index_error(
596600
),
597601
):
598602
df_1.project(into_exprs)
603+
604+
605+
@pytest.mark.xfail(
606+
reason="TODO: binary_expr_combination", raises=MultiOutputExpressionError
607+
)
608+
def test_expand_binary_expr_combination(df_1: Frame) -> None: # pragma: no cover
609+
three_to_three = nwp.nth(range(3)) * nwp.nth(3, 4, 5).max()
610+
611+
expecteds = [
612+
named_ir("a", nwp.col("a") * nwp.col("d")),
613+
named_ir("b", nwp.col("b") * nwp.col("e")),
614+
named_ir("c", nwp.col("c") * nwp.col("f")),
615+
]
616+
actuals = df_1.project(three_to_three)
617+
for actual, expected in zip_strict(actuals, expecteds):
618+
assert_expr_ir_equal(actual, expected)
619+
620+
621+
@pytest.mark.xfail(reason="TODO: Move fancy error message", raises=AssertionError)
622+
def test_expand_binary_expr_combination_invalid(df_1: Frame) -> None: # pragma: no cover
623+
pattern = re.escape(
624+
"ncs.all() + ncs.by_name('b', 'c', require_all=True)\n"
625+
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
626+
)
627+
all_to_two = nwp.all() + nwp.col("b", "c")
628+
with pytest.raises(MultiOutputExpressionError, match=pattern):
629+
df_1.project(all_to_two)
630+
631+
pattern = re.escape(
632+
"ncs.by_name('a', 'b', 'c', require_all=True).abs().fill_null([lit(int: 0)]).round() * ncs.by_index([9, 10], require_all=True).cast(Int64).sort(asc)\n"
633+
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
634+
)
635+
three_to_two = (
636+
nwp.col("a", "b", "c").abs().fill_null(0).round(2)
637+
* nwp.nth(9, 10).cast(nw.Int64).sort()
638+
)
639+
with pytest.raises(MultiOutputExpressionError, match=pattern):
640+
df_1.project(three_to_two)

tests/plan/expr_parsing_test.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
InvalidIntoExprError,
2121
InvalidOperationError,
2222
InvalidOperationError as LengthChangingExprError,
23-
MultiOutputExpressionError,
2423
ShapeError,
2524
)
2625
from tests.plan.utils import assert_expr_ir_equal
@@ -207,33 +206,6 @@ def test_filtration_over() -> None:
207206
nwp.col("a").diff().drop_nulls().over("b", order_by="i")
208207

209208

210-
def test_invalid_binary_expr_multi() -> None:
211-
# TODO @dangotbanned: Move to `expr_expansion_test`
212-
pattern = re.escape(
213-
"ncs.all() + ncs.by_name('b', 'c', require_all=True)\n"
214-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
215-
)
216-
with pytest.raises(MultiOutputExpressionError, match=pattern):
217-
nwp.all() + nwp.col("b", "c")
218-
219-
# TODO @dangotbanned: Use as a positive case (3:3)
220-
pattern = re.escape(
221-
"ncs.by_index([1, 2, 3], require_all=True) * ncs.by_index([4, 5, 6], require_all=True).max()\n"
222-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
223-
)
224-
with pytest.raises(MultiOutputExpressionError, match=pattern):
225-
nwp.nth(1, 2, 3) * nwp.nth(4, 5, 6).max()
226-
pattern = re.escape(
227-
"ncs.by_name('a', 'b', 'c', require_all=True).abs().fill_null([lit(int: 0)]).round() * ncs.by_index([9, 10], require_all=True).cast(Int64).sort(asc)\n"
228-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
229-
)
230-
# TODO @dangotbanned: Move to `expr_expansion_test`
231-
with pytest.raises(MultiOutputExpressionError, match=pattern):
232-
nwp.col("a", "b", "c").abs().fill_null(0).round(2) * nwp.nth(9, 10).cast(
233-
nw.Int64()
234-
).sort()
235-
236-
237209
def test_invalid_binary_expr_length_changing() -> None:
238210
a = nwp.col("a")
239211
b = nwp.col("b")

0 commit comments

Comments
 (0)