diff --git a/tests/baseline_images/test_draw/test_manual_legend.png b/tests/baseline_images/test_draw/test_manual_legend.png index e64d63a9e..de60242b5 100644 Binary files a/tests/baseline_images/test_draw/test_manual_legend.png and b/tests/baseline_images/test_draw/test_manual_legend.png differ diff --git a/tests/baseline_images/test_draw/test_manual_legend_styles.png b/tests/baseline_images/test_draw/test_manual_legend_styles.png new file mode 100644 index 000000000..4b7d49c4f Binary files /dev/null and b/tests/baseline_images/test_draw/test_manual_legend_styles.png differ diff --git a/tests/baseline_images/test_gridsearch/test_base/test_gridsearchcolorplot.png b/tests/baseline_images/test_gridsearch/test_base/test_gridsearchcolorplot.png new file mode 100644 index 000000000..9c948574a Binary files /dev/null and b/tests/baseline_images/test_gridsearch/test_base/test_gridsearchcolorplot.png differ diff --git a/tests/baseline_images/test_gridsearch/test_base/test_numpy_integration.png b/tests/baseline_images/test_gridsearch/test_base/test_numpy_integration.png new file mode 100644 index 000000000..6322640bf Binary files /dev/null and b/tests/baseline_images/test_gridsearch/test_base/test_numpy_integration.png differ diff --git a/tests/baseline_images/test_gridsearch/test_base/test_pandas_integration.png b/tests/baseline_images/test_gridsearch/test_base/test_pandas_integration.png new file mode 100644 index 000000000..6322640bf Binary files /dev/null and b/tests/baseline_images/test_gridsearch/test_base/test_pandas_integration.png differ diff --git a/tests/baseline_images/test_gridsearch/test_base/test_quick_method.png b/tests/baseline_images/test_gridsearch/test_base/test_quick_method.png new file mode 100644 index 000000000..f5b333634 Binary files /dev/null and b/tests/baseline_images/test_gridsearch/test_base/test_quick_method.png differ diff --git a/tests/test_draw.py b/tests/test_draw.py index 6a2a9dd70..e9372bbec 100644 --- a/tests/test_draw.py +++ b/tests/test_draw.py @@ -31,9 +31,31 @@ def test_manual_legend_uneven_colors(): """ Raise exception when colors and labels are mismatched in manual_legend """ - with pytest.raises(YellowbrickValueError, match="same number of colors as labels"): + with pytest.raises(YellowbrickValueError, + match="list of length equal to the number of labels"): manual_legend(None, ("a", "b", "c"), ("r", "g")) +def test_manual_legend_styles_malformed_input(): + """ + Raise exception when styles and/or colors are not lists of same length + as labels + """ + + # styles should be a list of strings + with pytest.raises(YellowbrickValueError, + match="Please specify the styles parameter as a list of strings"): + manual_legend(None, ("a", "b", "c"), styles="ro") + + # styles should be a list of same len() as labels + with pytest.raises(YellowbrickValueError, + match="list of length equal to the number of labels"): + manual_legend(None, ("a", "b", "c"), styles=("ro", "--")) + + # if colors is passed in alongside styles, it should be of same length + with pytest.raises(YellowbrickValueError, + match="list of length equal to the number of labels"): + manual_legend(None, ("a", "b", "c"), ("r", "g"), styles=("ro", "b--", "--")) + @pytest.fixture(scope="class") def data(request): @@ -83,7 +105,45 @@ def test_manual_legend(self): ) # Assert image similarity - self.assert_images_similar(ax=ax, tol=0.5) + self.assert_images_similar(ax=ax, tol=0.5, remove_legend=False) + + def test_manual_legend_styles(self): + """ + Check that the styles argument to manual_legend is correctly + processed, including its being overridden by the colors argument + """ + + # Draw a random scatter plot + random = np.random.RandomState(42) + + Ax, Ay = random.normal(50, 2, 100), random.normal(50, 3, 100) + Bx, By = random.normal(42, 3, 100), random.normal(44, 1, 100) + Cx, Cy = random.normal(20, 10, 100), random.normal(30, 1, 100) + Dx, Dy = random.normal(33, 5, 100), random.normal(22, 2, 100) + + _, ax = plt.subplots() + ax.scatter(Ax, Ay, c="r", alpha=0.35, label="a") + ax.scatter(Bx, By, c="g", alpha=0.35, label="b") + ax.scatter(Cx, Cy, c="b", alpha=0.35, label="c") + ax.scatter(Dx, Dy, c="y", alpha=0.35, label="d") + + # Four style/color combinations are tested here: + # (1) "blue" color should override the "r" of "ro" style + # (2) blank color should, of course, be overriden by the "g" of "-g" + # (3) None color should also be overridden by the third style, but + # since a color is not specified there either, the entry should + # default to black. + # (4) Linestyle, marker, and color are all unspecified. The entry should + # default to a solid black line. + styles = ["ro", "-g", "--", ""] + labels = ("a", "b", "c", "d") + colors = ("blue", "", None, None) + manual_legend( + ax, labels, colors, styles=styles, frameon=True, loc="upper left" + ) + + # Assert image similarity + self.assert_images_similar(ax=ax, tol=0.5, remove_legend=False) def test_vertical_bar_stack(self): """ diff --git a/tests/test_gridsearch/test_base.py b/tests/test_gridsearch/test_base.py new file mode 100644 index 000000000..f1e49c06b --- /dev/null +++ b/tests/test_gridsearch/test_base.py @@ -0,0 +1,133 @@ +# # tests.test_gridsearch.test_base.py +# # Test the GridSearchColorPlot (standard and quick visualizers). +# # +# # Author: Tan Tran +# # Created: Sat Aug 29 12:00:00 2020 -0400 +# # +# # Copyright (C) 2020 The scikit-yb developers +# # For license information, see LICENSE.txt +# # + +""" +Test the GridSearchColorPlot visualizer. +""" + +# ########################################################################## +# ## Imports +# ########################################################################## + +import pytest + +from tests.base import VisualTestCase +from tests.fixtures import Dataset, Split + +from yellowbrick.datasets import load_occupancy +from yellowbrick.gridsearch import GridSearchColorPlot, gridsearch_color_plot + +from sklearn.datasets import make_classification +from sklearn.svm import SVC +from sklearn.model_selection import GridSearchCV + +import pandas as pd + +# ########################################################################## +# ## Test fixtures +# ########################################################################## + +@pytest.fixture(scope="class") +def binary(request): + """ + Creates a random binary classification dataset fixture + """ + X, y = make_classification( + n_samples=1000, + n_features=4, + n_informative=2, + n_redundant=2, + n_classes=2, + n_clusters_per_class=2, + random_state=1234, + ) + + request.cls.binary = Dataset(X, y) + +@pytest.fixture(scope="class") +def gridsearchcv(request): + """ + Creates an sklearn SVC, a GridSearchCV for testing through the SVC's kernel, + gamma, and C parameters, and returns the GridSearchCV. + """ + + svc = SVC() + grid = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4], 'C': [0.01, 0.1, 1, 10]}, + {'kernel': ['linear'], 'C': [0.01, 0.1, 1, 10]}] + gridsearchcv = GridSearchCV(svc, grid, n_jobs=4) + + request.cls.gridsearchcv = gridsearchcv + +@pytest.mark.usefixtures("binary", "gridsearchcv") +class TestGridSearchColorPlot(VisualTestCase): + """ + Tests of basic GridSearchColorPlot functionality + """ + + # ########################################################################## + # ## GridSearchColorPlot Base Test Cases + # ########################################################################## + + def test_gridsearchcolorplot(self): + """ + Test GridSearchColorPlot drawing + """ + + gs_viz = GridSearchColorPlot(self.gridsearchcv, 'C', 'kernel') + gs_viz.fit(self.binary.X, self.binary.y) + self.assert_images_similar(gs_viz) + + def test_quick_method(self): + """ + Test gridsearch_color_plot quick method + """ + + gs = self.gridsearchcv + + # If no X data is passed to quick method, model is assumed to be fit + # already + gs.fit(self.binary.X, self.binary.y) + + gs_viz = gridsearch_color_plot(gs, 'gamma', 'C') + assert isinstance(gs_viz, GridSearchColorPlot) + self.assert_images_similar(gs_viz) + + # ########################################################################## + # ## Integration Tests + # ########################################################################## + + @pytest.mark.skipif(pd is None, reason="test requires pandas") + def test_pandas_integration(self): + """ + Test GridSearchColorPlot on sklearn occupancy data set (as pandas df) + """ + + X, y = load_occupancy(return_dataset=True).to_pandas() + X, y = X.head(1000), y.head(1000) + + gs = self.gridsearchcv + gs_viz = GridSearchColorPlot(self.gridsearchcv, 'C', 'kernel') + gs_viz.fit(X, y) + + self.assert_images_similar(gs_viz) + + def test_numpy_integration(self): + """ + Test GridSearchColorPlot on sklearn occupancy data set (as numpy df) + """ + + X, y = load_occupancy(return_dataset=True).to_numpy() + X, y = X[:1000], y[:1000] + + gs = self.gridsearchcv + gs_viz = GridSearchColorPlot(self.gridsearchcv, 'C', 'kernel') + gs_viz.fit(X, y) + + self.assert_images_similar(gs_viz) \ No newline at end of file diff --git a/yellowbrick/draw.py b/yellowbrick/draw.py index a3fcca4a7..9dc3261e9 100644 --- a/yellowbrick/draw.py +++ b/yellowbrick/draw.py @@ -21,7 +21,7 @@ from .exceptions import YellowbrickValueError from .style.colors import resolve_colors -from matplotlib import patches +from matplotlib import axes, patches, lines import matplotlib.pyplot as plt import numpy as np @@ -30,15 +30,17 @@ ## Legend Drawing Utilities ########################################################################## - -def manual_legend(g, labels, colors, **legend_kwargs): +def manual_legend(g, labels, colors=None, styles=None, **legend_kwargs): """ - Adds a manual legend for a scatter plot to the visualizer where the labels - and associated colors are drawn with circle patches instead of determining - them from the labels of the artist objects on the axes. This helper is - used either when there are a lot of duplicate labels, no labeled artists, - or when the color of the legend doesn't exactly match the color in the - figure (e.g. because of the use of transparency). + Adds a manual legend for a scatter plot to the visualizer. The legend + entries are drawn according to the ``styles`` parameter if specified, and + with circle patches (colored according to ``colors``) if not specified. + Calling this function overrides the default behavior of drawing the legend + from the labels of the artist objects on the axes. + + This helper is used either when there are a lot of duplicate labels, + no labeled artists, or when the color of the legend doesn't exactly + match the color in the figure (e.g. because of the use of transparency). Parameters ---------- @@ -51,10 +53,21 @@ def manual_legend(g, labels, colors, **legend_kwargs): The text labels to associate with the legend. Note that the labels will be added to the legend in the order specified. - colors : list of colors - A list of any valid matplotlib color reference. The number of colors - specified must be equal to the number of labels. - + colors : list of colors, default: None + A list of any valid matplotlib color references. If ``styles`` + is provided, colors must be either ``None`` or a list of equal length to + ``labels``; in the latter case, this parameter takes predence over any + colors specified in ``styles``. To skip specifying a color for a + particular entry, use an empty string, None, or 'None'. + + styles : list of str, default: None + A list of matplotlib-style format strings, each corresponding to a label + and describing its graphical appearance in the legend, e.g., 'ro' for a + red circle. The number of styles specified must be equal to the number + of labels. Either one or both of ``colors`` and ``styles`` must be + specified. Consistent with matplotlib, blank style entries default to + solid, unmarked, black lines. + legend_kwargs : dict Any additional keyword arguments to pass to the legend. @@ -64,36 +77,78 @@ def manual_legend(g, labels, colors, **legend_kwargs): The artist created by the ax.legend() call, returned for further manipulation if required by the caller. - Notes - ----- - Right now this method simply draws the patches as rectangles and cannot - take into account the line or scatter plot properties (e.g. line style or - marker style). It is possible to add Line2D patches to the artist that do - add manual styles like this, which we can explore in the future. - .. seealso:: https://matplotlib.org/gallery/text_labels_and_annotations/custom_legends.html + + .. seealso:: https://matplotlib.org/3.3.0/api/_as_gen/matplotlib.pyplot.plot.html """ + # Get access to the matplotlib Axes if isinstance(g, Visualizer): g = g.ax elif g is None: g = plt.gca() - # Ensure that labels and colors are the same length to prevent odd behavior. - if len(colors) != len(labels): - raise YellowbrickValueError( - "please specify the same number of colors as labels!" - ) - - # Create the legend handles with the associated colors and labels - handles = [ - patches.Patch(color=color, label=label) for color, label in zip(colors, labels) - ] + if styles: + # Documented the `styles` parameter as being a list when really + # it makes sense to accept it as a list or a tuple + if type(styles) not in (list, tuple): + raise YellowbrickValueError( + "Please specify the styles parameter as a list of strings!" + ) + + if len(styles) != len(labels): + raise YellowbrickValueError( + "Please specify the styles parameter as a list of length " + "equal to the number of labels!" + ) + + if colors is not None and len(colors) != len(labels): + raise YellowbrickValueError( + "Please specify the colors parameter either as colors=None or " + "a list of length equal to the number of labels. You can use " + "an empty string or None as a placeholder for colors that " + "are already specified in the corresponding styles entry." + ) + else: + if colors is None or len(colors) != len(labels): + raise YellowbrickValueError( + "Please specify the colors parameter as a list of length equal " + "to the number of labels!" + ) + + # Set legend's artist handles to: + # linestyles/markers/colors specified by `styles` if passed in, or + # patches according to `colors` if it is not + if styles: + if colors is None: + colors = [None] * len(styles) + else: + colors = [None if color in ("", " ", None, 'None') else color + for color in colors] + + handles = list() + for style, color, label in zip(styles, colors, labels): + linestyle, marker, style_color = \ + axes._base._process_plot_format(style) + + # colors parameter should take precedence over styles, + # consistent with matplotlib + color = color or style_color or 'black' + # _process_plot_format() above will have already set linestyle to + # '-' and marker to 'None' if they weren't specified + + line_2d = lines.Line2D([0], [0], linestyle=linestyle, marker=marker, + color=color, label=label) + handles.append(line_2d) + else: + handles = [ + patches.Patch(color=color, label=label) for + color, label in zip(colors, labels) + ] # Return the Legend artist return g.legend(handles=handles, **legend_kwargs) - def bar_stack( data, ax=None, @@ -192,3 +247,4 @@ def bar_stack( legend_kws = legend_kws or {} manual_legend(ax, labels=labels, colors=colors, **legend_kws) return ax + \ No newline at end of file diff --git a/yellowbrick/gridsearch/pcolor.py b/yellowbrick/gridsearch/pcolor.py index eabd1639d..890bce53a 100644 --- a/yellowbrick/gridsearch/pcolor.py +++ b/yellowbrick/gridsearch/pcolor.py @@ -68,8 +68,8 @@ def gridsearch_color_plot(model, x_param, y_param, X=None, y=None, ax=None, **kw Returns ------- - ax : matplotlib axes - Returns the axes that the classification report was drawn on. + visualizer : GridSearchColorPlot + Returns visualizer """ # Instantiate the visualizer visualizer = GridSearchColorPlot(model, x_param, y_param, ax=ax, **kwargs) @@ -80,8 +80,8 @@ def gridsearch_color_plot(model, x_param, y_param, X=None, y=None, ax=None, **kw else: visualizer.draw() - # Return the axes object on the visualizer - return visualizer.ax + # Return the visualizer + return visualizer class GridSearchColorPlot(GridSearchVisualizer):