Skip to content

Commit fe54aac

Browse files
committed
fix type signatures of Dataframe.squeeze and Series.squeeze
1 parent 18df89e commit fe54aac

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2191,7 +2191,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
21912191
numeric_only: _bool = ...,
21922192
**kwargs: Any,
21932193
) -> Series: ...
2194-
def squeeze(self, axis: Axis | None = ...): ...
2194+
def squeeze(self, axis: Axis | None = ...) -> DataFrame | Series | Scalar: ...
21952195
def std(
21962196
self,
21972197
axis: Axis = ...,

pandas-stubs/core/series.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
12161216
) -> Series[S1]: ...
12171217
def droplevel(self, level: Level | list[Level], axis: AxisIndex = ...) -> Self: ...
12181218
def pop(self, item: Hashable) -> S1: ...
1219-
def squeeze(self, axis: AxisIndex | None = ...) -> Scalar: ...
1219+
def squeeze(self) -> Series[S1] | Scalar: ...
12201220
def __abs__(self) -> Series[S1]: ...
12211221
def add_prefix(self, prefix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...
12221222
def add_suffix(self, suffix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...

tests/test_frame.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,6 +3231,17 @@ def test_resample() -> None:
32313231
check(assert_type(df.resample("2min").ohlc(), pd.DataFrame), pd.DataFrame)
32323232

32333233

3234+
def test_squeeze() -> None:
3235+
df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
3236+
check(df1.squeeze(), pd.DataFrame)
3237+
df2 = pd.DataFrame({"a": [1, 2]})
3238+
check(df2.squeeze(), pd.Series)
3239+
df3 = pd.DataFrame({"a": [1], "b": [2]})
3240+
check(df3.squeeze(), pd.Series, np.integer)
3241+
df4 = pd.DataFrame({"a": [1]})
3242+
check(df4.squeeze(), np.integer)
3243+
3244+
32343245
def test_loc_set() -> None:
32353246
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
32363247
df.loc["a"] = [3, 4]

tests/test_series.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,13 @@ def test_resample() -> None:
17301730
check(assert_type(s.resample("2min").ohlc(), pd.DataFrame), pd.DataFrame)
17311731

17321732

1733+
def test_squeeze() -> None:
1734+
s1 = pd.Series([1, 2, 3])
1735+
check(s1.squeeze(), pd.Series, np.integer)
1736+
s2 = pd.Series([1])
1737+
check(s2.squeeze(), np.integer)
1738+
1739+
17331740
def test_to_xarray():
17341741
s = pd.Series([1, 2])
17351742
check(assert_type(s.to_xarray(), xr.DataArray), xr.DataArray)

0 commit comments

Comments
 (0)