Skip to content

Commit 7f9fb6b

Browse files
committed
feat: Implement EagerNamespace.from_numpy
Addresses #2196 (comment)
1 parent 91bd274 commit 7f9fb6b

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

narwhals/_compliant/namespace.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
from typing import Container
77
from typing import Iterable
88
from typing import Literal
9+
from typing import Mapping
910
from typing import Protocol
11+
from typing import Sequence
12+
from typing import overload
1013

1114
from narwhals._compliant.typing import CompliantExprT
1215
from narwhals._compliant.typing import CompliantFrameT
1316
from narwhals._compliant.typing import DepthTrackingExprT
1417
from narwhals._compliant.typing import EagerDataFrameT
1518
from narwhals._compliant.typing import EagerExprT
1619
from narwhals._compliant.typing import EagerSeriesT
20+
from narwhals.dependencies import is_numpy_array_2d
1721
from narwhals.utils import exclude_column_names
1822
from narwhals.utils import get_column_names
1923
from narwhals.utils import passthrough_column_names
@@ -25,6 +29,9 @@
2529
from narwhals._compliant.when_then import CompliantWhen
2630
from narwhals._compliant.when_then import EagerWhen
2731
from narwhals.dtypes import DType
32+
from narwhals.schema import Schema
33+
from narwhals.typing import Into1DArray
34+
from narwhals.typing import _2DArray
2835
from narwhals.utils import Implementation
2936
from narwhals.utils import Version
3037

@@ -116,3 +123,29 @@ def _series(self) -> type[EagerSeriesT]: ...
116123
def when(
117124
self, predicate: EagerExprT
118125
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, Incomplete]: ...
126+
127+
@overload
128+
def from_numpy(
129+
self,
130+
data: Into1DArray,
131+
/,
132+
schema: None = ...,
133+
) -> EagerSeriesT: ...
134+
135+
@overload
136+
def from_numpy(
137+
self,
138+
data: _2DArray,
139+
/,
140+
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
141+
) -> EagerDataFrameT: ...
142+
143+
def from_numpy(
144+
self,
145+
data: Into1DArray | _2DArray,
146+
/,
147+
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
148+
) -> EagerDataFrameT | EagerSeriesT:
149+
if is_numpy_array_2d(data):
150+
return self._dataframe.from_numpy(data, schema=schema, context=self)
151+
return self._series.from_numpy(data, context=self)

0 commit comments

Comments
 (0)