diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index abebc547c..ffc9790e1 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -2191,7 +2191,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): numeric_only: _bool = ..., **kwargs: Any, ) -> Series: ... - def squeeze(self, axis: Axis | None = ...): ... + def squeeze(self, axis: Axis | None = ...) -> DataFrame | Series | Scalar: ... def std( self, axis: Axis = ..., diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index fc45ee5aa..8020ed7a0 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1216,7 +1216,7 @@ class Series(IndexOpsMixin[S1], NDFrame): ) -> Series[S1]: ... def droplevel(self, level: Level | list[Level], axis: AxisIndex = ...) -> Self: ... def pop(self, item: Hashable) -> S1: ... - def squeeze(self, axis: AxisIndex | None = ...) -> Scalar: ... + def squeeze(self) -> Series[S1] | Scalar: ... def __abs__(self) -> Series[S1]: ... def add_prefix(self, prefix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ... def add_suffix(self, suffix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index acb96183d..6fd089faf 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -3231,6 +3231,25 @@ def test_resample() -> None: check(assert_type(df.resample("2min").ohlc(), pd.DataFrame), pd.DataFrame) +def test_squeeze() -> None: + df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + check( + assert_type(df1.squeeze(), Union[pd.DataFrame, pd.Series, Scalar]), pd.DataFrame + ) + df2 = pd.DataFrame({"a": [1, 2]}) + check(assert_type(df2.squeeze(), Union[pd.DataFrame, pd.Series, Scalar]), pd.Series) + df3 = pd.DataFrame({"a": [1], "b": [2]}) + check( + assert_type(df3.squeeze(), Union[pd.DataFrame, pd.Series, Scalar]), + pd.Series, + np.integer, + ) + df4 = pd.DataFrame({"a": [1]}) + check( + assert_type(df4.squeeze(), Union[pd.DataFrame, pd.Series, Scalar]), np.integer + ) + + def test_loc_set() -> None: df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) df.loc["a"] = [3, 4] diff --git a/tests/test_series.py b/tests/test_series.py index ae1b7f3f9..ce4e64bf7 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1730,6 +1730,17 @@ def test_resample() -> None: check(assert_type(s.resample("2min").ohlc(), pd.DataFrame), pd.DataFrame) +def test_squeeze() -> None: + s1 = pd.Series([1, 2, 3]) + check( + assert_type(s1.squeeze(), Union["pd.Series[int]", Scalar]), + pd.Series, + np.integer, + ) + s2 = pd.Series([1]) + check(assert_type(s2.squeeze(), Union["pd.Series[int]", Scalar]), np.integer) + + def test_to_xarray(): s = pd.Series([1, 2]) check(assert_type(s.to_xarray(), xr.DataArray), xr.DataArray)