Skip to content

Commit 5706029

Browse files
committed
finish dataframe.replace typing
1 parent b04f5e6 commit 5706029

File tree

2 files changed

+138
-24
lines changed

2 files changed

+138
-24
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -802,23 +802,41 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
802802
Scalar
803803
| NAType
804804
| Sequence[Scalar | Pattern]
805+
| Mapping[Scalar | Pattern, Scalar]
805806
| Mapping[Hashable, Scalar | Pattern]
807+
| Mapping[Hashable, Sequence[Scalar | Pattern]]
808+
| Mapping[Hashable, Mapping[Scalar | Pattern, Scalar]]
809+
| Mapping[Hashable, Series[Any]]
806810
| Series[Any]
807811
| Pattern
808812
| None
809813
) = ...,
810814
value: (
811-
Scalar | NAType | Sequence[Scalar] | Mapping[Hashable, Scalar] | None
815+
Scalar
816+
| NAType
817+
| Sequence[Scalar]
818+
| Mapping[Scalar, Scalar]
819+
| Mapping[Hashable, Scalar]
820+
| Mapping[Hashable, Sequence[Scalar]]
821+
| Mapping[Hashable, Mapping[Scalar, Scalar]]
822+
| Mapping[Hashable, Series[Any]]
823+
| Series[Any]
824+
| None
812825
) = ...,
813826
*,
814827
inplace: Literal[True],
815828
regex: (
816-
bool
817-
| str
818-
| Pattern
819-
| Sequence[str | Pattern]
820-
| Mapping[Hashable, str | Pattern]
829+
Scalar
830+
| NAType
831+
| Sequence[Scalar | Pattern]
832+
| Mapping[Scalar | Pattern, Scalar]
833+
| Mapping[Hashable, Scalar | Pattern]
834+
| Mapping[Hashable, Sequence[Scalar | Pattern]]
835+
| Mapping[Hashable, Mapping[Scalar | Pattern, Scalar]]
836+
| Mapping[Hashable, Series[Any]]
821837
| Series[Any]
838+
| Pattern
839+
| None
822840
) = ...,
823841
) -> None: ...
824842
@overload
@@ -828,23 +846,41 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
828846
Scalar
829847
| NAType
830848
| Sequence[Scalar | Pattern]
849+
| Mapping[Scalar | Pattern, Scalar]
831850
| Mapping[Hashable, Scalar | Pattern]
851+
| Mapping[Hashable, Sequence[Scalar | Pattern]]
852+
| Mapping[Hashable, Mapping[Scalar | Pattern, Scalar]]
853+
| Mapping[Hashable, Series[Any]]
832854
| Series[Any]
833855
| Pattern
834856
| None
835857
) = ...,
836858
value: (
837-
Scalar | NAType | Sequence[Scalar] | Mapping[Hashable, Scalar] | None
859+
Scalar
860+
| NAType
861+
| Sequence[Scalar]
862+
| Mapping[Scalar, Scalar]
863+
| Mapping[Hashable, Scalar]
864+
| Mapping[Hashable, Sequence[Scalar]]
865+
| Mapping[Hashable, Mapping[Scalar, Scalar]]
866+
| Mapping[Hashable, Series[Any]]
867+
| Series[Any]
868+
| None
838869
) = ...,
839870
*,
840871
inplace: Literal[False] = ...,
841872
regex: (
842-
bool
843-
| str
844-
| Pattern
845-
| Sequence[str | Pattern]
846-
| Mapping[Hashable, str | Pattern]
873+
Scalar
874+
| NAType
875+
| Sequence[Scalar | Pattern]
876+
| Mapping[Scalar | Pattern, Scalar]
877+
| Mapping[Hashable, Scalar | Pattern]
878+
| Mapping[Hashable, Sequence[Scalar | Pattern]]
879+
| Mapping[Hashable, Mapping[Scalar | Pattern, Scalar]]
880+
| Mapping[Hashable, Series[Any]]
847881
| Series[Any]
882+
| Pattern
883+
| None
848884
) = ...,
849885
) -> Self: ...
850886
def shift(

tests/test_frame.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,10 +2574,31 @@ def test_types_replace() -> None:
25742574
def test_dataframe_replace() -> None:
25752575
df = pd.DataFrame({"col1": ["a", "ab", "ba"]})
25762576
pattern = re.compile(r"^a.*")
2577+
# global scalar replacement
25772578
check(assert_type(df.replace("a", "x"), pd.DataFrame), pd.DataFrame)
2579+
check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame)
2580+
check(assert_type(df.replace("a", "x", regex=True), pd.DataFrame), pd.DataFrame)
2581+
check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame)
2582+
check(assert_type(df.replace(regex="a", value="x"), pd.DataFrame), pd.DataFrame)
2583+
check(assert_type(df.replace(regex=pattern, value="x"), pd.DataFrame), pd.DataFrame)
2584+
# global sequence replacement
2585+
check(assert_type(df.replace(["a"], ["x"]), pd.DataFrame), pd.DataFrame)
2586+
check(assert_type(df.replace([pattern], ["x"]), pd.DataFrame), pd.DataFrame)
2587+
check(assert_type(df.replace(regex=["a"], value=["x"]), pd.DataFrame), pd.DataFrame)
2588+
check(
2589+
assert_type(df.replace(regex=[pattern], value=["x"]), pd.DataFrame),
2590+
pd.DataFrame,
2591+
)
2592+
# global mapping
25782593
check(assert_type(df.replace({"a": "x"}), pd.DataFrame), pd.DataFrame)
2594+
check(assert_type(df.replace({pattern: "x"}), pd.DataFrame), pd.DataFrame)
25792595
check(assert_type(df.replace(pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame)
2580-
check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame)
2596+
check(assert_type(df.replace(regex={"a": "x"}), pd.DataFrame), pd.DataFrame)
2597+
check(assert_type(df.replace(regex={pattern: "x"}), pd.DataFrame), pd.DataFrame)
2598+
check(
2599+
assert_type(df.replace(regex=pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame
2600+
)
2601+
# per-column scalar-scalar replacement
25812602
check(
25822603
assert_type(df.replace({"col1": "a"}, {"col1": "x"}), pd.DataFrame),
25832604
pd.DataFrame,
@@ -2586,26 +2607,83 @@ def test_dataframe_replace() -> None:
25862607
assert_type(df.replace({"col1": pattern}, {"col1": "x"}), pd.DataFrame),
25872608
pd.DataFrame,
25882609
)
2589-
check(assert_type(df.replace(["a"], ["x"]), pd.DataFrame), pd.DataFrame)
2590-
check(assert_type(df.replace([pattern], ["x"]), pd.DataFrame), pd.DataFrame)
2591-
check(assert_type(df.replace("^a.*", "x", regex=True), pd.DataFrame), pd.DataFrame)
2592-
check(assert_type(df.replace(value="x", regex="^a."), pd.DataFrame), pd.DataFrame)
2593-
check(assert_type(df.replace(value="x", regex=["^a."]), pd.DataFrame), pd.DataFrame)
25942610
check(
2595-
assert_type(df.replace(value="x", regex={"col1": "^a."}), pd.DataFrame),
2611+
assert_type(
2612+
df.replace(pd.Series({"col1": "a"}), pd.Series({"col1": "x"})), pd.DataFrame
2613+
),
25962614
pd.DataFrame,
25972615
)
2598-
check(assert_type(df.replace(value="x", regex=pattern), pd.DataFrame), pd.DataFrame)
25992616
check(
2600-
assert_type(df.replace(value="x", regex=[pattern]), pd.DataFrame), pd.DataFrame
2617+
assert_type(df.replace(regex={"col1": "a"}, value={"col1": "x"}), pd.DataFrame),
2618+
pd.DataFrame,
26012619
)
26022620
check(
2603-
assert_type(df.replace(value="x", regex={"col1": pattern}), pd.DataFrame),
2621+
assert_type(
2622+
df.replace(regex={"col1": pattern}, value={"col1": "x"}), pd.DataFrame
2623+
),
26042624
pd.DataFrame,
26052625
)
2606-
check(assert_type(df.replace(regex={"a": "x"}), pd.DataFrame), pd.DataFrame)
26072626
check(
2608-
assert_type(df.replace(regex=pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame
2627+
assert_type(
2628+
df.replace(regex=pd.Series({"col1": "a"}), value=pd.Series({"col1": "x"})),
2629+
pd.DataFrame,
2630+
),
2631+
pd.DataFrame,
2632+
)
2633+
# per-column sequence replacement
2634+
check(
2635+
assert_type(df.replace({"col1": ["a"]}, {"col1": ["x"]}), pd.DataFrame),
2636+
pd.DataFrame,
2637+
)
2638+
check(
2639+
assert_type(df.replace({"col1": [pattern]}, {"col1": ["x"]}), pd.DataFrame),
2640+
pd.DataFrame,
2641+
)
2642+
check(
2643+
assert_type(
2644+
df.replace(pd.Series({"col1": ["a"]}), pd.Series({"col1": ["x"]})),
2645+
pd.DataFrame,
2646+
),
2647+
pd.DataFrame,
2648+
)
2649+
check(
2650+
assert_type(
2651+
df.replace(regex={"col1": ["a"]}, value={"col1": ["x"]}), pd.DataFrame
2652+
),
2653+
pd.DataFrame,
2654+
)
2655+
check(
2656+
assert_type(
2657+
df.replace(regex={"col1": [pattern]}, value={"col1": ["x"]}), pd.DataFrame
2658+
),
2659+
pd.DataFrame,
2660+
)
2661+
check(
2662+
assert_type(
2663+
df.replace(
2664+
regex=pd.Series({"col1": ["a"]}), value=pd.Series({"col1": ["x"]})
2665+
),
2666+
pd.DataFrame,
2667+
),
2668+
pd.DataFrame,
2669+
)
2670+
# per-column mapping
2671+
check(assert_type(df.replace({"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame)
2672+
check(assert_type(df.replace({"col1": {pattern: "x"}}), pd.DataFrame), pd.DataFrame)
2673+
check(
2674+
assert_type(df.replace({"col1": pd.Series({"a": "x"})}), pd.DataFrame),
2675+
pd.DataFrame,
2676+
)
2677+
check(
2678+
assert_type(df.replace(regex={"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame
2679+
)
2680+
check(
2681+
assert_type(df.replace(regex={"col1": {pattern: "x"}}), pd.DataFrame),
2682+
pd.DataFrame,
2683+
)
2684+
check(
2685+
assert_type(df.replace(regex={"col1": pd.Series({"a": "x"})}), pd.DataFrame),
2686+
pd.DataFrame,
26092687
)
26102688

26112689

0 commit comments

Comments
 (0)