diff --git a/seaborn/_compat.py b/seaborn/_compat.py index bd2f0c12d3..29e2348635 100644 --- a/seaborn/_compat.py +++ b/seaborn/_compat.py @@ -56,6 +56,18 @@ def __call__(self, value, clip=None): return new_norm +def is_registered_colormap(name): + """Handle changes to matplotlib colormap interface in 3.5.""" + if _version_predates(mpl, "3.5"): + try: + mpl.cm.get_cmap(name) + return True + except ValueError: + return False + else: + return name in mpl.colormaps + + def get_colormap(name): """Handle changes to matplotlib colormap interface in 3.6.""" try: diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index 4e2df91b49..61892db6bb 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -14,6 +14,7 @@ from seaborn._core.rules import categorical_order, variable_type from seaborn.palettes import QUAL_PALETTES, color_palette, blend_palette from seaborn.utils import get_color_cycle +from seaborn._compat import is_registered_colormap from typing import Any, Callable, Tuple, List, Union, Optional @@ -43,6 +44,7 @@ class Property: """Base class for visual properties that can be set directly or be data scaling.""" + _TRANS_ARGS = ["log", "symlog", "logit", "pow", "sqrt"] # When True, scales for this property will populate the legend by default legend = False @@ -76,9 +78,8 @@ def infer_scale(self, arg: Any, data: Series) -> Scale: # (e.g. color). How best to handle that? One option is to call super after # handling property-specific possibilities (e.g. for color check that the # arg is not a valid palette name) but that could get tricky. - trans_args = ["log", "symlog", "logit", "pow", "sqrt"] if isinstance(arg, str): - if any(arg.startswith(k) for k in trans_args): + if any(arg.startswith(k) for k in self._TRANS_ARGS): # TODO validate numeric type? That should happen centrally somewhere return Continuous(trans=arg) else: @@ -183,6 +184,8 @@ def infer_scale(self, arg: Any, data: Series) -> Scale: return Nominal(arg) elif var_type == "datetime": return Temporal(arg) + elif isinstance(arg, str) and any(arg.startswith(k) for k in self._TRANS_ARGS): + return Continuous(trans=arg) # TODO other variable types else: return Continuous(arg) @@ -607,8 +610,6 @@ def infer_scale(self, arg: Any, data: Series) -> Scale: if callable(arg): return Continuous(arg) - # TODO Do we accept str like "log", "pow", etc. for semantics? - if not isinstance(arg, str): msg = " ".join([ f"A single scale argument for {self.variable} variables must be", @@ -619,7 +620,14 @@ def infer_scale(self, arg: Any, data: Series) -> Scale: if arg in QUAL_PALETTES: return Nominal(arg) elif var_type == "numeric": - return Continuous(arg) + # Prioritize actual colormaps, e.g. if a colormap named "pow" exists + if is_registered_colormap(arg): + return Continuous(arg) + elif any(arg.startswith(k) for k in self._TRANS_ARGS): + return Continuous(trans=arg) + else: + return Continuous(arg) + # TODO implement scales for date variables and any others. else: return Nominal(arg) diff --git a/tests/_core/test_properties.py b/tests/_core/test_properties.py index c87dd918d0..ed445f7bf6 100644 --- a/tests/_core/test_properties.py +++ b/tests/_core/test_properties.py @@ -238,6 +238,17 @@ def test_inference(self, values, data_type, scale_class, vectors): assert isinstance(scale, scale_class) assert scale.values == values + @pytest.mark.parametrize( + "trans", + ["pow", "sqrt", "log", "symlog", "logit", "log2", "symlog100"] + ) + def test_inference_magic_args(self, trans, num_vector): + + scale = Color().infer_scale(trans, num_vector) + assert isinstance(scale, Continuous) + assert scale.trans == trans + assert scale.values is None + def test_standardization(self): f = Color().standardize @@ -531,6 +542,17 @@ def test_mapped_interval_categorical(self, cat_vector): n = cat_vector.nunique() assert_array_equal(mapping([n - 1, 0]), self.prop().default_range) + @pytest.mark.parametrize( + "trans", + ["pow", "sqrt", "log", "symlog", "log13", "logit", "symlog37"] + ) + def test_inference_magic_args(self, trans, num_vector): + + scale = self.prop().infer_scale(trans, num_vector) + assert isinstance(scale, Continuous) + assert scale.trans == trans + assert scale.values is None + def test_bad_scale_values_numeric_data(self, num_vector): prop_name = self.prop.__name__.lower()