Skip to content

Commit 3165da4

Browse files
committed
fix: Handle lit broadcasting
1 parent 0d1394f commit 3165da4

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

narwhals/_plan/impl_arrow.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@ def cols(node: expr.Columns, frame: NativeFrame) -> Evaluated:
3838

3939

4040
@evaluate.register(expr.Literal)
41-
def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated: # noqa: ARG001
41+
def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated:
4242
import pyarrow as pa
4343

4444
if is_scalar_literal(node.value):
45-
scalar = node.value.unwrap()
46-
return [pa.chunked_array([[scalar]])]
45+
lit: t.Any = pa.scalar
46+
array = pa.repeat(lit(node.value.unwrap()), len(frame))
47+
return [pa.chunked_array([array])]
4748
elif is_series_literal(node.value):
4849
ca = node.value.unwrap().to_native()
4950
return [t.cast("NativeSeries", ca)]

tests/plan/to_compliant_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None:
4646
[
4747
(nwd.col("a"), ["A", "B", "A"]),
4848
(nwd.col("a", "b"), [["A", "B", "A"], [1, 2, 3]]),
49-
(nwd.lit(1), [1]),
50-
(nwd.lit(2.0), [2.0]),
51-
(nwd.lit(None, nw.String()), [None]),
49+
(nwd.lit(1), [1, 1, 1]),
50+
(nwd.lit(2.0), [2.0, 2.0, 2.0]),
51+
(nwd.lit(None, nw.String()), [None, None, None]),
5252
],
5353
ids=_ids_ir,
5454
)

0 commit comments

Comments
 (0)