Skip to content

Commit 927cf69

Browse files
committed
refactor: Make ArrowDataFrame compliant
1 parent deb77c2 commit 927cf69

File tree

1 file changed

+13
-28
lines changed

1 file changed

+13
-28
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any
66
from typing import Iterator
77
from typing import Literal
8+
from typing import Mapping
89
from typing import Sequence
910
from typing import cast
1011
from typing import overload
@@ -26,6 +27,7 @@
2627
from narwhals.utils import check_column_names_are_unique
2728
from narwhals.utils import generate_temporary_column_name
2829
from narwhals.utils import is_sequence_but_not_str
30+
from narwhals.utils import not_implemented
2931
from narwhals.utils import parse_columns_to_drop
3032
from narwhals.utils import parse_version
3133
from narwhals.utils import scale_bytes
@@ -449,29 +451,17 @@ def join(
449451
),
450452
)
451453

452-
def join_asof(
453-
self: Self,
454-
other: Self,
455-
*,
456-
left_on: str | None,
457-
right_on: str | None,
458-
by_left: list[str] | None,
459-
by_right: list[str] | None,
460-
strategy: Literal["backward", "forward", "nearest"],
461-
suffix: str,
462-
) -> Self:
463-
msg = "join_asof is not yet supported on PyArrow tables" # pragma: no cover
464-
raise NotImplementedError(msg)
454+
join_asof = not_implemented()
465455

466-
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
456+
def drop(self: Self, columns: Sequence[str], strict: bool) -> Self: # noqa: FBT001
467457
to_drop = parse_columns_to_drop(
468458
compliant_frame=self, columns=columns, strict=strict
469459
)
470460
return self._from_native_frame(
471461
self._native_frame.drop(to_drop), validate_column_names=False
472462
)
473463

474-
def drop_nulls(self: ArrowDataFrame, subset: list[str] | None) -> ArrowDataFrame:
464+
def drop_nulls(self: ArrowDataFrame, subset: Sequence[str] | None) -> ArrowDataFrame:
475465
if subset is None:
476466
return self._from_native_frame(
477467
self._native_frame.drop_null(), validate_column_names=False
@@ -667,9 +657,7 @@ def collect(
667657
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
668658
raise AssertionError(msg) # pragma: no cover
669659

670-
def clone(self: Self) -> Self:
671-
msg = "clone is not yet supported on PyArrow tables"
672-
raise NotImplementedError(msg)
660+
clone = not_implemented()
673661

674662
def item(self: Self, row: int | None, column: int | str | None) -> Any:
675663
from narwhals._arrow.series import maybe_extract_py_scalar
@@ -695,7 +683,7 @@ def item(self: Self, row: int | None, column: int | str | None) -> Any:
695683
self._native_frame[_col][row], return_py_scalar=True
696684
)
697685

698-
def rename(self: Self, mapping: dict[str, str]) -> Self:
686+
def rename(self: Self, mapping: Mapping[str, str]) -> Self:
699687
df = self._native_frame
700688
new_cols = [mapping.get(c, c) for c in df.column_names]
701689
return self._from_native_frame(df.rename_columns(new_cols))
@@ -746,7 +734,7 @@ def is_unique(self: Self) -> ArrowSeries:
746734

747735
def unique(
748736
self: ArrowDataFrame,
749-
subset: list[str] | None,
737+
subset: Sequence[str] | None,
750738
*,
751739
keep: Literal["any", "first", "last", "none"],
752740
maintain_order: bool | None = None,
@@ -757,7 +745,7 @@ def unique(
757745

758746
df = self._native_frame
759747
check_column_exists(self.columns, subset)
760-
subset = subset or self.columns
748+
subset = list(subset or self.columns)
761749

762750
if keep in {"any", "first", "last"}:
763751
agg_func_map = {"any": "min", "first": "min", "last": "max"}
@@ -808,18 +796,15 @@ def sample(
808796

809797
def unpivot(
810798
self: Self,
811-
on: list[str] | None,
812-
index: list[str] | None,
799+
on: Sequence[str] | None,
800+
index: Sequence[str] | None,
813801
variable_name: str,
814802
value_name: str,
815803
) -> Self:
816804
native_frame = self._native_frame
817805
n_rows = len(self)
818-
819-
index_: list[str] = [] if index is None else index
820-
on_: list[str] = (
821-
[c for c in self.columns if c not in index_] if on is None else on
822-
)
806+
index_ = [] if index is None else index
807+
on_ = [c for c in self.columns if c not in index_] if on is None else on
823808
concat = (
824809
partial(pa.concat_tables, promote_options="permissive")
825810
if self._backend_version >= (14, 0, 0)

0 commit comments

Comments
 (0)