Skip to content

Commit 54283bf

Browse files
Use native legends when converting from matplotlib
1 parent 1ec864b commit 54283bf

File tree

3 files changed

+121
-79
lines changed

3 files changed

+121
-79
lines changed

plotly/matplotlylib/renderer.py

Lines changed: 36 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(self):
6060
self.mpl_x_bounds = (0, 1)
6161
self.mpl_y_bounds = (0, 1)
6262
self.msg = "Initialized PlotlyRenderer\n"
63+
self._processing_legend = False
64+
self._legend_visible = False
6365

6466
def open_figure(self, fig, props):
6567
"""Creates a new figure by beginning to fill out layout dict.
@@ -108,7 +110,6 @@ def close_figure(self, fig):
108110
fig -- a matplotlib.figure.Figure object.
109111
110112
"""
111-
self.plotly_fig["layout"]["showlegend"] = False
112113
self.msg += "Closing figure\n"
113114

114115
def open_axes(self, ax, props):
@@ -198,6 +199,35 @@ def close_axes(self, ax):
198199
self.msg += " Closing axes\n"
199200
self.x_is_mpl_date = False
200201

202+
def open_legend(self, legend, props):
203+
"""Enable Plotly's native legend when matplotlib legend is detected.
204+
205+
This method is called when a matplotlib legend is found. It enables
206+
Plotly's showlegend only if the matplotlib legend is visible.
207+
208+
Positional arguments:
209+
legend -- matplotlib.legend.Legend object
210+
props -- legend properties dictionary
211+
"""
212+
self.msg += " Opening legend\n"
213+
self._processing_legend = True
214+
self._legend_visible = props.get("visible", True)
215+
if self._legend_visible:
216+
self.msg += " Enabling native plotly legend (matplotlib legend is visible)\n"
217+
self.plotly_fig["layout"]["showlegend"] = True
218+
else:
219+
self.msg += " Not enabling legend (matplotlib legend is not visible)\n"
220+
221+
def close_legend(self, legend):
222+
"""Finalize legend processing.
223+
224+
Positional arguments:
225+
legend -- matplotlib.legend.Legend object
226+
"""
227+
self.msg += " Closing legend\n"
228+
self._processing_legend = False
229+
self._legend_visible = False
230+
201231
def draw_bars(self, bars):
202232

203233
# sort bars according to bar containers
@@ -310,83 +340,6 @@ def draw_bar(self, coll):
310340
"assuming data redundancy, not plotting."
311341
)
312342

313-
def draw_legend_shapes(self, mode, shape, **props):
314-
"""Create a shape that matches lines or markers in legends.
315-
316-
Main issue is that path for circles do not render, so we have to use 'circle'
317-
instead of 'path'.
318-
"""
319-
for single_mode in mode.split("+"):
320-
x = props["data"][0][0]
321-
y = props["data"][0][1]
322-
if single_mode == "markers" and props.get("markerstyle"):
323-
size = shape.pop("size", 6)
324-
symbol = shape.pop("symbol")
325-
# aligning to "center"
326-
x0 = 0
327-
y0 = 0
328-
x1 = size
329-
y1 = size
330-
markerpath = props["markerstyle"].get("markerpath")
331-
if markerpath is None and symbol != "circle":
332-
self.msg += (
333-
"not sure how to handle this marker without a valid path\n"
334-
)
335-
return
336-
# marker path to SVG path conversion
337-
path = " ".join(
338-
[f"{a} {t[0]},{t[1]}" for a, t in zip(markerpath[1], markerpath[0])]
339-
)
340-
341-
if symbol == "circle":
342-
# symbols like . and o in matplotlib, use circle
343-
# plotly also maps many other markers to circle, such as 1,8 and p
344-
path = None
345-
shape_type = "circle"
346-
x0 = -size / 2
347-
y0 = size / 2
348-
x1 = size / 2
349-
y1 = size + size / 2
350-
else:
351-
# triangles, star etc
352-
shape_type = "path"
353-
legend_shape = go.layout.Shape(
354-
type=shape_type,
355-
xref="paper",
356-
yref="paper",
357-
x0=x0,
358-
y0=y0,
359-
x1=x1,
360-
y1=y1,
361-
xsizemode="pixel",
362-
ysizemode="pixel",
363-
xanchor=x,
364-
yanchor=y,
365-
path=path,
366-
**shape,
367-
)
368-
369-
elif single_mode == "lines":
370-
mode = "line"
371-
x1 = props["data"][1][0]
372-
y1 = props["data"][1][1]
373-
374-
legend_shape = go.layout.Shape(
375-
type=mode,
376-
xref="paper",
377-
yref="paper",
378-
x0=x,
379-
y0=y + 0.02,
380-
x1=x1,
381-
y1=y1 + 0.02,
382-
**shape,
383-
)
384-
else:
385-
self.msg += "not sure how to handle this element\n"
386-
return
387-
self.plotly_fig.add_shape(legend_shape)
388-
self.msg += " Heck yeah, I drew that shape\n"
389-
390343
def draw_marked_line(self, **props):
391344
"""Create a data dict for a line obj.
392345
@@ -502,7 +455,7 @@ def draw_marked_line(self, **props):
502455
self.msg += " Heck yeah, I drew that line\n"
503456
elif props["coordinates"] == "axes":
504457
# dealing with legend graphical elements
505-
self.draw_legend_shapes(mode=mode, shape=shape, **props)
458+
self.msg += " Using native legend\n"
506459
else:
507460
self.msg += " Line didn't have 'data' coordinates, " "not drawing\n"
508461
warnings.warn(
@@ -668,6 +621,10 @@ def draw_text(self, **props):
668621
self.draw_title(**props)
669622
else: # just a regular text annotation...
670623
self.msg += " Text object is a normal annotation\n"
624+
# Skip creating annotations for legend text when using native legend
625+
if self._processing_legend and self._legend_visible and props["coordinates"] == "axes":
626+
self.msg += " Skipping legend text annotation (using native legend)\n"
627+
return
671628
if props["coordinates"] != "data":
672629
self.msg += (
673630
" Text object isn't linked to 'data' " "coordinates\n"

plotly/matplotlylib/tests/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import matplotlib
2+
3+
matplotlib.use("Agg")
4+
import matplotlib.pyplot as plt
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import plotly.tools as tls
2+
3+
from . import plt
4+
5+
6+
def test_native_legend_enabled_when_matplotlib_legend_present():
7+
"""Test that when matplotlib legend is present, Plotly uses native legend."""
8+
fig, ax = plt.subplots()
9+
ax.plot([0, 1], [0, 1], label="Line 1")
10+
ax.plot([0, 1], [1, 0], label="Line 2")
11+
ax.legend()
12+
13+
plotly_fig = tls.mpl_to_plotly(fig)
14+
15+
# Should enable native legend
16+
assert plotly_fig.layout.showlegend == True
17+
# Should have 2 traces with names
18+
assert len(plotly_fig.data) == 2
19+
assert plotly_fig.data[0].name == "Line 1"
20+
assert plotly_fig.data[1].name == "Line 2"
21+
22+
23+
def test_no_fake_legend_shapes_with_native_legend():
24+
"""Test that fake legend shapes are not created when using native legend."""
25+
fig, ax = plt.subplots()
26+
ax.plot([0, 1], [0, 1], "o-", label="Data with markers")
27+
ax.legend()
28+
29+
plotly_fig = tls.mpl_to_plotly(fig)
30+
31+
# Should use native legend
32+
assert plotly_fig.layout.showlegend == True
33+
# Should not create fake legend elements
34+
assert len(plotly_fig.layout.shapes) == 0
35+
assert len(plotly_fig.layout.annotations) == 0
36+
37+
38+
def test_legend_disabled_when_no_matplotlib_legend():
39+
"""Test that legend is not enabled when no matplotlib legend is present."""
40+
fig, ax = plt.subplots()
41+
ax.plot([0, 1], [0, 1], label="Line 1") # Has label but no legend() call
42+
43+
plotly_fig = tls.mpl_to_plotly(fig)
44+
45+
# Should not have showlegend explicitly set to True
46+
# (Plotly's default behavior when no legend elements exist)
47+
assert not hasattr(plotly_fig.layout, 'showlegend') or plotly_fig.layout.showlegend != True
48+
49+
50+
def test_legend_disabled_when_matplotlib_legend_not_visible():
51+
"""Test that legend is not enabled when no matplotlib legend is not visible."""
52+
fig, ax = plt.subplots()
53+
ax.plot([0, 1], [0, 1], label="Line 1")
54+
legend = ax.legend()
55+
legend.set_visible(False) # Hide the legend
56+
57+
plotly_fig = tls.mpl_to_plotly(fig)
58+
59+
# Should not enable legend when matplotlib legend is hidden
60+
assert not hasattr(plotly_fig.layout, 'showlegend') or plotly_fig.layout.showlegend != True
61+
62+
63+
def test_multiple_traces_native_legend():
64+
"""Test native legend works with multiple traces of different types."""
65+
fig, ax = plt.subplots()
66+
ax.plot([0, 1, 2], [0, 1, 0], '-', label="Line")
67+
ax.plot([0, 1, 2], [1, 0, 1], 'o', label="Markers")
68+
ax.plot([0, 1, 2], [0.5, 0.5, 0.5], 's-', label="Line+Markers")
69+
ax.legend()
70+
71+
plotly_fig = tls.mpl_to_plotly(fig)
72+
73+
assert plotly_fig.layout.showlegend == True
74+
assert len(plotly_fig.data) == 3
75+
assert plotly_fig.data[0].name == "Line"
76+
assert plotly_fig.data[1].name == "Markers"
77+
assert plotly_fig.data[2].name == "Line+Markers"
78+
# Verify modes are correct
79+
assert plotly_fig.data[0].mode == "lines"
80+
assert plotly_fig.data[1].mode == "markers"
81+
assert plotly_fig.data[2].mode == "lines+markers"

0 commit comments

Comments
 (0)