Skip to content

Commit 515286e

Browse files
authored
Improve handling of edge cases with nullable dtypes (#3394)
* Improve handling of numeric type edge cases * Ensure numeric data after scaling in Plot * Remove some back-compat flexibility around pd.NA
1 parent 9276e22 commit 515286e

File tree

5 files changed

+37
-22
lines changed

5 files changed

+37
-22
lines changed

seaborn/_core/plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,11 +1392,11 @@ def _setup_scales(
13921392
spec_error = PlotSpecError._during("Scaling operation", var)
13931393
raise spec_error from err
13941394

1395-
# Now the transformed data series are complete, set update the layer data
1395+
# Now the transformed data series are complete, update the layer data
13961396
for layer, new_series in zip(layers, transformed_data):
13971397
layer_df = layer["data"].frame
13981398
if var in layer_df:
1399-
layer_df[var] = new_series
1399+
layer_df[var] = pd.to_numeric(new_series)
14001400

14011401
def _plot_layer(self, p: Plot, layer: Layer) -> None:
14021402

seaborn/_core/rules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def variable_type(
7474
if pd.isna(vector).all():
7575
return VarType("numeric")
7676

77+
# Now drop nulls to simplify further type inference
78+
vector = vector.dropna()
79+
7780
# Special-case binary/boolean data, allow caller to determine
7881
# This triggers a numpy warning when vector has strings/objects
7982
# https://github.com/numpy/numpy/issues/6784
@@ -94,7 +97,7 @@ def variable_type(
9497
boolean_dtypes = ["bool"]
9598
boolean_vector = vector.dtype in boolean_dtypes
9699
else:
97-
boolean_vector = bool(np.isin(vector.dropna(), [0, 1]).all())
100+
boolean_vector = bool(np.isin(vector, [0, 1]).all())
98101
if boolean_vector:
99102
return VarType(boolean_type)
100103

seaborn/_oldcore.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def comp_data(self):
11281128
# it is similar to GH2419, but more complicated because
11291129
# supporting `order` in categorical plots is tricky
11301130
orig = orig[orig.isin(self.var_levels[var])]
1131-
comp = pd.to_numeric(converter.convert_units(orig))
1131+
comp = pd.to_numeric(converter.convert_units(orig)).astype(float)
11321132
if converter.get_scale() == "log":
11331133
comp = np.log10(comp)
11341134
parts.append(pd.Series(comp, orig.index, name=orig.name))
@@ -1505,6 +1505,9 @@ def variable_type(vector, boolean_type="numeric"):
15051505
if pd.isna(vector).all():
15061506
return VariableType("numeric")
15071507

1508+
# At this point, drop nans to simplify further type inference
1509+
vector = vector.dropna()
1510+
15081511
# Special-case binary/boolean data, allow caller to determine
15091512
# This triggers a numpy warning when vector has strings/objects
15101513
# https://github.com/numpy/numpy/issues/6784
@@ -1517,7 +1520,7 @@ def variable_type(vector, boolean_type="numeric"):
15171520
warnings.simplefilter(
15181521
action='ignore', category=(FutureWarning, DeprecationWarning)
15191522
)
1520-
if np.isin(vector.dropna(), [0, 1]).all():
1523+
if np.isin(vector, [0, 1]).all():
15211524
return VariableType(boolean_type)
15221525

15231526
# Defer to positive pandas tests

tests/_core/test_rules.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def test_variable_type():
3838
s = pd.Series([pd.NA, pd.NA])
3939
assert variable_type(s) == "numeric"
4040

41+
s = pd.Series([1, 2, pd.NA], dtype="Int64")
42+
assert variable_type(s) == "numeric"
43+
44+
s = pd.Series([1, 2, pd.NA], dtype=object)
45+
assert variable_type(s) == "numeric"
46+
4147
s = pd.Series(["1", "2", "3"])
4248
assert variable_type(s) == "categorical"
4349

tests/test_core.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,9 @@
2323
categorical_order,
2424
)
2525
from seaborn.utils import desaturate
26-
2726
from seaborn.palettes import color_palette
2827

2928

30-
try:
31-
from pandas import NA as PD_NA
32-
except ImportError:
33-
PD_NA = None
34-
35-
3629
@pytest.fixture(params=[
3730
dict(x="x", y="y"),
3831
dict(x="t", y="y"),
@@ -1302,28 +1295,23 @@ def test_comp_data_category_order(self):
13021295

13031296
@pytest.fixture(
13041297
params=itertools.product(
1305-
[None, np.nan, PD_NA],
1306-
["numeric", "category", "datetime"]
1298+
[None, np.nan, pd.NA],
1299+
["numeric", "category", "datetime"],
13071300
)
13081301
)
1309-
@pytest.mark.parametrize(
1310-
"NA,var_type",
1311-
)
1302+
@pytest.mark.parametrize("NA,var_type")
13121303
def comp_data_missing_fixture(self, request):
13131304

13141305
# This fixture holds the logic for parameterizing
13151306
# the following test (test_comp_data_missing)
13161307

13171308
NA, var_type = request.param
13181309

1319-
if NA is None:
1320-
pytest.skip("No pandas.NA available")
1321-
13221310
comp_data = [0, 1, np.nan, 2, np.nan, 1]
13231311
if var_type == "numeric":
13241312
orig_data = [0, 1, NA, 2, np.inf, 1]
13251313
elif var_type == "category":
1326-
orig_data = ["a", "b", NA, "c", NA, "b"]
1314+
orig_data = ["a", "b", NA, "c", pd.NA, "b"]
13271315
elif var_type == "datetime":
13281316
# Use 1-based numbers to avoid issue on matplotlib<3.2
13291317
# Could simplify the test a bit when we roll off that version
@@ -1343,6 +1331,7 @@ def test_comp_data_missing(self, comp_data_missing_fixture):
13431331
ax = plt.figure().subplots()
13441332
p._attach(ax)
13451333
assert_array_equal(p.comp_data["x"], comp_data)
1334+
assert p.comp_data["x"].dtype == "float"
13461335

13471336
def test_comp_data_duplicate_index(self):
13481337

@@ -1352,6 +1341,15 @@ def test_comp_data_duplicate_index(self):
13521341
p._attach(ax)
13531342
assert_array_equal(p.comp_data["x"], x)
13541343

1344+
def test_comp_data_nullable_dtype(self):
1345+
1346+
x = pd.Series([1, 2, 3, 4], dtype="Int64")
1347+
p = VectorPlotter(variables={"x": x})
1348+
ax = plt.figure().subplots()
1349+
p._attach(ax)
1350+
assert_array_equal(p.comp_data["x"], x)
1351+
assert p.comp_data["x"].dtype == "float"
1352+
13551353
def test_var_order(self, long_df):
13561354

13571355
order = ["c", "b", "a"]
@@ -1456,7 +1454,12 @@ def test_variable_type(self):
14561454
assert variable_type(s) == "numeric"
14571455

14581456
s = pd.Series([np.nan, np.nan])
1459-
# s = pd.Series([pd.NA, pd.NA])
1457+
assert variable_type(s) == "numeric"
1458+
1459+
s = pd.Series([pd.NA, pd.NA])
1460+
assert variable_type(s) == "numeric"
1461+
1462+
s = pd.Series([1, 2, pd.NA], dtype="Int64")
14601463
assert variable_type(s) == "numeric"
14611464

14621465
s = pd.Series(["1", "2", "3"])

0 commit comments

Comments
 (0)