Skip to content

Commit c14f43b

Browse files
authored
refactor: Add CompliantSeries.from_numpy (#2196)
1 parent 1b8d548 commit c14f43b

File tree

18 files changed

+161
-144
lines changed

18 files changed

+161
-144
lines changed

narwhals/_arrow/expr.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from narwhals._expression_parsing import ExprKind
1414
from narwhals._expression_parsing import evaluate_output_names_and_aliases
1515
from narwhals._expression_parsing import is_scalar_like
16-
from narwhals.dependencies import get_numpy
17-
from narwhals.dependencies import is_numpy_array
1816
from narwhals.exceptions import ColumnNotFoundError
1917
from narwhals.utils import Implementation
2018
from narwhals.utils import generate_temporary_column_name
@@ -25,7 +23,6 @@
2523

2624
from narwhals._arrow.dataframe import ArrowDataFrame
2725
from narwhals._arrow.namespace import ArrowNamespace
28-
from narwhals.dtypes import DType
2926
from narwhals.utils import Version
3027
from narwhals.utils import _FullContext
3128

@@ -203,44 +200,6 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
203200
version=self._version,
204201
)
205202

206-
def map_batches(
207-
self: Self,
208-
function: Callable[[Any], Any],
209-
return_dtype: DType | type[DType] | None,
210-
) -> Self:
211-
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
212-
input_series_list = self._call(df)
213-
output_names = [input_series.name for input_series in input_series_list]
214-
result = [function(series) for series in input_series_list]
215-
216-
if is_numpy_array(result[0]):
217-
result = [
218-
df.__narwhals_namespace__()
219-
._create_compliant_series(array)
220-
.alias(output_name)
221-
for array, output_name in zip(result, output_names)
222-
]
223-
elif (np := get_numpy()) is not None and np.isscalar(result[0]):
224-
result = [
225-
df.__narwhals_namespace__()
226-
._create_compliant_series([array])
227-
.alias(output_name)
228-
for array, output_name in zip(result, output_names)
229-
]
230-
if return_dtype is not None:
231-
result = [series.cast(return_dtype) for series in result]
232-
return result
233-
234-
return self.__class__(
235-
func,
236-
depth=self._depth + 1,
237-
function_name=self._function_name + "->map_batches",
238-
evaluate_output_names=self._evaluate_output_names,
239-
alias_output_names=self._alias_output_names,
240-
backend_version=self._backend_version,
241-
version=self._version,
242-
)
243-
244203
def cum_count(self: Self, *, reverse: bool) -> Self:
245204
return self._reuse_series("cum_count", reverse=reverse)
246205

narwhals/_arrow/namespace.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def _expr(self) -> type[ArrowExpr]:
5151
def _series(self) -> type[ArrowSeries]:
5252
return ArrowSeries
5353

54-
def _create_compliant_series(self: Self, value: Any) -> ArrowSeries:
55-
return self._series._from_iterable(value, name="", context=self)
56-
5754
# --- not in spec ---
5855
def __init__(
5956
self: Self, *, backend_version: tuple[int, ...], version: Version

narwhals/_arrow/series.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from narwhals._arrow.utils import nulls_like
2828
from narwhals._arrow.utils import pad_series
2929
from narwhals._compliant import EagerSeries
30+
from narwhals.dependencies import is_numpy_array_1d
3031
from narwhals.exceptions import InvalidOperationError
3132
from narwhals.utils import Implementation
3233
from narwhals.utils import generate_temporary_column_name
@@ -52,6 +53,7 @@
5253
from narwhals._arrow.typing import _AsPyType
5354
from narwhals._arrow.typing import _BasicDataType
5455
from narwhals.dtypes import DType
56+
from narwhals.typing import Into1DArray
5557
from narwhals.typing import _1DArray
5658
from narwhals.typing import _2DArray
5759
from narwhals.utils import Version
@@ -156,6 +158,12 @@ def _from_scalar(self, value: Any) -> Self:
156158
value = value.as_py()
157159
return super()._from_scalar(value)
158160

161+
@classmethod
162+
def from_numpy(cls, data: Into1DArray, /, *, context: _FullContext) -> Self:
163+
return cls._from_iterable(
164+
data if is_numpy_array_1d(data) else [data], name="", context=context
165+
)
166+
159167
def __narwhals_namespace__(self: Self) -> ArrowNamespace:
160168
from narwhals._arrow.namespace import ArrowNamespace
161169

@@ -437,7 +445,7 @@ def to_list(self: Self) -> list[Any]:
437445
def __array__(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray:
438446
return self.native.__array__(dtype=dtype, copy=copy)
439447

440-
def to_numpy(self: Self) -> _1DArray:
448+
def to_numpy(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray:
441449
return self.native.to_numpy()
442450

443451
def alias(self: Self, name: str) -> Self:

narwhals/_compliant/expr.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from narwhals._compliant.typing import EagerSeriesT
3030
from narwhals._compliant.typing import NativeExprT_co
3131
from narwhals._expression_parsing import evaluate_output_names_and_aliases
32+
from narwhals.dependencies import get_numpy
33+
from narwhals.dependencies import is_numpy_array
3234
from narwhals.dtypes import DType
3335
from narwhals.utils import _ExprNamespace
3436
from narwhals.utils import deprecated
@@ -760,6 +762,38 @@ def rolling_var(
760762
ddof=ddof,
761763
)
762764

765+
def map_batches(
766+
self: Self,
767+
function: Callable[[Any], Any],
768+
return_dtype: DType | type[DType] | None,
769+
) -> Self:
770+
def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT]:
771+
input_series_list = self(df)
772+
output_names = [input_series.name for input_series in input_series_list]
773+
result = [function(series) for series in input_series_list]
774+
if is_numpy_array(result[0]) or (
775+
(np := get_numpy()) is not None and np.isscalar(result[0])
776+
):
777+
from_numpy = partial(
778+
self.__narwhals_namespace__()._series.from_numpy, context=self
779+
)
780+
result = [
781+
from_numpy(array).alias(output_name)
782+
for array, output_name in zip(result, output_names)
783+
]
784+
if return_dtype is not None:
785+
result = [series.cast(return_dtype) for series in result]
786+
return result
787+
788+
return self._from_callable(
789+
func,
790+
depth=self._depth + 1,
791+
function_name=self._function_name + "->map_batches",
792+
evaluate_output_names=self._evaluate_output_names,
793+
alias_output_names=self._alias_output_names,
794+
context=self,
795+
)
796+
763797
@property
764798
def cat(self) -> EagerExprCatNamespace[Self]:
765799
return EagerExprCatNamespace(self)

narwhals/_compliant/namespace.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from narwhals._compliant.typing import EagerDataFrameT
1414
from narwhals._compliant.typing import EagerExprT
1515
from narwhals._compliant.typing import EagerSeriesT_co
16-
from narwhals.utils import deprecated
1716
from narwhals.utils import exclude_column_names
1817
from narwhals.utils import get_column_names
1918
from narwhals.utils import passthrough_column_names
@@ -85,13 +84,3 @@ class EagerNamespace(
8584
):
8685
@property
8786
def _series(self) -> type[EagerSeriesT_co]: ...
88-
89-
@deprecated(
90-
"Internally used for `numpy.ndarray` -> `CompliantSeries`\n"
91-
"Also referenced in untyped `nw.dataframe.DataFrame._extract_compliant`\n"
92-
"See Also:\n"
93-
" - https://github.com/narwhals-dev/narwhals/pull/2149#discussion_r1986283345\n"
94-
" - https://github.com/narwhals-dev/narwhals/issues/2116\n"
95-
" - https://github.com/narwhals-dev/narwhals/pull/2169"
96-
)
97-
def _create_compliant_series(self, value: Any) -> EagerSeriesT_co: ...

narwhals/_compliant/series.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import Protocol
77
from typing import TypeVar
88

9+
from narwhals._translate import NumpyConvertible
10+
911
if TYPE_CHECKING:
1012
from typing_extensions import Self
1113

@@ -14,7 +16,9 @@
1416
from narwhals._compliant.namespace import CompliantNamespace # noqa: F401
1517
from narwhals._compliant.namespace import EagerNamespace
1618
from narwhals.dtypes import DType
19+
from narwhals.typing import Into1DArray
1720
from narwhals.typing import NativeSeries
21+
from narwhals.typing import _1DArray # noqa: F401
1822
from narwhals.utils import Implementation
1923
from narwhals.utils import Version
2024
from narwhals.utils import _FullContext
@@ -24,7 +28,7 @@
2428
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
2529

2630

27-
class CompliantSeries(Protocol):
31+
class CompliantSeries(NumpyConvertible["_1DArray", "Into1DArray"], Protocol):
2832
@property
2933
def dtype(self) -> DType: ...
3034
@property
@@ -36,6 +40,8 @@ def alias(self, name: str) -> Self: ...
3640
def __narwhals_namespace__(self) -> Any: ... # CompliantNamespace[Any, Self]: ...
3741
def _from_native_series(self, series: Any) -> Self: ...
3842
def _to_expr(self) -> Any: ... # CompliantExpr[Any, Self]: ...
43+
@classmethod
44+
def from_numpy(cls, data: Into1DArray, /, *, context: _FullContext) -> Self: ...
3945

4046

4147
class EagerSeries(CompliantSeries, Protocol[NativeSeriesT_co]):
@@ -60,3 +66,5 @@ def __narwhals_namespace__(self) -> EagerNamespace[Any, Self, Any]: ...
6066

6167
def _to_expr(self) -> EagerExpr[Any, Any]:
6268
return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return]
69+
70+
def cast(self, dtype: DType | type[DType]) -> Self: ...

narwhals/_compliant/typing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from narwhals._compliant.expr import CompliantExpr
1616
from narwhals._compliant.expr import EagerExpr
1717
from narwhals._compliant.expr import NativeExpr
18+
from narwhals._compliant.namespace import EagerNamespace
1819
from narwhals._compliant.series import CompliantSeries
1920
from narwhals._compliant.series import EagerSeries
2021

@@ -48,5 +49,9 @@
4849
EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]")
4950
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound="EagerSeries[Any]", covariant=True)
5051
EagerExprT = TypeVar("EagerExprT", bound="EagerExpr[Any, Any]")
52+
EagerNamespaceAny: TypeAlias = (
53+
"EagerNamespace[EagerDataFrame[Any], EagerSeries[Any], EagerExpr[Any, Any]]"
54+
)
55+
5156
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
5257
AliasName: TypeAlias = Callable[[str], str]

narwhals/_expression_parsing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Literal
1313
from typing import Sequence
1414
from typing import TypeVar
15+
from typing import cast
1516

1617
from narwhals.dependencies import is_narwhals_series
1718
from narwhals.dependencies import is_numpy_array
@@ -27,6 +28,7 @@
2728
from narwhals._compliant import CompliantExprT
2829
from narwhals._compliant import CompliantFrameT
2930
from narwhals._compliant import CompliantNamespace
31+
from narwhals._compliant.typing import EagerNamespaceAny
3032
from narwhals.expr import Expr
3133
from narwhals.typing import CompliantDataFrame
3234
from narwhals.typing import CompliantLazyFrame
@@ -100,7 +102,8 @@ def extract_compliant(
100102
if is_narwhals_series(other):
101103
return other._compliant_series._to_expr()
102104
if is_numpy_array(other):
103-
return plx._create_compliant_series(other)._to_expr() # type: ignore[attr-defined]
105+
ns = cast("EagerNamespaceAny", plx)
106+
return ns._series.from_numpy(other, context=ns)._to_expr()
104107
return other
105108

106109

narwhals/_pandas_like/dataframe.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from narwhals._pandas_like.utils import align_series_full_broadcast
1717
from narwhals._pandas_like.utils import check_column_names_are_unique
1818
from narwhals._pandas_like.utils import convert_str_slice_to_int_slice
19-
from narwhals._pandas_like.utils import create_compliant_series
2019
from narwhals._pandas_like.utils import extract_dataframe_comparand
2120
from narwhals._pandas_like.utils import horizontal_concat
2221
from narwhals._pandas_like.utils import native_to_narwhals_dtype
@@ -433,16 +432,14 @@ def estimated_size(self: Self, unit: SizeUnit) -> int | float:
433432
return scale_bytes(sz, unit=unit)
434433

435434
def with_row_index(self: Self, name: str) -> Self:
436-
row_index = create_compliant_series(
437-
range(len(self._native_frame)),
438-
index=self._native_frame.index,
439-
implementation=self._implementation,
440-
backend_version=self._backend_version,
441-
version=self._version,
435+
frame = self._native_frame
436+
namespace = self.__narwhals_namespace__()
437+
row_index = namespace._series._from_iterable(
438+
range(len(frame)), name="", context=self, index=frame.index
442439
).alias(name)
443440
return self._from_native_frame(
444441
horizontal_concat(
445-
[row_index._native_series, self._native_frame],
442+
[row_index.native, frame],
446443
implementation=self._implementation,
447444
backend_version=self._backend_version,
448445
)

narwhals/_pandas_like/expr.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from narwhals._expression_parsing import is_elementary_expression
1414
from narwhals._pandas_like.group_by import AGGREGATIONS_TO_PANDAS_EQUIVALENT
1515
from narwhals._pandas_like.series import PandasLikeSeries
16-
from narwhals.dependencies import get_numpy
17-
from narwhals.dependencies import is_numpy_array
1816
from narwhals.exceptions import ColumnNotFoundError
1917
from narwhals.utils import generate_temporary_column_name
2018

@@ -23,7 +21,6 @@
2321

2422
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
2523
from narwhals._pandas_like.namespace import PandasLikeNamespace
26-
from narwhals.dtypes import DType
2724
from narwhals.utils import Implementation
2825
from narwhals.utils import Version
2926
from narwhals.utils import _FullContext
@@ -299,39 +296,6 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
299296
version=self._version,
300297
)
301298

302-
def map_batches(
303-
self: Self,
304-
function: Callable[[Any], Any],
305-
return_dtype: DType | type[DType] | None,
306-
) -> Self:
307-
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
308-
input_series_list = self._call(df)
309-
output_names = [input_series.name for input_series in input_series_list]
310-
result = [function(series) for series in input_series_list]
311-
if is_numpy_array(result[0]) or (
312-
(np := get_numpy()) is not None and np.isscalar(result[0])
313-
):
314-
result = [
315-
df.__narwhals_namespace__()
316-
._create_compliant_series(array)
317-
.alias(output_name)
318-
for array, output_name in zip(result, output_names)
319-
]
320-
if return_dtype is not None:
321-
result = [series.cast(return_dtype) for series in result]
322-
return result
323-
324-
return self.__class__(
325-
func,
326-
depth=self._depth + 1,
327-
function_name=self._function_name + "->map_batches",
328-
evaluate_output_names=self._evaluate_output_names,
329-
alias_output_names=self._alias_output_names,
330-
implementation=self._implementation,
331-
backend_version=self._backend_version,
332-
version=self._version,
333-
)
334-
335299
def cum_count(self: Self, *, reverse: bool) -> Self:
336300
return self._reuse_series("cum_count", call_kwargs={"reverse": reverse})
337301

0 commit comments

Comments
 (0)