Skip to content

Commit c056b44

Browse files
authored
refactor: *(Namespace|DataFrame).from_numpy (#2283)
1 parent 36c6d57 commit c056b44

File tree

13 files changed

+364
-112
lines changed

13 files changed

+364
-112
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@
5353
from narwhals._arrow.typing import Mask # type: ignore[attr-defined]
5454
from narwhals._arrow.typing import Order # type: ignore[attr-defined]
5555
from narwhals.dtypes import DType
56+
from narwhals.schema import Schema
5657
from narwhals.typing import CompliantDataFrame
5758
from narwhals.typing import CompliantLazyFrame
5859
from narwhals.typing import SizeUnit
5960
from narwhals.typing import _1DArray
6061
from narwhals.typing import _2DArray
6162
from narwhals.utils import Version
63+
from narwhals.utils import _FullContext
6264

6365
JoinType: TypeAlias = Literal[
6466
"left semi",
@@ -91,6 +93,29 @@ def __init__(
9193
self._version = version
9294
validate_backend_version(self._implementation, self._backend_version)
9395

96+
@classmethod
97+
def from_numpy(
98+
cls,
99+
data: _2DArray,
100+
/,
101+
*,
102+
context: _FullContext,
103+
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
104+
) -> Self:
105+
from narwhals.schema import Schema
106+
107+
arrays = [pa.array(val) for val in data.T]
108+
if isinstance(schema, (Mapping, Schema)):
109+
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
110+
else:
111+
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
112+
return cls(
113+
native,
114+
backend_version=context._backend_version,
115+
version=context._version,
116+
validate_column_names=True,
117+
)
118+
94119
def __narwhals_namespace__(self: Self) -> ArrowNamespace:
95120
from narwhals._arrow.namespace import ArrowNamespace
96121

@@ -511,7 +536,7 @@ def to_polars(self: Self) -> pl.DataFrame:
511536

512537
return pl.from_arrow(self.native) # type: ignore[return-value]
513538

514-
def to_numpy(self: Self) -> _2DArray:
539+
def to_numpy(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
515540
import numpy as np # ignore-banned-import
516541

517542
arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns])

narwhals/_arrow/namespace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838

3939

4040
class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr]):
41+
@property
42+
def _dataframe(self) -> type[ArrowDataFrame]:
43+
return ArrowDataFrame
44+
4145
@property
4246
def _expr(self) -> type[ArrowExpr]:
4347
return ArrowExpr

narwhals/_compliant/dataframe.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from narwhals._compliant.typing import EagerSeriesT
1919
from narwhals._compliant.typing import NativeFrameT_co
2020
from narwhals._expression_parsing import evaluate_output_names_and_aliases
21+
from narwhals._translate import NumpyConvertible
2122
from narwhals.utils import Version
2223
from narwhals.utils import _StoresNative
2324
from narwhals.utils import deprecated
@@ -34,9 +35,11 @@
3435

3536
from narwhals._compliant.group_by import CompliantGroupBy
3637
from narwhals.dtypes import DType
38+
from narwhals.schema import Schema
3739
from narwhals.typing import SizeUnit
3840
from narwhals.typing import _2DArray
3941
from narwhals.utils import Implementation
42+
from narwhals.utils import _FullContext
4043

4144
Incomplete: TypeAlias = Any
4245

@@ -46,6 +49,7 @@
4649

4750

4851
class CompliantDataFrame(
52+
NumpyConvertible["_2DArray", "_2DArray"],
4953
_StoresNative[NativeFrameT_co],
5054
Sized,
5155
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT_co],
@@ -57,6 +61,15 @@ class CompliantDataFrame(
5761

5862
def __narwhals_dataframe__(self) -> Self: ...
5963
def __narwhals_namespace__(self) -> Any: ...
64+
@classmethod
65+
def from_numpy(
66+
cls,
67+
data: _2DArray,
68+
/,
69+
*,
70+
context: _FullContext,
71+
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
72+
) -> Self: ...
6073
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
6174
def __getitem__(self, item: Any) -> CompliantSeriesT | Self: ...
6275
def simple_select(self, *column_names: str) -> Self:
@@ -143,7 +156,6 @@ def sort(
143156
) -> Self: ...
144157
def tail(self, n: int) -> Self: ...
145158
def to_arrow(self) -> pa.Table: ...
146-
def to_numpy(self) -> _2DArray: ...
147159
def to_pandas(self) -> pd.DataFrame: ...
148160
def to_polars(self) -> pl.DataFrame: ...
149161
@overload
@@ -286,7 +298,8 @@ def _evaluate_expr(self, expr: EagerExprT_contra, /) -> EagerSeriesT:
286298
return result[0]
287299

288300
def _evaluate_into_exprs(self, *exprs: EagerExprT_contra) -> Sequence[EagerSeriesT]:
289-
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs))
301+
# NOTE: Ignore is to avoid an intermittent false positive
302+
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
290303

291304
def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSeriesT]:
292305
"""Return list of raw columns.
@@ -308,3 +321,9 @@ def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSerie
308321
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
309322
"""Extract native Series, broadcasting to `len(self)` if necessary."""
310323
...
324+
325+
@staticmethod
326+
def _numpy_column_names(
327+
data: _2DArray, columns: Sequence[str] | None, /
328+
) -> list[str]:
329+
return list(columns or (f"column_{x}" for x in range(data.shape[1])))

narwhals/_compliant/namespace.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
from typing import Container
77
from typing import Iterable
88
from typing import Literal
9+
from typing import Mapping
910
from typing import Protocol
11+
from typing import Sequence
12+
from typing import overload
1013

1114
from narwhals._compliant.typing import CompliantExprT
1215
from narwhals._compliant.typing import CompliantFrameT
1316
from narwhals._compliant.typing import DepthTrackingExprT
1417
from narwhals._compliant.typing import EagerDataFrameT
1518
from narwhals._compliant.typing import EagerExprT
1619
from narwhals._compliant.typing import EagerSeriesT
20+
from narwhals.dependencies import is_numpy_array_2d
1721
from narwhals.utils import exclude_column_names
1822
from narwhals.utils import get_column_names
1923
from narwhals.utils import passthrough_column_names
@@ -25,6 +29,9 @@
2529
from narwhals._compliant.when_then import CompliantWhen
2630
from narwhals._compliant.when_then import EagerWhen
2731
from narwhals.dtypes import DType
32+
from narwhals.schema import Schema
33+
from narwhals.typing import Into1DArray
34+
from narwhals.typing import _2DArray
2835
from narwhals.utils import Implementation
2936
from narwhals.utils import Version
3037

@@ -109,8 +116,36 @@ class EagerNamespace(
109116
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
110117
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT],
111118
):
119+
@property
120+
def _dataframe(self) -> type[EagerDataFrameT]: ...
112121
@property
113122
def _series(self) -> type[EagerSeriesT]: ...
114123
def when(
115124
self, predicate: EagerExprT
116125
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, Incomplete]: ...
126+
127+
@overload
128+
def from_numpy(
129+
self,
130+
data: Into1DArray,
131+
/,
132+
schema: None = ...,
133+
) -> EagerSeriesT: ...
134+
135+
@overload
136+
def from_numpy(
137+
self,
138+
data: _2DArray,
139+
/,
140+
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
141+
) -> EagerDataFrameT: ...
142+
143+
def from_numpy(
144+
self,
145+
data: Into1DArray | _2DArray,
146+
/,
147+
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
148+
) -> EagerDataFrameT | EagerSeriesT:
149+
if is_numpy_array_2d(data):
150+
return self._dataframe.from_numpy(data, schema=schema, context=self)
151+
return self._series.from_numpy(data, context=self)

narwhals/_pandas_like/dataframe.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import Callable
6+
from typing import Iterable
57
from typing import Iterator
68
from typing import Literal
79
from typing import Mapping
@@ -17,6 +19,7 @@
1719
from narwhals._pandas_like.utils import align_series_full_broadcast
1820
from narwhals._pandas_like.utils import check_column_names_are_unique
1921
from narwhals._pandas_like.utils import convert_str_slice_to_int_slice
22+
from narwhals._pandas_like.utils import get_dtype_backend
2023
from narwhals._pandas_like.utils import horizontal_concat
2124
from narwhals._pandas_like.utils import native_to_narwhals_dtype
2225
from narwhals._pandas_like.utils import object_native_to_narwhals_dtype
@@ -46,17 +49,23 @@
4649
import pandas as pd
4750
import polars as pl
4851
from typing_extensions import Self
52+
from typing_extensions import TypeAlias
4953

5054
from narwhals._pandas_like.expr import PandasLikeExpr
5155
from narwhals._pandas_like.group_by import PandasLikeGroupBy
5256
from narwhals._pandas_like.namespace import PandasLikeNamespace
5357
from narwhals.dtypes import DType
58+
from narwhals.schema import Schema
5459
from narwhals.typing import CompliantDataFrame
5560
from narwhals.typing import CompliantLazyFrame
61+
from narwhals.typing import DTypeBackend
5662
from narwhals.typing import SizeUnit
5763
from narwhals.typing import _1DArray
5864
from narwhals.typing import _2DArray
5965
from narwhals.utils import Version
66+
from narwhals.utils import _FullContext
67+
68+
Constructor: TypeAlias = Callable[..., pd.DataFrame]
6069

6170

6271
CLASSICAL_NUMPY_DTYPES: frozenset[np.dtype[Any]] = frozenset(
@@ -104,6 +113,37 @@ def __init__(
104113
if validate_column_names:
105114
check_column_names_are_unique(native_dataframe.columns)
106115

116+
@classmethod
117+
def from_numpy(
118+
cls,
119+
data: _2DArray,
120+
/,
121+
*,
122+
context: _FullContext,
123+
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
124+
) -> Self:
125+
from narwhals.schema import Schema
126+
127+
implementation = context._implementation
128+
DataFrame: Constructor = implementation.to_native_namespace().DataFrame # noqa: N806
129+
if isinstance(schema, (Mapping, Schema)):
130+
it: Iterable[DTypeBackend] = (
131+
get_dtype_backend(native_type, implementation)
132+
for native_type in schema.values()
133+
)
134+
native = DataFrame(data, columns=schema.keys()).astype(
135+
Schema(schema).to_pandas(it)
136+
)
137+
else:
138+
native = DataFrame(data, columns=cls._numpy_column_names(data, schema))
139+
return cls(
140+
native,
141+
implementation=implementation,
142+
backend_version=context._backend_version,
143+
version=context._version,
144+
validate_column_names=True,
145+
)
146+
107147
def __narwhals_dataframe__(self: Self) -> Self:
108148
return self
109149

narwhals/_pandas_like/namespace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
class PandasLikeNamespace(
3535
EagerNamespace[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr]
3636
):
37+
@property
38+
def _dataframe(self) -> type[PandasLikeDataFrame]:
39+
return PandasLikeDataFrame
40+
3741
@property
3842
def _expr(self) -> type[PandasLikeExpr]:
3943
return PandasLikeExpr

narwhals/_polars/dataframe.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@
3636
from narwhals._polars.group_by import PolarsLazyGroupBy
3737
from narwhals._polars.series import PolarsSeries
3838
from narwhals.dtypes import DType
39+
from narwhals.schema import Schema
3940
from narwhals.typing import CompliantDataFrame
4041
from narwhals.typing import CompliantLazyFrame
4142
from narwhals.typing import _2DArray
4243
from narwhals.utils import Version
44+
from narwhals.utils import _FullContext
4345

4446
T = TypeVar("T")
4547
R = TypeVar("R")
@@ -92,6 +94,27 @@ def __init__(
9294
self._version = version
9395
validate_backend_version(self._implementation, self._backend_version)
9496

97+
@classmethod
98+
def from_numpy(
99+
cls,
100+
data: _2DArray,
101+
/,
102+
*,
103+
context: _FullContext, # NOTE: Maybe only `Implementation`?
104+
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
105+
) -> Self:
106+
from narwhals.schema import Schema
107+
108+
pl_schema = (
109+
Schema(schema).to_polars()
110+
if isinstance(schema, (Mapping, Schema))
111+
else schema
112+
)
113+
native = pl.from_numpy(data, pl_schema)
114+
return cls(
115+
native, backend_version=context._backend_version, version=context._version
116+
)
117+
95118
@property
96119
def native(self) -> pl.DataFrame:
97120
return self._native_frame

0 commit comments

Comments
 (0)