Skip to content

Commit 4e29a0d

Browse files
author
Ritsuki Yamada
committed
take method on NDDataFrame
1 parent cbb6723 commit 4e29a0d

File tree

5 files changed

+36
-4
lines changed

5 files changed

+36
-4
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2298,7 +2298,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
22982298
) -> Series: ...
22992299
def swapaxes(self, axis1: Axis, axis2: Axis, copy: _bool = ...) -> Self: ...
23002300
def tail(self, n: int = ...) -> Self: ...
2301-
def take(self, indices: list, axis: Axis = ..., **kwargs: Any) -> Self: ...
23022301
@overload
23032302
def to_json(
23042303
self,

pandas-stubs/core/generic.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ from pandas._typing import (
4646
P,
4747
StorageOptions,
4848
T,
49+
TakeIndexer,
4950
TimedeltaConvertibleTypes,
5051
TimeGrouperOrigin,
5152
TimestampConvention,
@@ -418,3 +419,5 @@ class NDFrame(indexing.IndexingMixin):
418419
offset: TimedeltaConvertibleTypes | None = ...,
419420
group_keys: _bool = ...,
420421
) -> DatetimeIndexResampler[Self]: ...
422+
@final
423+
def take(self, indices: TakeIndexer, axis: Axis = ..., **kwargs: Any) -> Self: ...

pandas-stubs/core/series.pyi

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,6 @@ class Series(IndexOpsMixin[S1], NDFrame):
450450
def __array__(self, dtype=...) -> np.ndarray: ...
451451
@property
452452
def axes(self) -> list: ...
453-
def take(
454-
self, indices: Sequence, axis: AxisIndex = ..., **kwargs: Any
455-
) -> Series[S1]: ...
456453
def __getattr__(self, name: _str) -> S1: ...
457454
@overload
458455
def __getitem__(

tests/test_frame.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2984,6 +2984,21 @@ def test_iloc_tuple() -> None:
29842984
df = df.iloc[0:2,]
29852985

29862986

2987+
def test_take_list() -> None:
2988+
df = pd.DataFrame({"a": [1, 2, 3]})
2989+
check(assert_type(df.take([0, 1]), pd.DataFrame), pd.DataFrame)
2990+
2991+
2992+
def test_take_list_npint() -> None:
2993+
df = pd.DataFrame({"a": [1, 2, 3]})
2994+
check(assert_type(df.take([np.int64(0), np.int64(1)]), pd.DataFrame), pd.DataFrame)
2995+
2996+
2997+
def test_take_ndarray() -> None:
2998+
df = pd.DataFrame({"a": [1, 2, 3]})
2999+
check(assert_type(df.take(np.array([0, 1])), pd.DataFrame), pd.DataFrame)
3000+
3001+
29873002
def test_set_columns() -> None:
29883003
# GH 73
29893004
df = pd.DataFrame({"a": [1, 2, 3], "b": [0.0, 1, 1]})

tests/test_series.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,6 +1715,24 @@ def test_iloc_getitem_ndarray() -> None:
17151715
check(assert_type(values_s.iloc[indices_u64], pd.Series), pd.Series)
17161716

17171717

1718+
def test_take_list() -> None:
1719+
s = pd.Series(np.arange(10), name="a")
1720+
check(assert_type(s.take([0, 1]), pd.Series), pd.Series)
1721+
1722+
1723+
def test_take_list_npint() -> None:
1724+
s = pd.Series(np.arange(10), name="a")
1725+
check(
1726+
assert_type(s.take([np.int64(0), np.int64(1)]), pd.Series),
1727+
pd.Series,
1728+
)
1729+
1730+
1731+
def test_take_ndarray() -> None:
1732+
s = pd.Series(np.arange(10), name="a")
1733+
check(assert_type(s.take(np.array([0, 1])), pd.Series), pd.Series)
1734+
1735+
17181736
def test_iloc_setitem_ndarray() -> None:
17191737
# GH 85
17201738
# GH 86

0 commit comments

Comments
 (0)