diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 1b31c8795..478f60da0 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -10,6 +10,7 @@ from collections.abc import ( import datetime from datetime import tzinfo from os import PathLike +from re import Pattern import sys from typing import ( Any, @@ -36,6 +37,7 @@ from typing_extensions import ( ) from pandas._libs.interval import Interval +from pandas._libs.missing import NAType from pandas._libs.tslibs import ( BaseOffset, Period, @@ -731,7 +733,17 @@ InterpolateOptions: TypeAlias = Literal[ "cubicspline", "from_derivatives", ] -ReplaceMethod: TypeAlias = Literal["pad", "ffill", "bfill"] +# Can be passed to `to_replace`, `value`, or `regex` in `Series.replace`. +# `DataFrame.replace` also accepts mappings of these. +ReplaceValue: TypeAlias = ( + Scalar + | Pattern + | NAType + | Sequence[Scalar | Pattern] + | Mapping[Hashable, Scalar] + | Series[Any] + | None +) SortKind: TypeAlias = Literal["quicksort", "mergesort", "heapsort", "stable"] NaPosition: TypeAlias = Literal["first", "last"] JoinHow: TypeAlias = Literal["left", "right", "outer", "inner"] diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 37a2d13ca..344dd3a39 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -8,7 +8,6 @@ from collections.abc import ( Sequence, ) import datetime as dt -from re import Pattern import sys from typing import ( Any, @@ -113,7 +112,7 @@ from pandas._typing import ( RandomState, ReadBuffer, Renamer, - ReplaceMethod, + ReplaceValue, Scalar, ScalarT, SequenceNotStr, @@ -799,24 +798,20 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): @overload def replace( self, - to_replace=..., - value: Scalar | NAType | Sequence | Mapping | Pattern | None = ..., + to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., *, inplace: Literal[True], - limit: int | None = ..., - regex=..., - method: ReplaceMethod = ..., + regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., ) -> None: ... @overload def replace( self, - to_replace=..., - value: Scalar | NAType | Sequence | Mapping | Pattern | None = ..., + to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., *, inplace: Literal[False] = ..., - limit: int | None = ..., - regex=..., - method: ReplaceMethod = ..., + regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., ) -> Self: ... def shift( self, diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index ecf65ab51..2ade583c7 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -24,7 +24,10 @@ from typing import ( overload, ) -from _typing import TimeZones +from _typing import ( + ReplaceValue, + TimeZones, +) from matplotlib.axes import ( Axes as PlotAxes, SubplotBase, @@ -141,7 +144,6 @@ from pandas._typing import ( QuantileInterpolation, RandomState, Renamer, - ReplaceMethod, Scalar, ScalarT, SequenceNotStr, @@ -1089,24 +1091,20 @@ class Series(IndexOpsMixin[S1], NDFrame): @overload def replace( self, - to_replace: _str | list | dict | Series[S1] | float | None = ..., - value: Scalar | NAType | dict | list | _str | None = ..., + to_replace: ReplaceValue = ..., + value: ReplaceValue = ..., *, - limit: int | None = ..., - regex=..., - method: ReplaceMethod = ..., + regex: ReplaceValue = ..., inplace: Literal[True], ) -> None: ... @overload def replace( self, - to_replace: _str | list | dict | Series[S1] | float | None = ..., - value: Scalar | NAType | dict | list | _str | None = ..., + to_replace: ReplaceValue = ..., + value: ReplaceValue = ..., *, + regex: ReplaceValue = ..., inplace: Literal[False] = ..., - limit: int | None = ..., - regex=..., - method: ReplaceMethod = ..., ) -> Series[S1]: ... def shift( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index b74118b34..7696008e3 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -14,6 +14,7 @@ import io import itertools from pathlib import Path +import re import string import sys from typing import ( @@ -2570,6 +2571,121 @@ def test_types_replace() -> None: assert assert_type(df.replace(1, 2, inplace=True), None) is None +def test_dataframe_replace() -> None: + df = pd.DataFrame({"col1": ["a", "ab", "ba"]}) + pattern = re.compile(r"^a.*") + check(assert_type(df.replace("a", "x"), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace("a", "x", regex=True), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(regex="a", value="x"), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(regex=pattern, value="x"), pd.DataFrame), pd.DataFrame) + + check(assert_type(df.replace(["a"], ["x"]), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace([pattern], ["x"]), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(regex=["a"], value=["x"]), pd.DataFrame), pd.DataFrame) + check( + assert_type(df.replace(regex=[pattern], value=["x"]), pd.DataFrame), + pd.DataFrame, + ) + + check(assert_type(df.replace({"a": "x"}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace({pattern: "x"}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(regex={"a": "x"}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(regex={pattern: "x"}), pd.DataFrame), pd.DataFrame) + check( + assert_type(df.replace(regex=pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame + ) + + check( + assert_type(df.replace({"col1": "a"}, {"col1": "x"}), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(df.replace({"col1": pattern}, {"col1": "x"}), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + df.replace(pd.Series({"col1": "a"}), pd.Series({"col1": "x"})), pd.DataFrame + ), + pd.DataFrame, + ) + check( + assert_type(df.replace(regex={"col1": "a"}, value={"col1": "x"}), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + df.replace(regex={"col1": pattern}, value={"col1": "x"}), pd.DataFrame + ), + pd.DataFrame, + ) + check( + assert_type( + df.replace(regex=pd.Series({"col1": "a"}), value=pd.Series({"col1": "x"})), + pd.DataFrame, + ), + pd.DataFrame, + ) + + check( + assert_type(df.replace({"col1": ["a"]}, {"col1": ["x"]}), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(df.replace({"col1": [pattern]}, {"col1": ["x"]}), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + df.replace(pd.Series({"col1": ["a"]}), pd.Series({"col1": ["x"]})), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.replace(regex={"col1": ["a"]}, value={"col1": ["x"]}), pd.DataFrame + ), + pd.DataFrame, + ) + check( + assert_type( + df.replace(regex={"col1": [pattern]}, value={"col1": ["x"]}), pd.DataFrame + ), + pd.DataFrame, + ) + check( + assert_type( + df.replace( + regex=pd.Series({"col1": ["a"]}), value=pd.Series({"col1": ["x"]}) + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + + check(assert_type(df.replace({"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace({"col1": {pattern: "x"}}), pd.DataFrame), pd.DataFrame) + check( + assert_type(df.replace({"col1": pd.Series({"a": "x"})}), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(df.replace(regex={"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(df.replace(regex={"col1": {pattern: "x"}}), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(df.replace(regex={"col1": pd.Series({"a": "x"})}), pd.DataFrame), + pd.DataFrame, + ) + + def test_loop_dataframe() -> None: # GH 70 df = pd.DataFrame({"x": [1, 2, 3]}) diff --git a/tests/test_series.py b/tests/test_series.py index 7d0537ed6..0c700fcbd 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1410,6 +1410,40 @@ def test_types_replace() -> None: assert assert_type(s.replace(1, 2, inplace=True), None) is None +def test_series_replace() -> None: + s: pd.Series[str] = pd.DataFrame({"col1": ["a", "ab", "ba"]})["col1"] + pattern = re.compile(r"^a.*") + check(assert_type(s.replace("a", "x"), "pd.Series[str]"), pd.Series) + check(assert_type(s.replace(pattern, "x"), "pd.Series[str]"), pd.Series) + check( + assert_type(s.replace({"a": "z"}), "pd.Series[str]"), + pd.Series, + ) + check( + assert_type(s.replace(pd.Series({"a": "z"})), "pd.Series[str]"), + pd.Series, + ) + check( + assert_type(s.replace({pattern: "z"}), "pd.Series[str]"), + pd.Series, + ) + check(assert_type(s.replace(["a"], ["x"]), "pd.Series[str]"), pd.Series) + check(assert_type(s.replace([pattern], ["x"]), "pd.Series[str]"), pd.Series) + check(assert_type(s.replace(r"^a.*", "x", regex=True), "pd.Series[str]"), pd.Series) + check(assert_type(s.replace(value="x", regex=r"^a.*"), "pd.Series[str]"), pd.Series) + check( + assert_type(s.replace(value="x", regex=[r"^a.*"]), "pd.Series[str]"), pd.Series + ) + check(assert_type(s.replace(value="x", regex=pattern), "pd.Series[str]"), pd.Series) + check( + assert_type(s.replace(value="x", regex=[pattern]), "pd.Series[str]"), pd.Series + ) + check(assert_type(s.replace(regex={"a": "x"}), "pd.Series[str]"), pd.Series) + check( + assert_type(s.replace(regex=pd.Series({"a": "x"})), "pd.Series[str]"), pd.Series + ) + + def test_cat_accessor() -> None: # GH 43 s: pd.Series[str] = pd.Series(