Skip to content

Commit 1d0f020

Browse files
committed
API: improve dtype in df.where with EA other
1 parent 4eef5f6 commit 1d0f020

File tree

3 files changed

+51
-15
lines changed

3 files changed

+51
-15
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ Other enhancements
9595
- Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`)
9696
- Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`)
9797
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
98-
98+
- Improve the resulting dtypes in :meth:`DataFrame.where` and :meth:`DataFrame.mask` with :class:`ExtensionDtype` ``other`` (:issue:`??`)
9999
.. ---------------------------------------------------------------------------
100100
.. _whatsnew_300.notable_bug_fixes:
101101

pandas/core/generic.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9788,14 +9788,42 @@ def _where(
97889788
raise InvalidIndexError
97899789

97909790
if other.ndim < self.ndim:
9791-
# TODO(EA2D): avoid object-dtype cast in EA case GH#38729
97929791
other = other._values
9793-
if axis == 0:
9794-
other = np.reshape(other, (-1, 1))
9795-
elif axis == 1:
9796-
other = np.reshape(other, (1, -1))
9797-
9798-
other = np.broadcast_to(other, self.shape)
9792+
if isinstance(other, np.ndarray):
9793+
# TODO(EA2D): could also do this for NDArrayBackedEA cases?
9794+
if axis == 0:
9795+
other = np.reshape(other, (-1, 1))
9796+
elif axis == 1:
9797+
other = np.reshape(other, (1, -1))
9798+
9799+
other = np.broadcast_to(other, self.shape)
9800+
else:
9801+
# GH#38729 avoid lossy casting or object-casting
9802+
if axis == 0:
9803+
res_cols = [
9804+
self.iloc[:, i]._where(
9805+
cond.iloc[:, i],
9806+
other,
9807+
)
9808+
for i in range(self.shape[1])
9809+
]
9810+
elif axis == 1:
9811+
# TODO: can we use a zero-copy alternative to "repeat"?
9812+
res_cols = [
9813+
self.iloc[:, i]._where(
9814+
cond.iloc[:, i],
9815+
other[i : i + 1].repeat(len(self)),
9816+
)
9817+
for i in range(self.shape[1])
9818+
]
9819+
res = self._constructor(
9820+
{i: res_cols[i] for i in range(len(res_cols))}
9821+
)
9822+
res.index = self.index
9823+
res.columns = self.columns
9824+
if inplace:
9825+
return self._update_inplace(res)
9826+
return res.__finalize__(self)
97999827

98009828
# slice me out of the other
98019829
else:

pandas/tests/frame/indexing/test_where.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -705,15 +705,23 @@ def test_where_ea_other(self):
705705
mask = np.ones(df.shape, dtype=bool)
706706
mask[1, :] = False
707707

708-
# TODO: ideally we would get Int64 instead of object
709-
result = df.where(mask, ser, axis=0)
710-
expected = DataFrame({"A": [1, np.nan, 3], "B": [4, np.nan, 6]})
711-
tm.assert_frame_equal(result, expected)
708+
result1 = df.where(mask, ser, axis=0)
709+
expected1 = DataFrame({"A": [1, pd.NA, 3], "B": [4, pd.NA, 6]}, dtype="Int64")
710+
tm.assert_frame_equal(result1, expected1)
712711

713712
ser2 = Series(arr[:2], index=["A", "B"])
714-
expected = DataFrame({"A": [1, 7, 3], "B": [4, np.nan, 6]})
715-
result = df.where(mask, ser2, axis=1)
716-
tm.assert_frame_equal(result, expected)
713+
expected2 = DataFrame({"A": [1, 7, 3], "B": [4, pd.NA, 6]})
714+
expected2["B"] = expected2["B"].astype("Int64")
715+
result2 = df.where(mask, ser2, axis=1)
716+
tm.assert_frame_equal(result2, expected2)
717+
718+
result3 = df.copy()
719+
result3.mask(mask, ser, axis=0, inplace=True)
720+
tm.assert_frame_equal(result3, expected1)
721+
722+
result4 = df.copy()
723+
result4.mask(mask, ser2, axis=1, inplace=True)
724+
tm.assert_frame_equal(result4, expected2)
717725

718726
def test_where_interval_noop(self):
719727
# GH#44181

0 commit comments

Comments
 (0)