Skip to content

Commit 996298b

Browse files
GH984 Add overload for DataFrame.clip and update those for Series.clip
1 parent 9c02c36 commit 996298b

File tree

4 files changed

+213
-25
lines changed

4 files changed

+213
-25
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,23 +1738,73 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
17381738
@overload
17391739
def clip(
17401740
self,
1741-
lower: float | AnyArrayLike | None = ...,
1742-
upper: float | AnyArrayLike | None = ...,
1741+
lower: float | None = ...,
1742+
upper: float | None = ...,
17431743
*,
17441744
axis: Axis | None = ...,
1745-
inplace: Literal[True],
1745+
inplace: Literal[False] = ...,
17461746
**kwargs: Any,
1747-
) -> None: ...
1747+
) -> Self: ...
17481748
@overload
17491749
def clip(
17501750
self,
1751-
lower: float | AnyArrayLike | None = ...,
1752-
upper: float | AnyArrayLike | None = ...,
1751+
lower: AnyArrayLike = ...,
1752+
upper: AnyArrayLike | None = ...,
17531753
*,
1754-
axis: Axis | None = ...,
1754+
axis: Axis = ...,
17551755
inplace: Literal[False] = ...,
17561756
**kwargs: Any,
17571757
) -> Self: ...
1758+
@overload
1759+
def clip(
1760+
self,
1761+
lower: AnyArrayLike | None = ...,
1762+
upper: AnyArrayLike = ...,
1763+
*,
1764+
axis: Axis = ...,
1765+
inplace: Literal[False] = ...,
1766+
**kwargs: Any,
1767+
) -> Self: ...
1768+
@overload
1769+
def clip( # pyright: ignore[reportOverlappingOverload]
1770+
self,
1771+
lower: None = ...,
1772+
upper: None = ...,
1773+
*,
1774+
axis: Axis | None = ...,
1775+
inplace: Literal[True],
1776+
**kwargs: Any,
1777+
) -> Self: ...
1778+
@overload
1779+
def clip(
1780+
self,
1781+
lower: float | None = ...,
1782+
upper: float | None = ...,
1783+
*,
1784+
axis: Axis | None = ...,
1785+
inplace: Literal[True],
1786+
**kwargs: Any,
1787+
) -> None: ...
1788+
@overload
1789+
def clip(
1790+
self,
1791+
lower: AnyArrayLike = ...,
1792+
upper: AnyArrayLike | None = ...,
1793+
*,
1794+
axis: Axis = ...,
1795+
inplace: Literal[True],
1796+
**kwargs: Any,
1797+
) -> None: ...
1798+
@overload
1799+
def clip(
1800+
self,
1801+
lower: AnyArrayLike | None = ...,
1802+
upper: AnyArrayLike = ...,
1803+
*,
1804+
axis: Axis = ...,
1805+
inplace: Literal[True],
1806+
**kwargs: Any,
1807+
) -> None: ...
17581808
def copy(self, deep: _bool = ...) -> Self: ...
17591809
def cummax(
17601810
self, axis: Axis | None = ..., skipna: _bool = ..., *args: Any, **kwargs: Any

pandas-stubs/core/series.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,16 @@ class Series(IndexOpsMixin[S1], NDFrame):
13941394
subset: _str | Sequence[_str] | None = ...,
13951395
) -> Scalar | Series[S1]: ...
13961396
@overload
1397+
def clip( # pyright: ignore[reportOverlappingOverload]
1398+
self,
1399+
lower: None = ...,
1400+
upper: None = ...,
1401+
*,
1402+
axis: AxisIndex | None = ...,
1403+
inplace: Literal[True],
1404+
**kwargs: Any,
1405+
) -> Self: ...
1406+
@overload
13971407
def clip(
13981408
self,
13991409
lower: AnyArrayLike | float | None = ...,

tests/test_frame.py

Lines changed: 117 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -659,23 +659,124 @@ def test_types_quantile() -> None:
659659
df.quantile(np.array([0.25, 0.75]))
660660

661661

662-
@pytest.mark.parametrize("lower", [None, 5, pd.Series([3, 4])])
663-
@pytest.mark.parametrize("upper", [None, 15, pd.Series([12, 13])])
664-
@pytest.mark.parametrize("axis", [None, 0, "index"])
665-
def test_types_clip(lower, upper, axis) -> None:
666-
def is_none_or_numeric(val: Any) -> bool:
667-
return val is None or isinstance(val, int | float)
668-
662+
def test_dataframe_clip() -> None:
663+
"""Test different clipping combinations for dataframe."""
669664
df = pd.DataFrame(data={"col1": [20, 12], "col2": [3, 14]})
670-
uses_array = not (is_none_or_numeric(lower) and is_none_or_numeric(upper))
671-
if uses_array and axis is None:
672-
with pytest.raises(ValueError):
673-
df.clip(lower=lower, upper=upper, axis=axis)
674-
else:
675-
check(
676-
assert_type(df.clip(lower=lower, upper=upper, axis=axis), pd.DataFrame),
677-
pd.DataFrame,
678-
)
665+
if TYPE_CHECKING_INVALID_USAGE:
666+
df.clip(lower=pd.Series([4, 5]), upper=None, axis=None) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
667+
df.clip(lower=None, upper=pd.Series([4, 5]), axis=None) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
668+
df.clip(lower=pd.Series([1, 2]), upper=pd.Series([4, 5]), axis=None) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
669+
df.copy().clip(lower=pd.Series([1, 2]), upper=None, axis=None, inplace=True) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
670+
df.copy().clip(lower=None, upper=pd.Series([1, 2]), axis=None, inplace=True) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
671+
df.copy().clip(lower=pd.Series([4, 5]), upper=pd.Series([1, 2]), axis=None, inplace=True) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
672+
673+
check(
674+
assert_type(df.clip(lower=None, upper=None, axis=None), pd.DataFrame),
675+
pd.DataFrame,
676+
)
677+
check(
678+
assert_type(df.clip(lower=5, upper=None, axis=None), pd.DataFrame), pd.DataFrame
679+
)
680+
check(
681+
assert_type(df.clip(lower=None, upper=15, axis=None), pd.DataFrame),
682+
pd.DataFrame,
683+
)
684+
check(
685+
assert_type(
686+
df.clip(lower=None, upper=None, axis=None, inplace=True), pd.DataFrame
687+
),
688+
pd.DataFrame,
689+
)
690+
check(
691+
assert_type(df.clip(lower=5, upper=None, axis=None, inplace=True), None),
692+
type(None),
693+
)
694+
check(
695+
assert_type(df.clip(lower=None, upper=15, axis=None, inplace=True), None),
696+
type(None),
697+
)
698+
699+
check(
700+
assert_type(df.clip(lower=None, upper=None, axis=0), pd.DataFrame), pd.DataFrame
701+
)
702+
check(assert_type(df.clip(lower=5, upper=None, axis=0), pd.DataFrame), pd.DataFrame)
703+
check(
704+
assert_type(df.clip(lower=None, upper=15, axis=0), pd.DataFrame), pd.DataFrame
705+
)
706+
check(
707+
assert_type(df.clip(lower=pd.Series([1, 2]), upper=None, axis=0), pd.DataFrame),
708+
pd.DataFrame,
709+
)
710+
check(
711+
assert_type(df.clip(lower=None, upper=pd.Series([1, 2]), axis=0), pd.DataFrame),
712+
pd.DataFrame,
713+
)
714+
check(
715+
assert_type(
716+
df.clip(lower=None, upper=None, axis="index", inplace=True), pd.DataFrame
717+
),
718+
pd.DataFrame,
719+
)
720+
check(
721+
assert_type(df.clip(lower=5, upper=None, axis="index", inplace=True), None),
722+
type(None),
723+
)
724+
check(
725+
assert_type(df.clip(lower=None, upper=15, axis="index", inplace=True), None),
726+
type(None),
727+
)
728+
check(
729+
assert_type(
730+
df.clip(lower=pd.Series([1, 2]), upper=None, axis="index", inplace=True),
731+
None,
732+
),
733+
type(None),
734+
)
735+
check(
736+
assert_type(
737+
df.clip(lower=None, upper=pd.Series([1, 2]), axis="index", inplace=True),
738+
None,
739+
),
740+
type(None),
741+
)
742+
check(
743+
assert_type(df.clip(lower=None, upper=None, axis="index"), pd.DataFrame),
744+
pd.DataFrame,
745+
)
746+
check(
747+
assert_type(df.clip(lower=5, upper=None, axis="index"), pd.DataFrame),
748+
pd.DataFrame,
749+
)
750+
check(
751+
assert_type(df.clip(lower=None, upper=15, axis="index"), pd.DataFrame),
752+
pd.DataFrame,
753+
)
754+
check(
755+
assert_type(
756+
df.clip(lower=pd.Series([1, 2]), upper=None, axis="index"), pd.DataFrame
757+
),
758+
pd.DataFrame,
759+
)
760+
check(
761+
assert_type(
762+
df.clip(lower=None, upper=pd.Series([1, 2]), axis="index"), pd.DataFrame
763+
),
764+
pd.DataFrame,
765+
)
766+
check(
767+
assert_type(
768+
df.clip(lower=None, upper=None, axis=0, inplace=True), pd.DataFrame
769+
),
770+
pd.DataFrame,
771+
)
772+
check(
773+
assert_type(df.clip(lower=5, upper=None, axis=0, inplace=True), None),
774+
type(None),
775+
)
776+
check(
777+
assert_type(df.clip(lower=None, upper=15, axis=0, inplace=True), None),
778+
type(None),
779+
)
679780

680781

681782
def test_types_abs() -> None:

tests/test_series.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,35 @@ def test_types_quantile() -> None:
583583

584584
def test_types_clip() -> None:
585585
s = pd.Series([-10, 2, 3, 10])
586-
s.clip(lower=0, upper=5)
587-
s.clip(lower=0, upper=5, inplace=True)
586+
check(
587+
assert_type(s.clip(lower=None, upper=None), "pd.Series[int]"),
588+
pd.Series,
589+
np.integer,
590+
)
591+
check(
592+
assert_type(s.clip(lower=0, upper=5), "pd.Series[int]"), pd.Series, np.integer
593+
)
594+
check(
595+
assert_type(s.clip(lower=0, upper=None), "pd.Series[int]"),
596+
pd.Series,
597+
np.integer,
598+
)
599+
check(
600+
assert_type(s.clip(lower=None, upper=5), "pd.Series[int]"),
601+
pd.Series,
602+
np.integer,
603+
)
604+
check(
605+
assert_type(s.clip(lower=None, upper=None, inplace=True), "pd.Series[int]"),
606+
pd.Series,
607+
np.integer,
608+
)
609+
check(assert_type(s.clip(lower=0, upper=5, inplace=True), None), type(None))
610+
check(assert_type(s.clip(lower=0, upper=None, inplace=True), None), type(None))
611+
check(
612+
assert_type(s.clip(lower=None, upper=5, inplace=True), None),
613+
type(None),
614+
)
588615

589616

590617
def test_types_abs() -> None:

0 commit comments

Comments
 (0)