Skip to content

Commit b80e0e1

Browse files
committed
feat: Infer DType in lit
1 parent 563076d commit b80e0e1

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

narwhals/_plan/common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from narwhals._plan.lists import IRListNamespace
2323
from narwhals._plan.meta import IRMetaNamespace
2424
from narwhals._plan.options import FunctionOptions
25+
from narwhals.dtypes import DType
2526
from narwhals.typing import NonNestedLiteral
2627

2728
else:
@@ -315,3 +316,21 @@ def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | DummySeries]:
315316
from narwhals._plan.dummy import DummySeries
316317

317318
return isinstance(obj, (str, bytes, DummySeries))
319+
320+
321+
def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType:
322+
dtypes = version.dtypes
323+
mapping: dict[type[NonNestedLiteral], type[DType]] = {
324+
int: dtypes.Int64,
325+
float: dtypes.Float64,
326+
str: dtypes.String,
327+
bool: dtypes.Boolean,
328+
dt.datetime: dtypes.Datetime,
329+
dt.date: dtypes.Date,
330+
dt.time: dtypes.Time,
331+
dt.timedelta: dtypes.Duration,
332+
bytes: dtypes.Binary,
333+
Decimal: dtypes.Decimal,
334+
type(None): dtypes.Unknown,
335+
}
336+
return mapping.get(type(obj), dtypes.Unknown)()

narwhals/_plan/demo.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
expr_parsing as parse,
1010
functions as F, # noqa: N812
1111
)
12-
from narwhals._plan.common import ExprIR, IntoExpr, is_expr, is_non_nested_literal
12+
from narwhals._plan.common import (
13+
ExprIR,
14+
IntoExpr,
15+
is_expr,
16+
is_non_nested_literal,
17+
py_to_narwhals_dtype,
18+
)
1319
from narwhals._plan.dummy import DummySeries
1420
from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth
1521
from narwhals._plan.literal import ScalarLiteral, SeriesLiteral
@@ -52,11 +58,11 @@ def lit(
5258
) -> DummyExpr:
5359
if isinstance(value, DummySeries):
5460
return SeriesLiteral(value=value).to_literal().to_narwhals()
55-
if dtype is None or not isinstance(dtype, DType):
56-
dtype = Version.MAIN.dtypes.Unknown()
5761
if not is_non_nested_literal(value):
5862
msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}."
5963
raise TypeError(msg)
64+
if dtype is None or not isinstance(dtype, DType):
65+
dtype = py_to_narwhals_dtype(value, Version.MAIN)
6066
return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals()
6167

6268

0 commit comments

Comments
 (0)