diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index ccca2213..0019284f 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -224,7 +224,7 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]): | slice | _IndexSliceTuple | Callable, - MaskType | list[HashableT] | slice | Callable, + MaskType | list[HashableT] | IndexType | Callable, ] ), ) -> _T: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 90d5626e..ae5520c4 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -3737,13 +3737,25 @@ def test_xs_key() -> None: def test_loc_slice() -> None: - # GH 277 + """Test DataFrame.loc with a slice, Index, Series.""" + # GH277 df1 = pd.DataFrame( {"x": [1, 2, 3, 4]}, index=pd.MultiIndex.from_product([[1, 2], ["a", "b"]], names=["num", "let"]), ) check(assert_type(df1.loc[1, :], Union[pd.Series, pd.DataFrame]), pd.DataFrame) + # GH1299 + ind = pd.Index(["a", "b"]) + mask = pd.Series([True, False]) + mask_col = pd.Series([True, False], index=pd.Index(["a", "b"])) + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + + # loc with index for columns + check(assert_type(df.loc[mask, ind], pd.DataFrame), pd.DataFrame) + # loc with index for columns + check(assert_type(df.loc[mask, mask_col], pd.DataFrame), pd.DataFrame) + def test_where() -> None: df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})