Skip to content

Commit 78a5511

Browse files
chore: Add Compliant*.from_native (#2315)
* feat(typing): Add `FromNative` protocol * chore(typing): Add `FromNative` to `CompliantSeries` - Adding `._is_native` made `TypeVar` invariant - Realistically, it always was, but underspecified * feat: Implement for `(Arrow|PandasLike)Series` * feat: Implement for `PolarsSeries` + get some coverage * chore: `ArrowSeries` coverage * chore: `PandasLikeSeries` partial coverage * ignore coverage for now ... * feat(typing): Add `CompliantDataFrame.from_native` * feat: Implement for `ArrowDataFrame` Also coverage for `ArrowSeries` * feat: Implement for `PandasLikeDataFrame` Loads of coverage for both `PandasLike` * feat: Implement for `PolarsDataFrame` * refactor: Found one more * chore(typing): Fix missing `SQLExpression` ignore Was missed in (#2310) * feat: Implement `EagerNamespace.from_native` - `Polars*` will also need to handle `LazyFrame` - `Lazy*` has other constraints * feat: Add `Polars(Namespace|LazyFrame).from_native` Probably need to add a `LazyNamespace` protocol for `LazyOnly` * chore: Ignore coverage `PandasLikeDataFrame._is_native` Nowhere to use it yet, current stuff uses the more precise `self.native.__class__` * feat: Add all `CompliantLazyFrame.from_native` * feat: Add `LazyNamespace.from_native` * refactor: Get some lazy coverage https://github.com/narwhals-dev/narwhals/actions/runs/14157697512/job/39659084059?pr=2315 * refactor: More `polars` coverage https://github.com/narwhals-dev/narwhals/actions/runs/14158424987/job/39660662342?pr=2315 * refactor: reuse `is_spark_like_dataframe` * Update narwhals/_compliant/namespace.py Co-authored-by: Dan Redding <[email protected]> --------- Co-authored-by: Marco Edward Gorelli <[email protected]>
1 parent f932965 commit 78a5511

File tree

25 files changed

+414
-424
lines changed

25 files changed

+414
-424
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 27 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@
4545
import polars as pl
4646
from typing_extensions import Self
4747
from typing_extensions import TypeAlias
48+
from typing_extensions import TypeIs
4849

4950
from narwhals._arrow.expr import ArrowExpr
5051
from narwhals._arrow.group_by import ArrowGroupBy
5152
from narwhals._arrow.namespace import ArrowNamespace
52-
from narwhals._arrow.series import ArrowSeries
5353
from narwhals._arrow.typing import ArrowChunkedArray
5454
from narwhals._arrow.typing import Indices # type: ignore[attr-defined]
5555
from narwhals._arrow.typing import Mask # type: ignore[attr-defined]
@@ -99,7 +99,7 @@ def __init__(
9999
@classmethod
100100
def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self:
101101
backend_version = context._backend_version
102-
if isinstance(data, pa.Table):
102+
if cls._is_native(data):
103103
native = data
104104
elif backend_version >= (14,) or isinstance(data, Collection):
105105
native = pa.table(data)
@@ -109,12 +109,7 @@ def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self:
109109
else: # pragma: no cover
110110
msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}."
111111
raise TypeError(msg)
112-
return cls(
113-
native,
114-
backend_version=backend_version,
115-
version=context._version,
116-
validate_column_names=True,
117-
)
112+
return cls.from_native(native, context=context)
118113

119114
@classmethod
120115
def from_dict(
@@ -129,8 +124,16 @@ def from_dict(
129124

130125
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
131126
native = pa.Table.from_pydict(data, schema=pa_schema)
127+
return cls.from_native(native, context=context)
128+
129+
@staticmethod
130+
def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]:
131+
return isinstance(obj, pa.Table)
132+
133+
@classmethod
134+
def from_native(cls, data: pa.Table, /, *, context: _FullContext) -> Self:
132135
return cls(
133-
native,
136+
data,
134137
backend_version=context._backend_version,
135138
version=context._version,
136139
validate_column_names=True,
@@ -152,12 +155,7 @@ def from_numpy(
152155
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
153156
else:
154157
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
155-
return cls(
156-
native,
157-
backend_version=context._backend_version,
158-
version=context._version,
159-
validate_column_names=True,
160-
)
158+
return cls.from_native(native, context=context)
161159

162160
def __narwhals_namespace__(self: Self) -> ArrowNamespace:
163161
from narwhals._arrow.namespace import ArrowNamespace
@@ -224,15 +222,8 @@ def rows(self: Self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, A
224222
return self.native.to_pylist()
225223

226224
def iter_columns(self) -> Iterator[ArrowSeries]:
227-
from narwhals._arrow.series import ArrowSeries
228-
229225
for name, series in zip(self.columns, self.native.itercolumns()):
230-
yield ArrowSeries(
231-
series,
232-
name=name,
233-
backend_version=self._backend_version,
234-
version=self._version,
235-
)
226+
yield ArrowSeries.from_native(series, context=self, name=name)
236227

237228
_iter_columns = iter_columns
238229

@@ -251,18 +242,10 @@ def iter_rows(
251242
yield from df[i : i + buffer_size].to_pylist()
252243

253244
def get_column(self: Self, name: str) -> ArrowSeries:
254-
from narwhals._arrow.series import ArrowSeries
255-
256245
if not isinstance(name, str):
257246
msg = f"Expected str, got: {type(name)}"
258247
raise TypeError(msg)
259-
260-
return ArrowSeries(
261-
self.native[name],
262-
name=name,
263-
backend_version=self._backend_version,
264-
version=self._version,
265-
)
248+
return ArrowSeries.from_native(self.native[name], context=self, name=name)
266249

267250
def __array__(self: Self, dtype: Any, *, copy: bool | None) -> _2DArray:
268251
return self.native.__array__(dtype, copy=copy)
@@ -304,14 +287,7 @@ def __getitem__(
304287
item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # pyright: ignore[reportAssignmentType]
305288

306289
if isinstance(item, str):
307-
from narwhals._arrow.series import ArrowSeries
308-
309-
return ArrowSeries(
310-
self.native[item],
311-
name=item,
312-
backend_version=self._backend_version,
313-
version=self._version,
314-
)
290+
return ArrowSeries.from_native(self.native[item], context=self, name=item)
315291
elif (
316292
isinstance(item, tuple)
317293
and len(item) == 2
@@ -345,7 +321,6 @@ def __getitem__(
345321
)
346322
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
347323
raise TypeError(msg) # pragma: no cover
348-
from narwhals._arrow.series import ArrowSeries
349324

350325
# PyArrow columns are always strings
351326
col_name = (
@@ -357,18 +332,12 @@ def __getitem__(
357332
msg = "Can not slice with tuple with the first element as a str"
358333
raise TypeError(msg)
359334
if (isinstance(item[0], slice)) and (item[0] == slice(None)):
360-
return ArrowSeries(
361-
self.native[col_name],
362-
name=col_name,
363-
backend_version=self._backend_version,
364-
version=self._version,
335+
return ArrowSeries.from_native(
336+
self.native[col_name], context=self, name=col_name
365337
)
366338
selected_rows = select_rows(self.native, item[0])
367-
return ArrowSeries(
368-
selected_rows[col_name],
369-
name=col_name,
370-
backend_version=self._backend_version,
371-
version=self._version,
339+
return ArrowSeries.from_native(
340+
selected_rows[col_name], context=self, name=col_name
372341
)
373342

374343
elif isinstance(item, slice):
@@ -589,18 +558,10 @@ def to_dict(
589558
self: Self, *, as_series: bool
590559
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
591560
df = self.native
592-
593561
names_and_values = zip(df.column_names, df.columns)
594562
if as_series:
595-
from narwhals._arrow.series import ArrowSeries
596-
597563
return {
598-
name: ArrowSeries(
599-
col,
600-
name=name,
601-
backend_version=self._backend_version,
602-
version=self._version,
603-
)
564+
name: ArrowSeries.from_native(col, context=self, name=name)
604565
for name, col in names_and_values
605566
}
606567
else:
@@ -778,26 +739,20 @@ def write_csv(self: Self, file: str | Path | BytesIO | None) -> str | None:
778739
return None
779740

780741
def is_unique(self: Self) -> ArrowSeries:
781-
from narwhals._arrow.series import ArrowSeries
782-
783742
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
784743
row_index = pa.array(range(len(self)))
785744
keep_idx = (
786745
self.native.append_column(col_token, row_index)
787746
.group_by(self.columns)
788747
.aggregate([(col_token, "min"), (col_token, "max")])
789748
)
790-
return ArrowSeries(
791-
pa.chunked_array(
792-
pc.and_(
793-
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
794-
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
795-
)
796-
),
797-
name="",
798-
backend_version=self._backend_version,
799-
version=self._version,
749+
native = pa.chunked_array(
750+
pc.and_(
751+
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
752+
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
753+
)
800754
)
755+
return ArrowSeries.from_native(native, context=self)
801756

802757
def unique(
803758
self: ArrowDataFrame,

narwhals/_arrow/namespace.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@
3737
from narwhals.utils import Version
3838

3939

40-
class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr]):
40+
class ArrowNamespace(
41+
EagerNamespace[
42+
ArrowDataFrame, ArrowSeries, ArrowExpr, "pa.Table", "ArrowChunkedArray"
43+
]
44+
):
4145
@property
4246
def _dataframe(self) -> type[ArrowDataFrame]:
4347
return ArrowDataFrame

narwhals/_arrow/series.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import pandas as pd
4444
import polars as pl
4545
from typing_extensions import Self
46+
from typing_extensions import TypeIs
4647

4748
from narwhals._arrow.dataframe import ArrowDataFrame
4849
from narwhals._arrow.namespace import ArrowNamespace
@@ -135,12 +136,7 @@ def _with_native(
135136
*,
136137
preserve_broadcast: bool = False,
137138
) -> Self:
138-
result = self.__class__(
139-
chunked_array(series),
140-
name=self._name,
141-
backend_version=self._backend_version,
142-
version=self._version,
143-
)
139+
result = self.from_native(chunked_array(series), name=self.name, context=self)
144140
if preserve_broadcast:
145141
result._broadcast = self._broadcast
146142
return result
@@ -156,18 +152,30 @@ def from_iterable(
156152
) -> Self:
157153
version = context._version
158154
dtype_pa = narwhals_to_native_dtype(dtype, version) if dtype else None
159-
return cls(
160-
chunked_array([data], dtype_pa),
161-
name=name,
162-
backend_version=context._backend_version,
163-
version=version,
155+
return cls.from_native(
156+
chunked_array([data], dtype_pa), name=name, context=context
164157
)
165158

166159
def _from_scalar(self, value: Any) -> Self:
167160
if self._backend_version < (13,) and hasattr(value, "as_py"):
168161
value = value.as_py()
169162
return super()._from_scalar(value)
170163

164+
@staticmethod
165+
def _is_native(obj: ArrowChunkedArray | Any) -> TypeIs[ArrowChunkedArray]:
166+
return isinstance(obj, pa.ChunkedArray)
167+
168+
@classmethod
169+
def from_native(
170+
cls, data: ArrowChunkedArray, /, *, context: _FullContext, name: str = ""
171+
) -> Self:
172+
return cls(
173+
data,
174+
backend_version=context._backend_version,
175+
version=context._version,
176+
name=name,
177+
)
178+
171179
@classmethod
172180
def from_numpy(cls, data: Into1DArray, /, *, context: _FullContext) -> Self:
173181
return cls.from_iterable(
@@ -546,7 +554,7 @@ def tail(self: Self, n: int) -> Self:
546554
return self._with_native(self.native.slice(abs(n)))
547555

548556
def is_in(self: Self, other: Any) -> Self:
549-
if isinstance(other, pa.ChunkedArray):
557+
if self._is_native(other):
550558
value_set: ArrowChunkedArray | ArrowArray = other
551559
else:
552560
value_set = pa.array(other)

narwhals/_compliant/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from narwhals._compliant.group_by import LazyGroupBy
1313
from narwhals._compliant.namespace import CompliantNamespace
1414
from narwhals._compliant.namespace import EagerNamespace
15+
from narwhals._compliant.namespace import LazyNamespace
1516
from narwhals._compliant.selectors import CompliantSelector
1617
from narwhals._compliant.selectors import CompliantSelectorNamespace
1718
from narwhals._compliant.selectors import EagerSelectorNamespace
@@ -64,6 +65,7 @@
6465
"IntoCompliantExpr",
6566
"LazyExpr",
6667
"LazyGroupBy",
68+
"LazyNamespace",
6769
"LazySelectorNamespace",
6870
"LazyWhen",
6971
"NativeFrameT_co",

narwhals/_compliant/dataframe.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from narwhals._compliant.typing import CompliantSeriesT
1717
from narwhals._compliant.typing import EagerExprT_contra
1818
from narwhals._compliant.typing import EagerSeriesT
19-
from narwhals._compliant.typing import NativeFrameT_co
19+
from narwhals._compliant.typing import NativeFrameT
2020
from narwhals._expression_parsing import evaluate_output_names_and_aliases
2121
from narwhals._translate import ArrowConvertible
2222
from narwhals._translate import DictConvertible
23+
from narwhals._translate import FromNative
2324
from narwhals._translate import NumpyConvertible
2425
from narwhals.utils import Version
2526
from narwhals.utils import _StoresNative
@@ -57,11 +58,12 @@ class CompliantDataFrame(
5758
NumpyConvertible["_2DArray", "_2DArray"],
5859
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
5960
ArrowConvertible["pa.Table", "IntoArrowTable"],
60-
_StoresNative[NativeFrameT_co],
61+
_StoresNative[NativeFrameT],
62+
FromNative[NativeFrameT],
6163
Sized,
62-
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT_co],
64+
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT],
6365
):
64-
_native_frame: Any
66+
_native_frame: NativeFrameT
6567
_implementation: Implementation
6668
_backend_version: tuple[int, ...]
6769
_version: Version
@@ -80,6 +82,8 @@ def from_dict(
8082
schema: Mapping[str, DType] | Schema | None,
8183
) -> Self: ...
8284
@classmethod
85+
def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...
86+
@classmethod
8387
def from_numpy(
8488
cls,
8589
data: _2DArray,
@@ -105,8 +109,8 @@ def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
105109
def _with_version(self, version: Version) -> Self: ...
106110

107111
@property
108-
def native(self) -> NativeFrameT_co:
109-
return self._native_frame # type: ignore[no-any-return]
112+
def native(self) -> NativeFrameT:
113+
return self._native_frame
110114

111115
@property
112116
def columns(self) -> Sequence[str]: ...
@@ -210,16 +214,21 @@ def write_parquet(self, file: str | Path | BytesIO) -> None: ...
210214

211215

212216
class CompliantLazyFrame(
213-
_StoresNative[NativeFrameT_co], Protocol[CompliantExprT_contra, NativeFrameT_co]
217+
_StoresNative[NativeFrameT],
218+
FromNative[NativeFrameT],
219+
Protocol[CompliantExprT_contra, NativeFrameT],
214220
):
215-
_native_frame: Any
221+
_native_frame: NativeFrameT
216222
_implementation: Implementation
217223
_backend_version: tuple[int, ...]
218224
_version: Version
219225

220226
def __narwhals_lazyframe__(self) -> Self: ...
221227
def __narwhals_namespace__(self) -> Any: ...
222228

229+
@classmethod
230+
def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...
231+
223232
def simple_select(self, *column_names: str) -> Self:
224233
"""`select` where all args are column names."""
225234
...
@@ -234,8 +243,8 @@ def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
234243
def _with_version(self, version: Version) -> Self: ...
235244

236245
@property
237-
def native(self) -> NativeFrameT_co:
238-
return self._native_frame # type: ignore[no-any-return]
246+
def native(self) -> NativeFrameT:
247+
return self._native_frame
239248

240249
@property
241250
def columns(self) -> Sequence[str]: ...
@@ -307,9 +316,9 @@ def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
307316

308317

309318
class EagerDataFrame(
310-
CompliantDataFrame[EagerSeriesT, EagerExprT_contra, NativeFrameT_co],
311-
CompliantLazyFrame[EagerExprT_contra, NativeFrameT_co],
312-
Protocol[EagerSeriesT, EagerExprT_contra, NativeFrameT_co],
319+
CompliantDataFrame[EagerSeriesT, EagerExprT_contra, NativeFrameT],
320+
CompliantLazyFrame[EagerExprT_contra, NativeFrameT],
321+
Protocol[EagerSeriesT, EagerExprT_contra, NativeFrameT],
313322
):
314323
def _evaluate_expr(self, expr: EagerExprT_contra, /) -> EagerSeriesT:
315324
"""Evaluate `expr` and ensure it has a **single** output."""

0 commit comments

Comments
 (0)