diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 3191c077d3c36..58aff1b38238c 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -89,12 +89,14 @@ Other enhancements - Added support to read and write from and to Apache Iceberg tables with the new :func:`read_iceberg` and :meth:`DataFrame.to_iceberg` functions (:issue:`61383`) - Errors occurring during SQL I/O will now throw a generic :class:`.DatabaseError` instead of the raw Exception type from the underlying driver manager library (:issue:`60748`) - Implemented :meth:`Series.str.isascii` and :meth:`Series.str.isascii` (:issue:`59091`) +- Improve the resulting dtypes in :meth:`DataFrame.where` and :meth:`DataFrame.mask` with :class:`ExtensionDtype` ``other`` (:issue:`62038`) - Improved deprecation message for offset aliases (:issue:`60820`) - Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`) - Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`) - Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`) - Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) +- .. --------------------------------------------------------------------------- .. _whatsnew_300.notable_bug_fixes: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index cbd853886a0f4..7c407b03965df 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -9788,14 +9788,40 @@ def _where( raise InvalidIndexError if other.ndim < self.ndim: - # TODO(EA2D): avoid object-dtype cast in EA case GH#38729 other = other._values - if axis == 0: - other = np.reshape(other, (-1, 1)) - elif axis == 1: - other = np.reshape(other, (1, -1)) - - other = np.broadcast_to(other, self.shape) + if isinstance(other, np.ndarray): + # TODO(EA2D): could also do this for NDArrayBackedEA cases? + if axis == 0: + other = np.reshape(other, (-1, 1)) + elif axis == 1: + other = np.reshape(other, (1, -1)) + + other = np.broadcast_to(other, self.shape) + else: + # GH#38729, GH#62038 avoid lossy casting or object-casting + if axis == 0: + res_cols = [ + self.iloc[:, i]._where( + cond.iloc[:, i], + other, + ) + for i in range(self.shape[1]) + ] + elif axis == 1: + # TODO: can we use a zero-copy alternative to "repeat"? + res_cols = [ + self.iloc[:, i]._where( + cond.iloc[:, i], + other[i : i + 1].repeat(len(self)), + ) + for i in range(self.shape[1]) + ] + res = self._constructor(dict(enumerate(res_cols))) + res.index = self.index + res.columns = self.columns + if inplace: + return self._update_inplace(res) + return res.__finalize__(self) # slice me out of the other else: diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index d6570fcda2ee8..cb95014d92809 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -698,22 +698,30 @@ def test_where_categorical_filtering(self): tm.assert_equal(result, expected) def test_where_ea_other(self): - # GH#38729/GH#38742 + # GH#38729/GH#38742, GH#62038 df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) arr = pd.array([7, pd.NA, 9]) ser = Series(arr) mask = np.ones(df.shape, dtype=bool) mask[1, :] = False - # TODO: ideally we would get Int64 instead of object - result = df.where(mask, ser, axis=0) - expected = DataFrame({"A": [1, np.nan, 3], "B": [4, np.nan, 6]}) - tm.assert_frame_equal(result, expected) + result1 = df.where(mask, ser, axis=0) + expected1 = DataFrame({"A": [1, pd.NA, 3], "B": [4, pd.NA, 6]}, dtype="Int64") + tm.assert_frame_equal(result1, expected1) ser2 = Series(arr[:2], index=["A", "B"]) - expected = DataFrame({"A": [1, 7, 3], "B": [4, np.nan, 6]}) - result = df.where(mask, ser2, axis=1) - tm.assert_frame_equal(result, expected) + expected2 = DataFrame({"A": [1, 7, 3], "B": [4, pd.NA, 6]}) + expected2["B"] = expected2["B"].astype("Int64") + result2 = df.where(mask, ser2, axis=1) + tm.assert_frame_equal(result2, expected2) + + result3 = df.copy() + result3.mask(mask, ser, axis=0, inplace=True) + tm.assert_frame_equal(result3, expected1) + + result4 = df.copy() + result4.mask(mask, ser2, axis=1, inplace=True) + tm.assert_frame_equal(result4, expected2) def test_where_interval_noop(self): # GH#44181