diff --git a/pandas/tests/frame/methods/test_replace.py b/pandas/tests/frame/methods/test_replace.py index b2320798ea9a2..e5bd8a9c45b55 100644 --- a/pandas/tests/frame/methods/test_replace.py +++ b/pandas/tests/frame/methods/test_replace.py @@ -334,7 +334,6 @@ def test_regex_replace_str_to_numeric(self, mix_abc): return_value = res3.replace(regex=r"\s*\.\s*", value=0, inplace=True) assert return_value is None expec = DataFrame({"a": mix_abc["a"], "b": ["a", "b", 0, 0], "c": mix_abc["c"]}) - # TODO(infer_string) expec["c"] = expec["c"].astype(object) tm.assert_frame_equal(res, expec) tm.assert_frame_equal(res2, expec) @@ -1469,21 +1468,24 @@ def test_regex_replace_scalar( tm.assert_frame_equal(result, expected) @pytest.mark.parametrize("regex", [False, True]) - def test_replace_regex_dtype_frame(self, regex): + @pytest.mark.parametrize("value", [1, "1"]) + def test_replace_regex_dtype_frame(self, regex, value): # GH-48644 df1 = DataFrame({"A": ["0"], "B": ["0"]}) - expected_df1 = DataFrame({"A": [1], "B": [1]}, dtype=object) - result_df1 = df1.replace(to_replace="0", value=1, regex=regex) + # When value is an integer, coerce result to object. + # When value is a string, infer the correct string dtype. + dtype = object if value == 1 else None + + expected_df1 = DataFrame({"A": [value], "B": [value]}, dtype=dtype) + result_df1 = df1.replace(to_replace="0", value=value, regex=regex) tm.assert_frame_equal(result_df1, expected_df1) df2 = DataFrame({"A": ["0"], "B": ["1"]}) if regex: - # TODO(infer_string): both string columns get cast to object, - # while only needed for column A - expected_df2 = DataFrame({"A": [1], "B": ["1"]}, dtype=object) + expected_df2 = DataFrame({"A": [value], "B": ["1"]}, dtype=dtype) else: - expected_df2 = DataFrame({"A": Series([1], dtype=object), "B": ["1"]}) - result_df2 = df2.replace(to_replace="0", value=1, regex=regex) + expected_df2 = DataFrame({"A": Series([value], dtype=dtype), "B": ["1"]}) + result_df2 = df2.replace(to_replace="0", value=value, regex=regex) tm.assert_frame_equal(result_df2, expected_df2) def test_replace_with_value_also_being_replaced(self):