Skip to content

Commit b6737d5

Browse files
authored
Avoid error when determining type of variable containing pd.NA (#3391)
* Avoid error when determining type of variable containing pd.NA * Remove vestigial tests for non-Series inputs to variable_type
1 parent 824c102 commit b6737d5

File tree

3 files changed

+4
-9
lines changed

3 files changed

+4
-9
lines changed

seaborn/_core/rules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def variable_type(
9494
boolean_dtypes = ["bool"]
9595
boolean_vector = vector.dtype in boolean_dtypes
9696
else:
97-
boolean_vector = bool(np.isin(vector, [0, 1, np.nan]).all())
97+
boolean_vector = bool(np.isin(vector.dropna(), [0, 1]).all())
9898
if boolean_vector:
9999
return VarType(boolean_type)
100100

seaborn/_oldcore.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,9 +1495,10 @@ def variable_type(vector, boolean_type="numeric"):
14951495
var_type : 'numeric', 'categorical', or 'datetime'
14961496
Name identifying the type of data in the vector.
14971497
"""
1498+
vector = pd.Series(vector)
14981499

14991500
# If a categorical dtype is set, infer categorical
1500-
if isinstance(getattr(vector, 'dtype', None), pd.CategoricalDtype):
1501+
if isinstance(vector.dtype, pd.CategoricalDtype):
15011502
return VariableType("categorical")
15021503

15031504
# Special-case all-na data, which is always "numeric"
@@ -1516,7 +1517,7 @@ def variable_type(vector, boolean_type="numeric"):
15161517
warnings.simplefilter(
15171518
action='ignore', category=(FutureWarning, DeprecationWarning)
15181519
)
1519-
if np.isin(vector, [0, 1, np.nan]).all():
1520+
if np.isin(vector.dropna(), [0, 1]).all():
15201521
return VariableType(boolean_type)
15211522

15221523
# Defer to positive pandas tests

tests/_core/test_rules.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ def test_variable_type():
2828
assert variable_type(s) == "numeric"
2929
assert variable_type(s.astype(int)) == "numeric"
3030
assert variable_type(s.astype(object)) == "numeric"
31-
assert variable_type(s.to_numpy()) == "numeric"
32-
assert variable_type(s.to_list()) == "numeric"
3331

3432
s = pd.Series([1, 2, 3, np.nan], dtype=object)
3533
assert variable_type(s) == "numeric"
@@ -42,8 +40,6 @@ def test_variable_type():
4240

4341
s = pd.Series(["1", "2", "3"])
4442
assert variable_type(s) == "categorical"
45-
assert variable_type(s.to_numpy()) == "categorical"
46-
assert variable_type(s.to_list()) == "categorical"
4743

4844
s = pd.Series([True, False, False])
4945
assert variable_type(s) == "numeric"
@@ -62,8 +58,6 @@ def test_variable_type():
6258
s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)])
6359
assert variable_type(s) == "datetime"
6460
assert variable_type(s.astype(object)) == "datetime"
65-
assert variable_type(s.to_numpy()) == "datetime"
66-
assert variable_type(s.to_list()) == "datetime"
6761

6862

6963
def test_categorical_order():

0 commit comments

Comments
 (0)