Skip to content

Commit ca1f822

Browse files
committed
feat: Improve Literal type unwrapping
- 4th iteration of trying to get this working - A nice property of this is the unimplemented `RangeLiteral` case is (statically) unreachable in `impl_arrow` - Without the need to name *any* of the `LiteralValue` classes - Also the traversal is now hidden behind `unwrap`, which preserves the type
1 parent 3165da4 commit ca1f822

File tree

4 files changed

+52
-22
lines changed

4 files changed

+52
-22
lines changed

narwhals/_plan/expr.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
import typing as t
1414

1515
from narwhals._plan.aggregation import Agg, OrderableAgg
16-
from narwhals._plan.common import ExprIR, SelectorIR, _field_str
16+
from narwhals._plan.common import ExprIR, SelectorIR, _field_str, is_non_nested_literal
1717
from narwhals._plan.name import KeepName, RenameAlias
1818
from narwhals._plan.typing import (
1919
ExprT,
2020
FunctionT,
2121
LeftSelectorT,
2222
LeftT,
23+
LiteralT,
2324
Ns,
2425
OperatorT,
2526
RightSelectorT,
@@ -115,12 +116,12 @@ def to_compliant(self, plx: Ns[ExprT], /) -> ExprT:
115116
return plx.col(*self.names)
116117

117118

118-
class Literal(ExprIR):
119+
class Literal(ExprIR, t.Generic[LiteralT]):
119120
"""https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81."""
120121

121122
__slots__ = ("value",)
122123

123-
value: LiteralValue
124+
value: LiteralValue[LiteralT]
124125

125126
@property
126127
def is_scalar(self) -> bool:
@@ -138,7 +139,13 @@ def __repr__(self) -> str:
138139
return f"lit({self.value!r})"
139140

140141
def to_compliant(self, plx: Ns[ExprT], /) -> ExprT:
141-
return plx.lit(self.value.unwrap(), self.dtype)
142+
value = self.unwrap()
143+
if is_non_nested_literal(value):
144+
return plx.lit(value, self.dtype)
145+
raise NotImplementedError(type(self.value))
146+
147+
def unwrap(self) -> LiteralT:
148+
return self.value.unwrap()
142149

143150

144151
class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]):

narwhals/_plan/impl_arrow.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from functools import singledispatch
1010

1111
from narwhals._plan import expr
12-
from narwhals._plan.literal import is_scalar_literal, is_series_literal
12+
from narwhals._plan.literal import is_literal_scalar, is_literal_series
1313

1414
if t.TYPE_CHECKING:
1515
import pyarrow as pa
1616
from typing_extensions import TypeAlias
1717

1818
from narwhals._plan.common import ExprIR
19+
from narwhals._plan.dummy import DummySeries
20+
from narwhals.typing import NonNestedLiteral
1921

2022
NativeFrame: TypeAlias = pa.Table
2123
NativeSeries: TypeAlias = pa.ChunkedArray[t.Any]
@@ -38,15 +40,17 @@ def cols(node: expr.Columns, frame: NativeFrame) -> Evaluated:
3840

3941

4042
@evaluate.register(expr.Literal)
41-
def lit(node: expr.Literal, frame: NativeFrame) -> Evaluated:
43+
def lit(
44+
node: expr.Literal[NonNestedLiteral] | expr.Literal[DummySeries], frame: NativeFrame
45+
) -> Evaluated:
4246
import pyarrow as pa
4347

44-
if is_scalar_literal(node.value):
48+
if is_literal_scalar(node):
4549
lit: t.Any = pa.scalar
46-
array = pa.repeat(lit(node.value.unwrap()), len(frame))
50+
array = pa.repeat(lit(node.unwrap()), len(frame))
4751
return [pa.chunked_array([array])]
48-
elif is_series_literal(node.value):
49-
ca = node.value.unwrap().to_native()
52+
elif is_literal_series(node):
53+
ca = node.unwrap().to_native()
5054
return [t.cast("NativeSeries", ca)]
5155
else:
5256
raise NotImplementedError(type(node.value))

narwhals/_plan/literal.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,17 @@
33
from typing import TYPE_CHECKING, Any, Generic
44

55
from narwhals._plan.common import Immutable
6+
from narwhals._plan.typing import LiteralT, NonNestedLiteralT
67

78
if TYPE_CHECKING:
89
from typing_extensions import TypeIs
910

1011
from narwhals._plan.dummy import DummySeries
1112
from narwhals._plan.expr import Literal
1213
from narwhals.dtypes import DType
13-
from narwhals.typing import NonNestedLiteral
1414

15-
from narwhals._typing_compat import TypeVar
1615

17-
T = TypeVar("T", default=Any)
18-
NonNestedLiteralT = TypeVar(
19-
"NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral"
20-
)
21-
22-
23-
class LiteralValue(Immutable, Generic[T]):
16+
class LiteralValue(Immutable, Generic[LiteralT]):
2417
"""https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/lit.rs#L67-L73."""
2518

2619
@property
@@ -40,7 +33,7 @@ def to_literal(self) -> Literal:
4033

4134
return Literal(value=self)
4235

43-
def unwrap(self) -> T:
36+
def unwrap(self) -> LiteralT:
4437
raise NotImplementedError
4538

4639

@@ -102,9 +95,27 @@ class RangeLiteral(LiteralValue):
10295
dtype: DType
10396

10497

105-
def is_scalar_literal(obj: Any) -> TypeIs[ScalarLiteral]:
98+
def _is_scalar(
99+
obj: ScalarLiteral[NonNestedLiteralT] | Any,
100+
) -> TypeIs[ScalarLiteral[NonNestedLiteralT]]:
106101
return isinstance(obj, ScalarLiteral)
107102

108103

109-
def is_series_literal(obj: Any) -> TypeIs[SeriesLiteral]:
104+
def _is_series(obj: Any) -> TypeIs[SeriesLiteral]:
110105
return isinstance(obj, SeriesLiteral)
106+
107+
108+
def is_literal(obj: Literal[LiteralT] | Any) -> TypeIs[Literal[LiteralT]]:
109+
from narwhals._plan.expr import Literal
110+
111+
return isinstance(obj, Literal)
112+
113+
114+
def is_literal_scalar(
115+
obj: Literal[NonNestedLiteralT] | Any,
116+
) -> TypeIs[Literal[NonNestedLiteralT]]:
117+
return is_literal(obj) and _is_scalar(obj.value)
118+
119+
120+
def is_literal_series(obj: Literal[DummySeries] | Any) -> TypeIs[Literal[DummySeries]]:
121+
return is_literal(obj) and _is_series(obj.value)

narwhals/_plan/typing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from narwhals._compliant.typing import CompliantExprAny
1212
from narwhals._plan import operators as ops
1313
from narwhals._plan.common import ExprIR, Function, IRNamespace, SelectorIR
14+
from narwhals._plan.dummy import DummySeries
1415
from narwhals._plan.functions import RollingWindow
16+
from narwhals.typing import NonNestedLiteral
1517

1618
__all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT", "SelectorOperatorT"]
1719

@@ -29,6 +31,12 @@
2931
"SelectorOperatorT", bound="ops.SelectorOperator", default="ops.SelectorOperator"
3032
)
3133
IRNamespaceT = TypeVar("IRNamespaceT", bound="IRNamespace")
34+
35+
NonNestedLiteralT = TypeVar(
36+
"NonNestedLiteralT", bound="NonNestedLiteral", default="NonNestedLiteral"
37+
)
38+
LiteralT = TypeVar("LiteralT", bound="NonNestedLiteral | DummySeries", default=t.Any)
39+
3240
# NOTE: Shorter aliases of `_compliant.typing`
3341
# - Aiming to try and preserve the types as much as possible
3442
# - Recursion between `Expr` and `Frame` is an issue

0 commit comments

Comments
 (0)