Skip to content

Commit 1528322

Browse files
Use native legends when converting from matplotlib
1 parent 694b036 commit 1528322

File tree

3 files changed

+121
-79
lines changed

3 files changed

+121
-79
lines changed

plotly/matplotlylib/renderer.py

Lines changed: 36 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(self):
6060
self.mpl_x_bounds = (0, 1)
6161
self.mpl_y_bounds = (0, 1)
6262
self.msg = "Initialized PlotlyRenderer\n"
63+
self._processing_legend = False
64+
self._legend_visible = False
6365

6466
def open_figure(self, fig, props):
6567
"""Creates a new figure by beginning to fill out layout dict.
@@ -108,7 +110,6 @@ def close_figure(self, fig):
108110
fig -- a matplotlib.figure.Figure object.
109111
110112
"""
111-
self.plotly_fig["layout"]["showlegend"] = False
112113
self.msg += "Closing figure\n"
113114

114115
def open_axes(self, ax, props):
@@ -198,6 +199,35 @@ def close_axes(self, ax):
198199
self.msg += " Closing axes\n"
199200
self.x_is_mpl_date = False
200201

202+
def open_legend(self, legend, props):
203+
"""Enable Plotly's native legend when matplotlib legend is detected.
204+
205+
This method is called when a matplotlib legend is found. It enables
206+
Plotly's showlegend only if the matplotlib legend is visible.
207+
208+
Positional arguments:
209+
legend -- matplotlib.legend.Legend object
210+
props -- legend properties dictionary
211+
"""
212+
self.msg += " Opening legend\n"
213+
self._processing_legend = True
214+
self._legend_visible = props.get("visible", True)
215+
if self._legend_visible:
216+
self.msg += " Enabling native plotly legend (matplotlib legend is visible)\n"
217+
self.plotly_fig["layout"]["showlegend"] = True
218+
else:
219+
self.msg += " Not enabling legend (matplotlib legend is not visible)\n"
220+
221+
def close_legend(self, legend):
222+
"""Finalize legend processing.
223+
224+
Positional arguments:
225+
legend -- matplotlib.legend.Legend object
226+
"""
227+
self.msg += " Closing legend\n"
228+
self._processing_legend = False
229+
self._legend_visible = False
230+
201231
def draw_bars(self, bars):
202232
# sort bars according to bar containers
203233
mpl_traces = []
@@ -309,83 +339,6 @@ def draw_bar(self, coll):
309339
"assuming data redundancy, not plotting."
310340
)
311341

312-
def draw_legend_shapes(self, mode, shape, **props):
313-
"""Create a shape that matches lines or markers in legends.
314-
315-
Main issue is that path for circles do not render, so we have to use 'circle'
316-
instead of 'path'.
317-
"""
318-
for single_mode in mode.split("+"):
319-
x = props["data"][0][0]
320-
y = props["data"][0][1]
321-
if single_mode == "markers" and props.get("markerstyle"):
322-
size = shape.pop("size", 6)
323-
symbol = shape.pop("symbol")
324-
# aligning to "center"
325-
x0 = 0
326-
y0 = 0
327-
x1 = size
328-
y1 = size
329-
markerpath = props["markerstyle"].get("markerpath")
330-
if markerpath is None and symbol != "circle":
331-
self.msg += (
332-
"not sure how to handle this marker without a valid path\n"
333-
)
334-
return
335-
# marker path to SVG path conversion
336-
path = " ".join(
337-
[f"{a} {t[0]},{t[1]}" for a, t in zip(markerpath[1], markerpath[0])]
338-
)
339-
340-
if symbol == "circle":
341-
# symbols like . and o in matplotlib, use circle
342-
# plotly also maps many other markers to circle, such as 1,8 and p
343-
path = None
344-
shape_type = "circle"
345-
x0 = -size / 2
346-
y0 = size / 2
347-
x1 = size / 2
348-
y1 = size + size / 2
349-
else:
350-
# triangles, star etc
351-
shape_type = "path"
352-
legend_shape = go.layout.Shape(
353-
type=shape_type,
354-
xref="paper",
355-
yref="paper",
356-
x0=x0,
357-
y0=y0,
358-
x1=x1,
359-
y1=y1,
360-
xsizemode="pixel",
361-
ysizemode="pixel",
362-
xanchor=x,
363-
yanchor=y,
364-
path=path,
365-
**shape,
366-
)
367-
368-
elif single_mode == "lines":
369-
mode = "line"
370-
x1 = props["data"][1][0]
371-
y1 = props["data"][1][1]
372-
373-
legend_shape = go.layout.Shape(
374-
type=mode,
375-
xref="paper",
376-
yref="paper",
377-
x0=x,
378-
y0=y + 0.02,
379-
x1=x1,
380-
y1=y1 + 0.02,
381-
**shape,
382-
)
383-
else:
384-
self.msg += "not sure how to handle this element\n"
385-
return
386-
self.plotly_fig.add_shape(legend_shape)
387-
self.msg += " Heck yeah, I drew that shape\n"
388-
389342
def draw_marked_line(self, **props):
390343
"""Create a data dict for a line obj.
391344
@@ -501,7 +454,7 @@ def draw_marked_line(self, **props):
501454
self.msg += " Heck yeah, I drew that line\n"
502455
elif props["coordinates"] == "axes":
503456
# dealing with legend graphical elements
504-
self.draw_legend_shapes(mode=mode, shape=shape, **props)
457+
self.msg += " Using native legend\n"
505458
else:
506459
self.msg += " Line didn't have 'data' coordinates, not drawing\n"
507460
warnings.warn(
@@ -667,6 +620,10 @@ def draw_text(self, **props):
667620
self.draw_title(**props)
668621
else: # just a regular text annotation...
669622
self.msg += " Text object is a normal annotation\n"
623+
# Skip creating annotations for legend text when using native legend
624+
if self._processing_legend and self._legend_visible and props["coordinates"] == "axes":
625+
self.msg += " Skipping legend text annotation (using native legend)\n"
626+
return
670627
if props["coordinates"] != "data":
671628
self.msg += " Text object isn't linked to 'data' coordinates\n"
672629
x_px, y_px = (

plotly/matplotlylib/tests/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import matplotlib
2+
3+
matplotlib.use("Agg")
4+
import matplotlib.pyplot as plt
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import plotly.tools as tls
2+
3+
from . import plt
4+
5+
6+
def test_native_legend_enabled_when_matplotlib_legend_present():
7+
"""Test that when matplotlib legend is present, Plotly uses native legend."""
8+
fig, ax = plt.subplots()
9+
ax.plot([0, 1], [0, 1], label="Line 1")
10+
ax.plot([0, 1], [1, 0], label="Line 2")
11+
ax.legend()
12+
13+
plotly_fig = tls.mpl_to_plotly(fig)
14+
15+
# Should enable native legend
16+
assert plotly_fig.layout.showlegend == True
17+
# Should have 2 traces with names
18+
assert len(plotly_fig.data) == 2
19+
assert plotly_fig.data[0].name == "Line 1"
20+
assert plotly_fig.data[1].name == "Line 2"
21+
22+
23+
def test_no_fake_legend_shapes_with_native_legend():
24+
"""Test that fake legend shapes are not created when using native legend."""
25+
fig, ax = plt.subplots()
26+
ax.plot([0, 1], [0, 1], "o-", label="Data with markers")
27+
ax.legend()
28+
29+
plotly_fig = tls.mpl_to_plotly(fig)
30+
31+
# Should use native legend
32+
assert plotly_fig.layout.showlegend == True
33+
# Should not create fake legend elements
34+
assert len(plotly_fig.layout.shapes) == 0
35+
assert len(plotly_fig.layout.annotations) == 0
36+
37+
38+
def test_legend_disabled_when_no_matplotlib_legend():
39+
"""Test that legend is not enabled when no matplotlib legend is present."""
40+
fig, ax = plt.subplots()
41+
ax.plot([0, 1], [0, 1], label="Line 1") # Has label but no legend() call
42+
43+
plotly_fig = tls.mpl_to_plotly(fig)
44+
45+
# Should not have showlegend explicitly set to True
46+
# (Plotly's default behavior when no legend elements exist)
47+
assert not hasattr(plotly_fig.layout, 'showlegend') or plotly_fig.layout.showlegend != True
48+
49+
50+
def test_legend_disabled_when_matplotlib_legend_not_visible():
51+
"""Test that legend is not enabled when no matplotlib legend is not visible."""
52+
fig, ax = plt.subplots()
53+
ax.plot([0, 1], [0, 1], label="Line 1")
54+
legend = ax.legend()
55+
legend.set_visible(False) # Hide the legend
56+
57+
plotly_fig = tls.mpl_to_plotly(fig)
58+
59+
# Should not enable legend when matplotlib legend is hidden
60+
assert not hasattr(plotly_fig.layout, 'showlegend') or plotly_fig.layout.showlegend != True
61+
62+
63+
def test_multiple_traces_native_legend():
64+
"""Test native legend works with multiple traces of different types."""
65+
fig, ax = plt.subplots()
66+
ax.plot([0, 1, 2], [0, 1, 0], '-', label="Line")
67+
ax.plot([0, 1, 2], [1, 0, 1], 'o', label="Markers")
68+
ax.plot([0, 1, 2], [0.5, 0.5, 0.5], 's-', label="Line+Markers")
69+
ax.legend()
70+
71+
plotly_fig = tls.mpl_to_plotly(fig)
72+
73+
assert plotly_fig.layout.showlegend == True
74+
assert len(plotly_fig.data) == 3
75+
assert plotly_fig.data[0].name == "Line"
76+
assert plotly_fig.data[1].name == "Markers"
77+
assert plotly_fig.data[2].name == "Line+Markers"
78+
# Verify modes are correct
79+
assert plotly_fig.data[0].mode == "lines"
80+
assert plotly_fig.data[1].mode == "markers"
81+
assert plotly_fig.data[2].mode == "lines+markers"

0 commit comments

Comments
 (0)