Skip to content

Commit d60a32e

Browse files
committed
feat: Re-introduce error, display shape mismatch
Resolves #3233 (comment)
1 parent fbb1c8e commit d60a32e

File tree

3 files changed

+61
-33
lines changed

3 files changed

+61
-33
lines changed

narwhals/_plan/_expansion.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
from typing import TYPE_CHECKING, Any, Union
4343

4444
from narwhals._plan import common, expressions as ir, meta
45-
from narwhals._plan.exceptions import column_not_found_error, duplicate_error
45+
from narwhals._plan.exceptions import (
46+
binary_expr_multi_output_error,
47+
column_not_found_error,
48+
duplicate_error,
49+
)
4650
from narwhals._plan.expressions import (
4751
Alias,
4852
ExprIR,
@@ -315,8 +319,7 @@ def _expand_binary_expr(self, origin: ir.BinaryExpr, /) -> Iterator[ir.BinaryExp
315319
binary = origin.__replace__(right=rights[0])
316320
yield from (binary.__replace__(left=left) for left in lefts)
317321
else:
318-
msg = "TODO: `binary_expr_multi_output_error`"
319-
raise MultiOutputExpressionError(msg)
322+
raise binary_expr_multi_output_error(origin, lefts, rights)
320323

321324
def _expand_function_expr(
322325
self, origin: ir.FunctionExpr, /

narwhals/_plan/exceptions.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,35 +54,41 @@ def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001
5454
return ComputeError(msg)
5555

5656

57-
# NOTE: Always underlining `right`, since the message refers to both types of exprs
58-
# Assuming the most recent as the issue
57+
def _binary_underline(
58+
left: ir.ExprIR,
59+
operator: Operator,
60+
right: ir.ExprIR,
61+
/,
62+
*,
63+
underline_right: bool = True,
64+
) -> str:
65+
lhs, op, rhs = repr(left), repr(operator), repr(right)
66+
if underline_right:
67+
indent = (len(lhs) + len(op) + 2) * " "
68+
underline = len(rhs) * "^"
69+
else:
70+
indent = ""
71+
underline = len(lhs) * "^"
72+
return f"{lhs} {op} {rhs}\n{indent}{underline}"
73+
74+
5975
def binary_expr_shape_error(
6076
left: ir.ExprIR, op: Operator, right: ir.ExprIR
6177
) -> ShapeError:
62-
lhs_op = f"{left!r} {op!r} "
63-
rhs = repr(right)
64-
indent = len(lhs_op) * " "
65-
underline = len(rhs) * "^"
78+
expr = _binary_underline(left, op, right, underline_right=True)
6679
msg = (
67-
f"Cannot combine length-changing expressions with length-preserving ones.\n"
68-
f"{lhs_op}{rhs}\n{indent}{underline}"
80+
f"Cannot combine length-changing expressions with length-preserving ones.\n{expr}"
6981
)
7082
return ShapeError(msg)
7183

7284

73-
# TODO @dangotbanned: Share the right underline code w/ `binary_expr_shape_error`
7485
def binary_expr_multi_output_error(
75-
left: ir.ExprIR, op: Operator, right: ir.ExprIR
86+
origin: ir.BinaryExpr, left_expand: Seq[ir.ExprIR], right_expand: Seq[ir.ExprIR]
7687
) -> MultiOutputExpressionError:
77-
lhs_op = f"{left!r} {op!r} "
78-
rhs = repr(right)
79-
indent = len(lhs_op) * " "
80-
underline = len(rhs) * "^"
81-
msg = (
82-
"Multi-output expressions are only supported on the "
83-
f"left-hand side of a binary operation.\n"
84-
f"{lhs_op}{rhs}\n{indent}{underline}"
85-
)
88+
len_left, len_right = len(left_expand), len(right_expand)
89+
lhs, op, rhs = origin.left, origin.op, origin.right
90+
expr = _binary_underline(lhs, op, rhs, underline_right=len_left < len_right)
91+
msg = f"Cannot combine selectors that produce a different number of columns ({len_left} != {len_right}).\n{expr}"
8692
return MultiOutputExpressionError(msg)
8793

8894

tests/plan/expr_expansion_test.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -678,23 +678,42 @@ def test_expand_binary_expr_combination(
678678
assert_expr_ir_equal(actual, expect)
679679

680680

681-
@pytest.mark.xfail(reason="TODO: Move fancy error message", raises=AssertionError)
682-
def test_expand_binary_expr_combination_invalid(df_1: Frame) -> None: # pragma: no cover
683-
pattern = re.escape(
681+
def test_expand_binary_expr_combination_invalid(df_1: Frame) -> None:
682+
# fmt: off
683+
expr = re.escape(
684684
"ncs.all() + ncs.by_name('b', 'c', require_all=True)\n"
685-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
685+
"^^^^^^^^^"
686686
)
687+
# fmt: on
688+
shapes = "(20 != 2)"
689+
pattern = rf"{shapes}.+\n{expr}"
687690
all_to_two = nwp.all() + nwp.col("b", "c")
688691
with pytest.raises(MultiOutputExpressionError, match=pattern):
689692
df_1.project(all_to_two)
690693

691-
pattern = re.escape(
692-
"ncs.by_name('a', 'b', 'c', require_all=True).abs().fill_null([lit(int: 0)]).round() * ncs.by_index([9, 10], require_all=True).cast(Int64).sort(asc)\n"
693-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
694+
expr = re.escape(
695+
"ncs.by_name('a', 'b', require_all=True).abs().fill_null([lit(int: 0)]).round() * ncs.by_index([9, 10, 11], require_all=True).cast(Int64).sort(asc)\n"
696+
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
694697
)
695-
three_to_two = (
696-
nwp.col("a", "b", "c").abs().fill_null(0).round(2)
697-
* nwp.nth(9, 10).cast(nw.Int64).sort()
698+
shapes = "(2 != 3)"
699+
pattern = rf"{shapes}.+\n{expr}"
700+
two_to_three = (
701+
nwp.col("a", "b").abs().fill_null(0).round(2)
702+
* nwp.nth(9, 10, 11).cast(nw.Int64).sort()
698703
)
699704
with pytest.raises(MultiOutputExpressionError, match=pattern):
700-
df_1.project(three_to_two)
705+
df_1.project(two_to_three)
706+
707+
# fmt: off
708+
expr = re.escape(
709+
"ncs.numeric() / [(ncs.numeric()) - (ncs.by_dtype([Int64]))]\n"
710+
"^^^^^^^^^^^^^"
711+
)
712+
# fmt: on
713+
shapes = "(10 != 9)"
714+
pattern = rf"{shapes}.+\n{expr}"
715+
ten_to_nine = (
716+
ncs.numeric().as_expr() / (ncs.numeric() - ncs.by_dtype(nw.Int64)).as_expr()
717+
)
718+
with pytest.raises(MultiOutputExpressionError, match=pattern):
719+
df_1.project(ten_to_nine)

0 commit comments

Comments
 (0)