diff --git a/plotly/matplotlylib/renderer.py b/plotly/matplotlylib/renderer.py index c95de52247..d55f86ef3c 100644 --- a/plotly/matplotlylib/renderer.py +++ b/plotly/matplotlylib/renderer.py @@ -60,6 +60,8 @@ def __init__(self): self.mpl_x_bounds = (0, 1) self.mpl_y_bounds = (0, 1) self.msg = "Initialized PlotlyRenderer\n" + self._processing_legend = False + self._legend_visible = False def open_figure(self, fig, props): """Creates a new figure by beginning to fill out layout dict. @@ -108,7 +110,6 @@ def close_figure(self, fig): fig -- a matplotlib.figure.Figure object. """ - self.plotly_fig["layout"]["showlegend"] = False self.msg += "Closing figure\n" def open_axes(self, ax, props): @@ -198,6 +199,37 @@ def close_axes(self, ax): self.msg += " Closing axes\n" self.x_is_mpl_date = False + def open_legend(self, legend, props): + """Enable Plotly's native legend when matplotlib legend is detected. + + This method is called when a matplotlib legend is found. It enables + Plotly's showlegend only if the matplotlib legend is visible. + + Positional arguments: + legend -- matplotlib.legend.Legend object + props -- legend properties dictionary + """ + self.msg += " Opening legend\n" + self._processing_legend = True + self._legend_visible = props.get("visible", True) + if self._legend_visible: + self.msg += ( + " Enabling native plotly legend (matplotlib legend is visible)\n" + ) + self.plotly_fig["layout"]["showlegend"] = True + else: + self.msg += " Not enabling legend (matplotlib legend is not visible)\n" + + def close_legend(self, legend): + """Finalize legend processing. + + Positional arguments: + legend -- matplotlib.legend.Legend object + """ + self.msg += " Closing legend\n" + self._processing_legend = False + self._legend_visible = False + def draw_bars(self, bars): # sort bars according to bar containers mpl_traces = [] @@ -299,7 +331,7 @@ def draw_bar(self, coll): ) # TODO ditto if len(bar["x"]) > 1: self.msg += " Heck yeah, I drew that bar chart\n" - (self.plotly_fig.add_trace(bar),) + self.plotly_fig.add_trace(bar) if bar_gap is not None: self.plotly_fig["layout"]["bargap"] = bar_gap else: @@ -309,83 +341,6 @@ def draw_bar(self, coll): "assuming data redundancy, not plotting." ) - def draw_legend_shapes(self, mode, shape, **props): - """Create a shape that matches lines or markers in legends. - - Main issue is that path for circles do not render, so we have to use 'circle' - instead of 'path'. - """ - for single_mode in mode.split("+"): - x = props["data"][0][0] - y = props["data"][0][1] - if single_mode == "markers" and props.get("markerstyle"): - size = shape.pop("size", 6) - symbol = shape.pop("symbol") - # aligning to "center" - x0 = 0 - y0 = 0 - x1 = size - y1 = size - markerpath = props["markerstyle"].get("markerpath") - if markerpath is None and symbol != "circle": - self.msg += ( - "not sure how to handle this marker without a valid path\n" - ) - return - # marker path to SVG path conversion - path = " ".join( - [f"{a} {t[0]},{t[1]}" for a, t in zip(markerpath[1], markerpath[0])] - ) - - if symbol == "circle": - # symbols like . and o in matplotlib, use circle - # plotly also maps many other markers to circle, such as 1,8 and p - path = None - shape_type = "circle" - x0 = -size / 2 - y0 = size / 2 - x1 = size / 2 - y1 = size + size / 2 - else: - # triangles, star etc - shape_type = "path" - legend_shape = go.layout.Shape( - type=shape_type, - xref="paper", - yref="paper", - x0=x0, - y0=y0, - x1=x1, - y1=y1, - xsizemode="pixel", - ysizemode="pixel", - xanchor=x, - yanchor=y, - path=path, - **shape, - ) - - elif single_mode == "lines": - mode = "line" - x1 = props["data"][1][0] - y1 = props["data"][1][1] - - legend_shape = go.layout.Shape( - type=mode, - xref="paper", - yref="paper", - x0=x, - y0=y + 0.02, - x1=x1, - y1=y1 + 0.02, - **shape, - ) - else: - self.msg += "not sure how to handle this element\n" - return - self.plotly_fig.add_shape(legend_shape) - self.msg += " Heck yeah, I drew that shape\n" - def draw_marked_line(self, **props): """Create a data dict for a line obj. @@ -497,11 +452,11 @@ def draw_marked_line(self, **props): marked_line["x"] = mpltools.mpl_dates_to_datestrings( marked_line["x"], formatter ) - (self.plotly_fig.add_trace(marked_line),) + self.plotly_fig.add_trace(marked_line) self.msg += " Heck yeah, I drew that line\n" elif props["coordinates"] == "axes": # dealing with legend graphical elements - self.draw_legend_shapes(mode=mode, shape=shape, **props) + self.msg += " Using native legend\n" else: self.msg += " Line didn't have 'data' coordinates, not drawing\n" warnings.warn( @@ -667,6 +622,16 @@ def draw_text(self, **props): self.draw_title(**props) else: # just a regular text annotation... self.msg += " Text object is a normal annotation\n" + # Skip creating annotations for legend text when using native legend + if ( + self._processing_legend + and self._legend_visible + and props["coordinates"] == "axes" + ): + self.msg += ( + " Skipping legend text annotation (using native legend)\n" + ) + return if props["coordinates"] != "data": self.msg += " Text object isn't linked to 'data' coordinates\n" x_px, y_px = ( diff --git a/plotly/matplotlylib/tests/__init__.py b/plotly/matplotlylib/tests/__init__.py new file mode 100644 index 0000000000..c29e9896d6 --- /dev/null +++ b/plotly/matplotlylib/tests/__init__.py @@ -0,0 +1,4 @@ +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt diff --git a/plotly/matplotlylib/tests/test_renderer.py b/plotly/matplotlylib/tests/test_renderer.py new file mode 100644 index 0000000000..72116813ac --- /dev/null +++ b/plotly/matplotlylib/tests/test_renderer.py @@ -0,0 +1,87 @@ +import plotly.tools as tls + +from . import plt + + +def test_native_legend_enabled_when_matplotlib_legend_present(): + """Test that when matplotlib legend is present, Plotly uses native legend.""" + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], label="Line 1") + ax.plot([0, 1], [1, 0], label="Line 2") + ax.legend() + + plotly_fig = tls.mpl_to_plotly(fig) + + # Should enable native legend + assert plotly_fig.layout.showlegend == True + # Should have 2 traces with names + assert len(plotly_fig.data) == 2 + assert plotly_fig.data[0].name == "Line 1" + assert plotly_fig.data[1].name == "Line 2" + + +def test_no_fake_legend_shapes_with_native_legend(): + """Test that fake legend shapes are not created when using native legend.""" + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], "o-", label="Data with markers") + ax.legend() + + plotly_fig = tls.mpl_to_plotly(fig) + + # Should use native legend + assert plotly_fig.layout.showlegend == True + # Should not create fake legend elements + assert len(plotly_fig.layout.shapes) == 0 + assert len(plotly_fig.layout.annotations) == 0 + + +def test_legend_disabled_when_no_matplotlib_legend(): + """Test that legend is not enabled when no matplotlib legend is present.""" + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], label="Line 1") # Has label but no legend() call + + plotly_fig = tls.mpl_to_plotly(fig) + + # Should not have showlegend explicitly set to True + # (Plotly's default behavior when no legend elements exist) + assert ( + not hasattr(plotly_fig.layout, "showlegend") + or plotly_fig.layout.showlegend != True + ) + + +def test_legend_disabled_when_matplotlib_legend_not_visible(): + """Test that legend is not enabled when no matplotlib legend is not visible.""" + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], label="Line 1") + legend = ax.legend() + legend.set_visible(False) # Hide the legend + + plotly_fig = tls.mpl_to_plotly(fig) + + # Should not enable legend when matplotlib legend is hidden + assert ( + not hasattr(plotly_fig.layout, "showlegend") + or plotly_fig.layout.showlegend != True + ) + + +def test_multiple_traces_native_legend(): + """Test native legend works with multiple traces of different types.""" + fig, ax = plt.subplots() + ax.plot([0, 1, 2], [0, 1, 0], "-", label="Line") + ax.plot([0, 1, 2], [1, 0, 1], "o", label="Markers") + ax.plot([0, 1, 2], [0.5, 0.5, 0.5], "s-", label="Line+Markers") + ax.legend() + + plotly_fig = tls.mpl_to_plotly(fig) + + assert plotly_fig.layout.showlegend == True + assert len(plotly_fig.data) == 3 + assert plotly_fig.data[0].name == "Line" + assert plotly_fig.data[1].name == "Markers" + assert plotly_fig.data[2].name == "Line+Markers" + # Verify modes are correct + assert plotly_fig.data[0].mode == "lines" + assert plotly_fig.data[1].mode == "markers" + assert plotly_fig.data[2].mode == "lines+markers"