Skip to content

Commit 79cc01d

Browse files
jbrockmendelmroeschkepre-commit-ci[bot]
authored
API: improve dtype in df.where with EA other (#62038)
Co-authored-by: Matthew Roeschke <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2ad2abd commit 79cc01d

File tree

3 files changed

+51
-15
lines changed

3 files changed

+51
-15
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,14 @@ Other enhancements
9090
- Added support to read and write from and to Apache Iceberg tables with the new :func:`read_iceberg` and :meth:`DataFrame.to_iceberg` functions (:issue:`61383`)
9191
- Errors occurring during SQL I/O will now throw a generic :class:`.DatabaseError` instead of the raw Exception type from the underlying driver manager library (:issue:`60748`)
9292
- Implemented :meth:`Series.str.isascii` and :meth:`Series.str.isascii` (:issue:`59091`)
93+
- Improve the resulting dtypes in :meth:`DataFrame.where` and :meth:`DataFrame.mask` with :class:`ExtensionDtype` ``other`` (:issue:`62038`)
9394
- Improved deprecation message for offset aliases (:issue:`60820`)
9495
- Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`)
9596
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
9697
- Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`)
9798
- Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`)
9899
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
100+
-
99101

100102
.. ---------------------------------------------------------------------------
101103
.. _whatsnew_300.notable_bug_fixes:

pandas/core/generic.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9788,14 +9788,40 @@ def _where(
97889788
raise InvalidIndexError
97899789

97909790
if other.ndim < self.ndim:
9791-
# TODO(EA2D): avoid object-dtype cast in EA case GH#38729
97929791
other = other._values
9793-
if axis == 0:
9794-
other = np.reshape(other, (-1, 1))
9795-
elif axis == 1:
9796-
other = np.reshape(other, (1, -1))
9797-
9798-
other = np.broadcast_to(other, self.shape)
9792+
if isinstance(other, np.ndarray):
9793+
# TODO(EA2D): could also do this for NDArrayBackedEA cases?
9794+
if axis == 0:
9795+
other = np.reshape(other, (-1, 1))
9796+
elif axis == 1:
9797+
other = np.reshape(other, (1, -1))
9798+
9799+
other = np.broadcast_to(other, self.shape)
9800+
else:
9801+
# GH#38729, GH#62038 avoid lossy casting or object-casting
9802+
if axis == 0:
9803+
res_cols = [
9804+
self.iloc[:, i]._where(
9805+
cond.iloc[:, i],
9806+
other,
9807+
)
9808+
for i in range(self.shape[1])
9809+
]
9810+
elif axis == 1:
9811+
# TODO: can we use a zero-copy alternative to "repeat"?
9812+
res_cols = [
9813+
self.iloc[:, i]._where(
9814+
cond.iloc[:, i],
9815+
other[i : i + 1].repeat(len(self)),
9816+
)
9817+
for i in range(self.shape[1])
9818+
]
9819+
res = self._constructor(dict(enumerate(res_cols)))
9820+
res.index = self.index
9821+
res.columns = self.columns
9822+
if inplace:
9823+
return self._update_inplace(res)
9824+
return res.__finalize__(self)
97999825

98009826
# slice me out of the other
98019827
else:

pandas/tests/frame/indexing/test_where.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -698,22 +698,30 @@ def test_where_categorical_filtering(self):
698698
tm.assert_equal(result, expected)
699699

700700
def test_where_ea_other(self):
701-
# GH#38729/GH#38742
701+
# GH#38729/GH#38742, GH#62038
702702
df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
703703
arr = pd.array([7, pd.NA, 9])
704704
ser = Series(arr)
705705
mask = np.ones(df.shape, dtype=bool)
706706
mask[1, :] = False
707707

708-
# TODO: ideally we would get Int64 instead of object
709-
result = df.where(mask, ser, axis=0)
710-
expected = DataFrame({"A": [1, np.nan, 3], "B": [4, np.nan, 6]})
711-
tm.assert_frame_equal(result, expected)
708+
result1 = df.where(mask, ser, axis=0)
709+
expected1 = DataFrame({"A": [1, pd.NA, 3], "B": [4, pd.NA, 6]}, dtype="Int64")
710+
tm.assert_frame_equal(result1, expected1)
712711

713712
ser2 = Series(arr[:2], index=["A", "B"])
714-
expected = DataFrame({"A": [1, 7, 3], "B": [4, np.nan, 6]})
715-
result = df.where(mask, ser2, axis=1)
716-
tm.assert_frame_equal(result, expected)
713+
expected2 = DataFrame({"A": [1, 7, 3], "B": [4, pd.NA, 6]})
714+
expected2["B"] = expected2["B"].astype("Int64")
715+
result2 = df.where(mask, ser2, axis=1)
716+
tm.assert_frame_equal(result2, expected2)
717+
718+
result3 = df.copy()
719+
result3.mask(mask, ser, axis=0, inplace=True)
720+
tm.assert_frame_equal(result3, expected1)
721+
722+
result4 = df.copy()
723+
result4.mask(mask, ser2, axis=1, inplace=True)
724+
tm.assert_frame_equal(result4, expected2)
717725

718726
def test_where_interval_noop(self):
719727
# GH#44181

0 commit comments

Comments
 (0)