Skip to content

Commit a770e70

Browse files
authored
feat(typing): Add IntoSchema alias (#2945)
1 parent bfd31c4 commit a770e70

File tree

10 files changed

+61
-38
lines changed

10 files changed

+61
-38
lines changed

docs/api-reference/typing.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Narwhals comes fully statically typed. In addition to `nw.DataFrame`, `nw.Expr`,
1818
- IntoSeries
1919
- IntoSeriesT
2020
- IntoDType
21+
- IntoSchema
2122
- SizeUnit
2223
- TimeUnit
2324
- AsofJoinStrategy

narwhals/_arrow/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
from narwhals._translate import IntoArrowTable
4747
from narwhals._utils import Version, _LimitedContext
4848
from narwhals.dtypes import DType
49-
from narwhals.schema import Schema
5049
from narwhals.typing import (
50+
IntoSchema,
5151
JoinStrategy,
5252
SizedMultiIndexSelector,
5353
SizedMultiNameSelector,
@@ -114,7 +114,7 @@ def from_dict(
114114
/,
115115
*,
116116
context: _LimitedContext,
117-
schema: Mapping[str, DType] | Schema | None,
117+
schema: IntoSchema | None,
118118
) -> Self:
119119
from narwhals.schema import Schema
120120

@@ -140,7 +140,7 @@ def from_numpy(
140140
/,
141141
*,
142142
context: _LimitedContext,
143-
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
143+
schema: IntoSchema | Sequence[str] | None,
144144
) -> Self:
145145
from narwhals.schema import Schema
146146

narwhals/_compliant/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@
5353
from narwhals.dataframe import DataFrame
5454
from narwhals.dtypes import DType
5555
from narwhals.exceptions import ColumnNotFoundError
56-
from narwhals.schema import Schema
5756
from narwhals.typing import (
5857
AsofJoinStrategy,
58+
IntoSchema,
5959
JoinStrategy,
6060
LazyUniqueKeepStrategy,
6161
MultiColSelector,
@@ -105,7 +105,7 @@ def from_dict(
105105
/,
106106
*,
107107
context: _LimitedContext,
108-
schema: Mapping[str, DType] | Schema | None,
108+
schema: IntoSchema | None,
109109
) -> Self: ...
110110
@classmethod
111111
def from_native(cls, data: NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
@@ -116,7 +116,7 @@ def from_numpy(
116116
/,
117117
*,
118118
context: _LimitedContext,
119-
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
119+
schema: IntoSchema | Sequence[str] | None,
120120
) -> Self: ...
121121

122122
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...

narwhals/_compliant/namespace.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,18 @@
2424
from narwhals.dependencies import is_numpy_array_2d
2525

2626
if TYPE_CHECKING:
27-
from collections.abc import Container, Iterable, Mapping, Sequence
27+
from collections.abc import Container, Iterable, Sequence
2828

2929
from typing_extensions import TypeAlias
3030

3131
from narwhals._compliant.selectors import CompliantSelectorNamespace
3232
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
3333
from narwhals._utils import Implementation, Version
34-
from narwhals.dtypes import DType
35-
from narwhals.schema import Schema
3634
from narwhals.typing import (
3735
ConcatMethod,
3836
Into1DArray,
3937
IntoDType,
38+
IntoSchema,
4039
NonNestedLiteral,
4140
_2DArray,
4241
)
@@ -174,17 +173,14 @@ def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT:
174173

175174
@overload
176175
def from_numpy(
177-
self,
178-
data: _2DArray,
179-
/,
180-
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
176+
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
181177
) -> EagerDataFrameT: ...
182178

183179
def from_numpy(
184180
self,
185181
data: Into1DArray | _2DArray,
186182
/,
187-
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
183+
schema: IntoSchema | Sequence[str] | None = None,
188184
) -> EagerDataFrameT | EagerSeriesT:
189185
if is_numpy_array_2d(data):
190186
return self._dataframe.from_numpy(data, schema=schema, context=self)

narwhals/_pandas_like/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@
4949
from narwhals._translate import IntoArrowTable
5050
from narwhals._utils import Version, _LimitedContext
5151
from narwhals.dtypes import DType
52-
from narwhals.schema import Schema
5352
from narwhals.typing import (
5453
AsofJoinStrategy,
5554
DTypeBackend,
55+
IntoSchema,
5656
JoinStrategy,
5757
PivotAgg,
5858
SizedMultiIndexSelector,
@@ -144,7 +144,7 @@ def from_dict(
144144
/,
145145
*,
146146
context: _LimitedContext,
147-
schema: Mapping[str, DType] | Schema | None,
147+
schema: IntoSchema | None,
148148
) -> Self:
149149
from narwhals.schema import Schema
150150

@@ -195,7 +195,7 @@ def from_numpy(
195195
/,
196196
*,
197197
context: _LimitedContext,
198-
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
198+
schema: IntoSchema | Sequence[str] | None,
199199
) -> Self:
200200
from narwhals.schema import Schema
201201

narwhals/_polars/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
from narwhals._utils import Version, _LimitedContext
4444
from narwhals.dataframe import DataFrame, LazyFrame
4545
from narwhals.dtypes import DType
46-
from narwhals.schema import Schema
4746
from narwhals.typing import (
47+
IntoSchema,
4848
JoinStrategy,
4949
MultiColSelector,
5050
MultiIndexSelector,
@@ -272,7 +272,7 @@ def from_dict(
272272
/,
273273
*,
274274
context: _LimitedContext,
275-
schema: Mapping[str, DType] | Schema | None,
275+
schema: IntoSchema | None,
276276
) -> Self:
277277
from narwhals.schema import Schema
278278

@@ -290,7 +290,7 @@ def from_numpy(
290290
/,
291291
*,
292292
context: _LimitedContext, # NOTE: Maybe only `Implementation`?
293-
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
293+
schema: IntoSchema | Sequence[str] | None,
294294
) -> Self:
295295
from narwhals.schema import Schema
296296

narwhals/_polars/namespace.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
from narwhals.dtypes import DType
1414

1515
if TYPE_CHECKING:
16-
from collections.abc import Iterable, Mapping, Sequence
16+
from collections.abc import Iterable, Sequence
1717
from datetime import timezone
1818

1919
from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen
2020
from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame
2121
from narwhals._polars.typing import FrameT
2222
from narwhals._utils import Version, _LimitedContext
23-
from narwhals.schema import Schema
24-
from narwhals.typing import Into1DArray, IntoDType, TimeUnit, _2DArray
23+
from narwhals.typing import Into1DArray, IntoDType, IntoSchema, TimeUnit, _2DArray
2524

2625

2726
class PolarsNamespace:
@@ -94,17 +93,14 @@ def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries:
9493

9594
@overload
9695
def from_numpy(
97-
self,
98-
data: _2DArray,
99-
/,
100-
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
96+
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
10197
) -> PolarsDataFrame: ...
10298

10399
def from_numpy(
104100
self,
105101
data: Into1DArray | _2DArray,
106102
/,
107-
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
103+
schema: IntoSchema | Sequence[str] | None = None,
108104
) -> PolarsDataFrame | PolarsSeries:
109105
if is_numpy_array_2d(data):
110106
return self._dataframe.from_numpy(data, schema=schema, context=self)

narwhals/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@
6767
from narwhals._compliant import CompliantDataFrame, CompliantLazyFrame
6868
from narwhals._compliant.typing import CompliantExprAny, EagerNamespaceAny
6969
from narwhals._translate import IntoArrowTable
70-
from narwhals.dtypes import DType
7170
from narwhals.group_by import GroupBy, LazyGroupBy
7271
from narwhals.typing import (
7372
AsofJoinStrategy,
7473
IntoDataFrame,
7574
IntoExpr,
7675
IntoFrame,
76+
IntoSchema,
7777
JoinStrategy,
7878
LazyUniqueKeepStrategy,
7979
MultiColSelector as _MultiColSelector,
@@ -531,7 +531,7 @@ def from_arrow(
531531
def from_dict(
532532
cls,
533533
data: Mapping[str, Any],
534-
schema: Mapping[str, DType] | Schema | None = None,
534+
schema: IntoSchema | None = None,
535535
*,
536536
backend: ModuleType | Implementation | str | None = None,
537537
) -> DataFrame[Any]:
@@ -593,7 +593,7 @@ def from_dict(
593593
def from_numpy(
594594
cls,
595595
data: _2DArray,
596-
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
596+
schema: IntoSchema | Sequence[str] | None = None,
597597
*,
598598
backend: ModuleType | Implementation | str,
599599
) -> DataFrame[Any]:

narwhals/functions.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,12 @@
4444
from narwhals._compliant import CompliantExpr, CompliantNamespace
4545
from narwhals._translate import IntoArrowTable
4646
from narwhals.dataframe import DataFrame, LazyFrame
47-
from narwhals.dtypes import DType
48-
from narwhals.schema import Schema
4947
from narwhals.typing import (
5048
ConcatMethod,
5149
FrameT,
5250
IntoDType,
5351
IntoExpr,
52+
IntoSchema,
5453
NativeFrame,
5554
NativeLazyFrame,
5655
NativeSeries,
@@ -59,7 +58,7 @@
5958
_2DArray,
6059
)
6160

62-
_IntoSchema: TypeAlias = "Mapping[str, DType] | Schema | Sequence[str] | None"
61+
_IntoSchema: TypeAlias = "IntoSchema | Sequence[str] | None"
6362

6463

6564
def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT:
@@ -245,7 +244,7 @@ def _new_series_impl(
245244
@deprecate_native_namespace(warn_version="1.26.0")
246245
def from_dict(
247246
data: Mapping[str, Any],
248-
schema: Mapping[str, DType] | Schema | None = None,
247+
schema: IntoSchema | None = None,
249248
*,
250249
backend: ModuleType | Implementation | str | None = None,
251250
native_namespace: ModuleType | None = None, # noqa: ARG001
@@ -330,7 +329,7 @@ def _from_dict_no_backend(
330329

331330
def from_numpy(
332331
data: _2DArray,
333-
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
332+
schema: IntoSchema | Sequence[str] | None = None,
334333
*,
335334
backend: ModuleType | Implementation | str,
336335
) -> DataFrame[Any]:
@@ -384,7 +383,7 @@ def from_numpy(
384383
if not _is_into_schema(schema):
385384
msg = (
386385
"`schema` is expected to be one of the following types: "
387-
"Mapping[str, DType] | Schema | Sequence[str]. "
386+
"IntoSchema | Sequence[str]. "
388387
f"Got {type(schema)}."
389388
)
390389
raise TypeError(msg)

narwhals/typing.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
if TYPE_CHECKING:
88
import datetime as dt
9-
from collections.abc import Iterable, Sequence, Sized
9+
from collections.abc import Iterable, Mapping, Sequence, Sized
1010
from decimal import Decimal
1111
from types import ModuleType
1212

@@ -16,6 +16,7 @@
1616
from narwhals import dtypes
1717
from narwhals.dataframe import DataFrame, LazyFrame
1818
from narwhals.expr import Expr
19+
from narwhals.schema import Schema
1920
from narwhals.series import Series
2021

2122
# All dataframes supported by Narwhals have a
@@ -390,6 +391,36 @@ def Binary(self) -> type[dtypes.Binary]: ...
390391
└──────────────────┘
391392
"""
392393

394+
# TODO @dangotbanned: fix this?
395+
# Constructor allows tuples, but we don't support that *everywhere* yet
396+
IntoSchema: TypeAlias = "Mapping[str, dtypes.DType] | Schema"
397+
"""Anything that can be converted into a Narwhals Schema.
398+
399+
Defined by column names and their associated *instantiated* Narwhals DType.
400+
401+
Examples:
402+
>>> import narwhals as nw
403+
>>> import pyarrow as pa
404+
>>> data = {"a": [1, 2, 3], "b": [None, "hi", "howdy"], "c": [2.1, 2.0, None]}
405+
>>> nw.DataFrame.from_dict(
406+
... data,
407+
... schema={"a": nw.UInt8(), "b": nw.String(), "c": nw.Float32()},
408+
... backend="pyarrow",
409+
... )
410+
┌────────────────────────┐
411+
| Narwhals DataFrame |
412+
|------------------------|
413+
|pyarrow.Table |
414+
|a: uint8 |
415+
|b: string |
416+
|c: float |
417+
|---- |
418+
|a: [[1,2,3]] |
419+
|b: [[null,"hi","howdy"]]|
420+
|c: [[2.1,2,null]] |
421+
└────────────────────────┘
422+
"""
423+
393424

394425
# Annotations for `__getitem__` methods
395426
_T = TypeVar("_T")

0 commit comments

Comments
 (0)