Skip to content

Commit 0d1394f

Browse files
committed
test: Fix lit and add some tests
Forgot one level of nesting on the lists
1 parent 7570585 commit 0d1394f

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

narwhals/_plan/impl_arrow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated: # noqa: ARG001
4242
import pyarrow as pa
4343

4444
if is_scalar_literal(node.value):
45-
return [pa.chunked_array([node.value.unwrap()])]
45+
scalar = node.value.unwrap()
46+
return [pa.chunked_array([[scalar]])]
4647
elif is_series_literal(node.value):
4748
ca = node.value.unwrap().to_native()
4849
return [t.cast("NativeSeries", ca)]

tests/plan/to_compliant_test.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any
44

55
import pytest
66

77
import narwhals as nw
88
import narwhals._plan.demo as nwd
9+
from narwhals._plan.common import is_expr
10+
from narwhals._plan.impl_arrow import evaluate as evaluate_pyarrow
911
from narwhals.utils import Version
1012
from tests.namespace_test import backends
1113

@@ -14,8 +16,10 @@
1416
from narwhals._plan.dummy import DummyExpr
1517

1618

17-
def _ids_ir(expr: DummyExpr) -> str:
18-
return repr(expr._ir)
19+
def _ids_ir(expr: DummyExpr | Any) -> str:
20+
if is_expr(expr):
21+
return repr(expr._ir)
22+
return repr(expr)
1923

2024

2125
@pytest.mark.parametrize(
@@ -35,3 +39,33 @@ def test_to_compliant(backend: BackendName, expr: DummyExpr) -> None:
3539
namespace = Version.MAIN.namespace.from_backend(backend).compliant
3640
compliant_expr = expr._to_compliant(namespace)
3741
assert isinstance(compliant_expr, namespace._expr)
42+
43+
44+
@pytest.mark.parametrize(
45+
("expr", "expected"),
46+
[
47+
(nwd.col("a"), ["A", "B", "A"]),
48+
(nwd.col("a", "b"), [["A", "B", "A"], [1, 2, 3]]),
49+
(nwd.lit(1), [1]),
50+
(nwd.lit(2.0), [2.0]),
51+
(nwd.lit(None, nw.String()), [None]),
52+
],
53+
ids=_ids_ir,
54+
)
55+
def test_evaluate_pyarrow(expr: DummyExpr, expected: Any) -> None:
56+
pytest.importorskip("pyarrow")
57+
import pyarrow as pa
58+
59+
data: dict[str, Any] = {
60+
"a": ["A", "B", "A"],
61+
"b": [1, 2, 3],
62+
"c": [9, 2, 4],
63+
"d": [8, 7, 8],
64+
}
65+
frame = pa.table(data)
66+
result = evaluate_pyarrow(expr._ir, frame)
67+
if len(result) == 1:
68+
assert result[0].to_pylist() == expected
69+
else:
70+
results = [col.to_pylist() for col in result]
71+
assert results == expected

0 commit comments

Comments
 (0)