Skip to content

Commit 9691144

Browse files
committed
#1410 df.loc with iterable
1 parent cfcaf41 commit 9691144

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -225,27 +225,6 @@ class _iLocIndexerFrame(_iLocIndexer, Generic[_T]):
225225
) -> None: ...
226226

227227
class _LocIndexerFrame(_LocIndexer, Generic[_T]):
228-
@overload
229-
def __getitem__(self, idx: Scalar) -> Series | _T: ...
230-
@overload
231-
def __getitem__( # type: ignore[overload-overlap]
232-
self,
233-
idx: (
234-
IndexType
235-
| MaskType
236-
| Callable[[DataFrame], IndexType | MaskType | Sequence[Hashable]]
237-
| list[HashableT]
238-
| tuple[
239-
IndexType
240-
| MaskType
241-
| list[HashableT]
242-
| slice
243-
| _IndexSliceTuple
244-
| Callable,
245-
MaskType | list[HashableT] | IndexType | Callable,
246-
]
247-
),
248-
) -> _T: ...
249228
@overload
250229
def __getitem__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
251230
self,
@@ -277,7 +256,28 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
277256
),
278257
) -> Series: ...
279258
@overload
259+
def __getitem__(self, idx: Scalar) -> Series | _T: ...
260+
@overload
280261
def __getitem__(self, idx: tuple[Scalar, slice]) -> Series | _T: ...
262+
@overload
263+
def __getitem__(
264+
self,
265+
idx: (
266+
IndexType
267+
| MaskType
268+
| Callable[[DataFrame], IndexType | MaskType | Sequence[Hashable]]
269+
| list[HashableT]
270+
| tuple[
271+
IndexType
272+
| MaskType
273+
| list[HashableT]
274+
| slice
275+
| _IndexSliceTuple
276+
| Callable,
277+
MaskType | Iterable[HashableT] | IndexType | Callable,
278+
]
279+
),
280+
) -> _T: ...
281281

282282
# Keep in sync with `DataFrame.__setitem__`
283283
@overload

tests/test_frame.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from collections import (
44
OrderedDict,
5+
UserList,
56
defaultdict,
7+
deque,
68
)
79
from collections.abc import (
810
Callable,
@@ -3718,12 +3720,28 @@ def test_loc_int_set() -> None:
37183720
df.loc[np.uint64(1)] = [2, 3]
37193721

37203722

3721-
def test_loclist() -> None:
3722-
# GH 189
3723+
@pytest.mark.parametrize("col", [1, None])
3724+
@pytest.mark.parametrize("typ", [list, tuple, deque, UserList, iter])
3725+
def test_loc_iterable(col: Hashable, typ: type) -> None:
3726+
# GH 189, GH 1410
37233727
df = pd.DataFrame({1: [1, 2], None: 5}, columns=pd.Index([1, None], dtype=object))
3728+
check(df.loc[:, typ([col])], pd.DataFrame)
37243729

3725-
check(assert_type(df.loc[:, [None]], pd.DataFrame), pd.DataFrame)
3726-
check(assert_type(df.loc[:, [1]], pd.DataFrame), pd.DataFrame)
3730+
if TYPE_CHECKING:
3731+
assert_type(df.loc[:, [None]], pd.DataFrame)
3732+
assert_type(df.loc[:, [1]], pd.DataFrame)
3733+
3734+
assert_type(df.loc[:, (None,)], pd.DataFrame)
3735+
assert_type(df.loc[:, (1,)], pd.DataFrame)
3736+
3737+
assert_type(df.loc[:, deque([None])], pd.DataFrame)
3738+
assert_type(df.loc[:, deque([1])], pd.DataFrame)
3739+
3740+
assert_type(df.loc[:, UserList([None])], pd.DataFrame)
3741+
assert_type(df.loc[:, UserList([1])], pd.DataFrame)
3742+
3743+
assert_type(df.loc[:, (None for _ in [0])], pd.DataFrame)
3744+
assert_type(df.loc[:, (1 for _ in [0])], pd.DataFrame)
37273745

37283746

37293747
def test_dict_items() -> None:

0 commit comments

Comments
 (0)