Skip to content

Commit 936c78b

Browse files
committed
chore: planning fancy binary combination expansion
1 parent 9b2f617 commit 936c78b

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

narwhals/_plan/_expansion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _expand_only(self, child: ExprIR, /) -> ExprIR:
265265
raise MultiOutputExpressionError(msg)
266266
return first
267267

268-
# TODO @dangotbanned: It works, but all this class-specfic branching belongs in the classes themselves
268+
# TODO @dangotbanned: It works, but all this class-specific branching belongs in the classes themselves
269269
def _expand_combination(self, origin: Combination, /) -> Iterator[Combination]:
270270
changes: dict[str, Any] = {}
271271
if isinstance(origin, (ir.WindowExpr, ir.Filter, ir.SortBy)):
@@ -281,6 +281,10 @@ def _expand_combination(self, origin: Combination, /) -> Iterator[Combination]:
281281
replaced = common.replace(origin, **changes)
282282
for root in self._expand_recursive(replaced.expr):
283283
yield common.replace(replaced, expr=root)
284+
285+
# TODO @dangotbanned: Relax `BinaryExpr.right`
286+
# - https://github.com/narwhals-dev/narwhals/pull/3233#discussion_r2472757798
287+
# - https://github.com/narwhals-dev/narwhals/pull/3233#discussion_r2473810664
284288
elif isinstance(origin, ir.BinaryExpr):
285289
binary = origin.__replace__(right=self._expand_only(origin.right))
286290
for root in self._expand_recursive(binary.left):

narwhals/_plan/expressions/operators.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def to_binary_expr(
4646
) -> BinaryExpr[LeftT, Self, RightT]:
4747
from narwhals._plan.expressions.expr import BinaryExpr
4848

49+
# TODO @dangotbanned: Leave validating this until actually expanding
50+
# Only need to raise if outputs are not:
51+
# - 1:1
52+
# - M:1
53+
# - 1:M
54+
# - N:N
55+
# https://github.com/narwhals-dev/narwhals/pull/3233#discussion_r2473810664
4956
if right.meta.has_multiple_outputs():
5057
raise binary_expr_multi_output_error(left, self, right)
5158
if _is_filtration(left):

tests/plan/expr_parsing_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,15 @@ def test_filtration_over() -> None:
208208

209209

210210
def test_invalid_binary_expr_multi() -> None:
211+
# TODO @dangotbanned: Move to `expr_expansion_test`
211212
pattern = re.escape(
212213
"ncs.all() + ncs.by_name('b', 'c', require_all=True)\n"
213214
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
214215
)
215216
with pytest.raises(MultiOutputExpressionError, match=pattern):
216217
nwp.all() + nwp.col("b", "c")
217218

219+
# TODO @dangotbanned: Use as a positive case (3:3)
218220
pattern = re.escape(
219221
"ncs.by_index([1, 2, 3], require_all=True) * ncs.by_index([4, 5, 6], require_all=True).max()\n"
220222
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
@@ -225,6 +227,7 @@ def test_invalid_binary_expr_multi() -> None:
225227
"ncs.by_name('a', 'b', 'c', require_all=True).abs().fill_null([lit(int: 0)]).round() * ncs.by_index([9, 10], require_all=True).cast(Int64).sort(asc)\n"
226228
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"
227229
)
230+
# TODO @dangotbanned: Move to `expr_expansion_test`
228231
with pytest.raises(MultiOutputExpressionError, match=pattern):
229232
nwp.col("a", "b", "c").abs().fill_null(0).round(2) * nwp.nth(9, 10).cast(
230233
nw.Int64()

0 commit comments

Comments
 (0)