Skip to content

Commit 7e0d4e0

Browse files
committed
handle list with nan
1 parent e557c05 commit 7e0d4e0

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

pandas/core/generic.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
from pandas._libs import lib
3030
from pandas._libs.lib import is_range_indexer
31-
from pandas._libs.missing import NA
3231
from pandas._libs.tslibs import (
3332
Period,
3433
Timestamp,
@@ -114,6 +113,7 @@
114113
is_bool_dtype,
115114
is_dict_like,
116115
is_extension_array_dtype,
116+
is_float_dtype,
117117
is_list_like,
118118
is_number,
119119
is_numeric_dtype,
@@ -9713,6 +9713,7 @@ def _where(
97139713
if axis is not None:
97149714
axis = self._get_axis_number(axis)
97159715

9716+
has_nan: bool = False
97169717
# align the cond to same shape as myself
97179718
cond = common.apply_if_callable(cond, self)
97189719
if isinstance(cond, NDFrame):
@@ -9728,9 +9729,13 @@ def _where(
97289729
else:
97299730
if not hasattr(cond, "shape"):
97309731
cond = np.asanyarray(cond)
9732+
if is_float_dtype(cond) and np.isnan(cond).any():
9733+
has_nan = True
97319734
if cond.shape != self.shape:
97329735
raise ValueError("Array conditional must be same shape as self")
97339736
cond = self._constructor(cond, **self._construct_axes_dict(), copy=False)
9737+
if has_nan:
9738+
cond = cond.replace({0.0: False, 1.0: True})
97349739
cond = cond.fillna(True)
97359740

97369741
# make sure we are boolean
@@ -10113,13 +10118,13 @@ def mask(
1011310118
# see gh-21891
1011410119
if not hasattr(cond, "__invert__"):
1011510120
cond = np.array(cond)
10121+
if is_float_dtype(cond) and np.isnan(cond).any():
10122+
cond = cond.astype(bool)
1011610123

1011710124
if isinstance(cond, np.ndarray):
10118-
if all(x is NA or lib.is_bool(x) or x is np.nan for x in cond.flatten()):
10119-
if not cond.flags.writeable:
10120-
cond.setflags(write=True)
10121-
cond[isna(cond)] = False
10122-
cond = cond.astype(bool)
10125+
if not cond.flags.writeable:
10126+
cond.setflags(write=True)
10127+
cond[isna(cond)] = False
1012310128

1012410129
return self._where(
1012510130
~cond,

0 commit comments

Comments
 (0)