Skip to content

Commit c36823d

Browse files
authored
Merge pull request #278 from nschloe/legend-fixes
Legend fixes
2 parents 5e9a2ef + 4de76ae commit c36823d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+409
-383
lines changed

matplotlib2tikz/axes.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(self, data, obj):
126126

127127
# actually print the thing
128128
if self.is_subplot:
129-
self.content.append("\\nextgroupplot")
129+
self.content.append("\n\\nextgroupplot")
130130
else:
131131
self.content.append("\\begin{axis}")
132132

@@ -405,28 +405,25 @@ def _subplot(self, obj, data):
405405
if "is_in_groupplot_env" not in data or not data["is_in_groupplot_env"]:
406406
self.content.append(
407407
"\\begin{{groupplot}}[group style="
408-
"{{group size={} by {}}}]\n".format(geom[1], geom[0])
408+
"{{group size={} by {}}}]".format(geom[1], geom[0])
409409
)
410410
data["is_in_groupplot_env"] = True
411411
data["pgfplots libs"].add("groupplots")
412412

413413
return
414414

415415

416-
def _get_label_rotation_and_horizontal_alignment(obj, data, axes_obj):
416+
def _get_label_rotation_and_horizontal_alignment(obj, data, x_or_y):
417417
tick_label_text_width = None
418-
tick_label_text_width_identifier = "{} tick label text width".format(axes_obj)
418+
tick_label_text_width_identifier = "{} tick label text width".format(x_or_y)
419419
if tick_label_text_width_identifier in data["extra axis options"]:
420-
tick_label_text_width = data["extra axis options [base]"][
421-
tick_label_text_width_identifier
422-
]
423-
del data["extra axis options"][tick_label_text_width_identifier]
420+
data["extra axis options"].remove(tick_label_text_width_identifier)
424421

425422
label_style = ""
426423

427424
major_tick_labels = (
428425
obj.xaxis.get_majorticklabels()
429-
if axes_obj == "x"
426+
if x_or_y == "x"
430427
else obj.yaxis.get_majorticklabels()
431428
)
432429

@@ -456,9 +453,7 @@ def _get_label_rotation_and_horizontal_alignment(obj, data, axes_obj):
456453
values.append("text width={}".format(tick_label_text_width))
457454

458455
if values:
459-
label_style = "{}ticklabel style = {{{}}}".format(
460-
axes_obj, ",".join(values)
461-
)
456+
label_style = "{}ticklabel style = {{{}}}".format(x_or_y, ",".join(values))
462457
else:
463458
values = []
464459

@@ -478,12 +473,12 @@ def _get_label_rotation_and_horizontal_alignment(obj, data, axes_obj):
478473
else:
479474
for idx, x in enumerate(tick_labels_horizontal_alignment):
480475
label_style += "{}_tick_label_ha_{}/.initial = {}".format(
481-
axes_obj, idx, x
476+
x_or_y, idx, x
482477
)
483478

484479
values.append(
485480
"align=\\pgfkeysvalueof{{/pgfplots/{}_tick_label_ha_\\ticknum}}".format(
486-
axes_obj
481+
x_or_y
487482
)
488483
)
489484
values.append("text width={}".format(tick_label_text_width))
@@ -493,13 +488,13 @@ def _get_label_rotation_and_horizontal_alignment(obj, data, axes_obj):
493488
"Horizontal alignment will be ignored as no '{} tick "
494489
"label text width' has been passed in the 'extra' "
495490
"parameter"
496-
).format(axes_obj)
491+
).format(x_or_y)
497492
)
498493

499494
label_style = (
500495
"every {} tick label/.style = {{\n"
501496
"{}\n"
502-
"}}".format(axes_obj, ",\n".join(values))
497+
"}}".format(x_or_y, ",\n".join(values))
503498
)
504499

505500
return label_style

matplotlib2tikz/legend.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ def draw_legend(data, obj):
1616
texts.append("{}".format(text.get_text()))
1717
children_alignment.append("{}".format(text.get_horizontalalignment()))
1818

19-
cont = "legend entries={{{{{}}}}}".format("},{".join(texts))
20-
data["extra axis options"].add(cont)
21-
2219
# Get the location.
2320
# http://matplotlib.org/api/legend_api.html
2421
loc = obj._loc if obj._loc != 0 else _get_location_from_best(obj)
@@ -84,25 +81,6 @@ def draw_legend(data, obj):
8481
if obj._ncol != 1:
8582
data["extra axis options"].add("legend columns={}".format(obj._ncol))
8683

87-
# Set color of lines in legend
88-
for handle in obj.legendHandles:
89-
try:
90-
# when using matplotlib colours like "darkred" or "darkorange",
91-
# `handle.get_color` will create nested RGBA codes
92-
# e.g. `[[ 0.54509804, 0., 0., 1.]]` which casuse mpl to throw an error.
93-
# catch this error, `numpy.squeeze` the colour code and try again
94-
try:
95-
data, legend_color, _ = mycol.mpl_color2xcolor(data, handle.get_color())
96-
except ValueError:
97-
data, legend_color, _ = mycol.mpl_color2xcolor(
98-
data, numpy.squeeze(handle.get_color())
99-
)
100-
data["legend colors"].append(
101-
"\\addlegendimage{{no markers, {}}}\n".format(legend_color)
102-
)
103-
except AttributeError:
104-
pass
105-
10684
# Write styles to data
10785
if legend_style:
10886
style = "legend style={{{}}}".format(", ".join(legend_style))

matplotlib2tikz/line2d.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,26 @@
99
from . import files
1010

1111

12-
def get_legend_label_(line):
13-
"""Check if line is in legend
12+
def _has_legend(axes):
13+
return axes.get_legend() is not None
14+
15+
16+
def _get_legend_text(line):
17+
"""Check if line is in legend.
1418
"""
19+
leg = line.axes.get_legend()
20+
if leg is None:
21+
return None
22+
23+
keys = [l.get_label() for l in leg.get_lines()]
24+
values = [l.get_text() for l in leg.texts]
1525

1626
label = line.get_label()
17-
try:
18-
ax = line.axes
19-
leg = ax.get_legend()
20-
return label in [l.get_label() for l in leg.get_lines()]
21-
except AttributeError:
22-
return None
27+
d = dict(zip(keys, values))
28+
if label in d:
29+
return d[label]
30+
31+
return None
2332

2433

2534
def draw_line2d(data, obj):
@@ -73,19 +82,22 @@ def draw_line2d(data, obj):
7382
if marker and not show_line:
7483
addplot_options.append("only marks")
7584

76-
# Check if a line is not in a legend and forget it if so,
77-
# fixes bug #167:
78-
if not get_legend_label_(obj):
85+
# Check if a line is in a legend and forget it if not.
86+
# Fixes <https://github.com/nschloe/matplotlib2tikz/issues/167>.
87+
legend_text = _get_legend_text(obj)
88+
if legend_text is None and _has_legend(obj.axes):
7989
addplot_options.append("forget plot")
8090

8191
# process options
8292
content.append("\\addplot ")
8393
if addplot_options:
84-
options = ", ".join(addplot_options)
85-
content.append("[" + options + "]\n")
94+
content.append("[{}]\n".format(", ".join(addplot_options)))
8695

8796
_table(obj, content, data)
8897

98+
if legend_text is not None:
99+
content.append("\\addlegendentry{{{}}}\n".format(legend_text))
100+
89101
return data, content
90102

91103

@@ -147,9 +159,10 @@ def draw_linecollection(data, obj):
147159
options.append(linestyle)
148160

149161
# TODO what about masks?
150-
data, cont = mypath.draw_path(data, path, draw_options=options, simplify=False)
151-
152-
content.append(cont)
162+
data, cont, _, _ = mypath.draw_path(
163+
data, path, draw_options=options, simplify=False
164+
)
165+
content.append(cont + "\n")
153166

154167
return data, content
155168

matplotlib2tikz/patch.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def draw_patch(data, obj):
2121
return _draw_ellipse(data, obj, draw_options)
2222

2323
# regular patch
24-
return mypath.draw_path(data, obj.get_path(), draw_options=draw_options)
24+
data, path_command, _, _ = mypath.draw_path(
25+
data, obj.get_path(), draw_options=draw_options
26+
)
27+
return data, path_command
2528

2629

2730
def draw_patchcollection(data, obj):
@@ -40,57 +43,76 @@ def draw_patchcollection(data, obj):
4043
face_color = None
4144

4245
data, draw_options = mypath.get_draw_options(data, edge_color, face_color)
43-
for path in obj.get_paths():
44-
data, cont = mypath.draw_path(data, path, draw_options=draw_options)
46+
47+
paths = obj.get_paths()
48+
for path in paths:
49+
data, cont, draw_options, is_area = mypath.draw_path(
50+
data, path, draw_options=draw_options
51+
)
4552
content.append(cont)
53+
54+
if _is_in_legend(obj):
55+
# Unfortunately, patch legend entries need \addlegendimage in Pgfplots.
56+
tpe = "area legend" if is_area else "line legend"
57+
do = ", ".join([tpe] + draw_options) if draw_options else ""
58+
content += [
59+
"\\addlegendimage{{{}}}\n".format(do),
60+
"\\addlegendentry{{{}}}\n\n".format(obj.get_label()),
61+
]
62+
else:
63+
content.append("\n")
64+
4665
return data, content
4766

4867

68+
def _is_in_legend(obj):
69+
label = obj.get_label()
70+
leg = obj.axes.get_legend()
71+
if leg is None:
72+
return False
73+
return label in [txt.get_text() for txt in leg.get_texts()]
74+
75+
4976
def _draw_rectangle(data, obj, draw_options):
5077
"""Return the PGFPlots code for rectangles.
5178
"""
52-
53-
# Objects with labels are plot objects (from bar charts, etc).
54-
# Even those without labels explicitly set have a label of
55-
# "_nolegend_". Everything else should be skipped because
56-
# they likely correspong to axis/legend objects which are
57-
# handled by PGFPlots
79+
# Objects with labels are plot objects (from bar charts, etc). Even those without
80+
# labels explicitly set have a label of "_nolegend_". Everything else should be
81+
# skipped because they likely correspong to axis/legend objects which are handled by
82+
# PGFPlots
5883
label = obj.get_label()
5984
if label == "":
6085
return data, []
6186

62-
# get real label, bar charts by default only give rectangles
63-
# labels of "_nolegend_"
64-
# See
65-
# <http://stackoverflow.com/questions/35881290/how-to-get-the-label-on-bar-plot-stacked-bar-plot-in-matplotlib>
87+
# Get actual label, bar charts by default only give rectangles labels of
88+
# "_nolegend_". See <https://stackoverflow.com/q/35881290/353337>.
6689
handles, labels = obj.axes.get_legend_handles_labels()
6790
labelsFound = [
6891
label for h, label in zip(handles, labels) if obj in h.get_children()
6992
]
7093
if len(labelsFound) == 1:
7194
label = labelsFound[0]
7295

73-
legend = ""
74-
if label != "_nolegend_" and label not in data["rectangle_legends"]:
75-
data["rectangle_legends"].add(label)
76-
legend = ("\\addlegendimage{{ybar,ybar legend,{}}};\n").format(
77-
",".join(draw_options)
78-
)
79-
8096
left_lower_x = obj.get_x()
8197
left_lower_y = obj.get_y()
8298
ff = data["float format"]
8399
cont = (
84-
"{}\\draw[{}] (axis cs:" + ff + "," + ff + ") "
100+
"\\draw[{}] (axis cs:" + ff + "," + ff + ") "
85101
"rectangle (axis cs:" + ff + "," + ff + ");\n"
86102
).format(
87-
legend,
88103
",".join(draw_options),
89104
left_lower_x,
90105
left_lower_y,
91106
left_lower_x + obj.get_width(),
92107
left_lower_y + obj.get_height(),
93108
)
109+
110+
if label != "_nolegend_" and label not in data["rectangle_legends"]:
111+
data["rectangle_legends"].add(label)
112+
cont += "\\addlegendimage{{ybar,ybar legend,{}}};\n".format(
113+
",".join(draw_options)
114+
)
115+
cont += "\\addlegendentry{{{}}}\n\n".format(label)
94116
return data, cont
95117

96118

matplotlib2tikz/path.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def draw_path(data, path, draw_options=None, simplify=None):
1818
and all(path.vertices[0] == path.vertices[1])
1919
and "fill opacity=0" in draw_options
2020
):
21-
return data, ""
21+
return data, "", None, False
2222

2323
nodes = []
2424
ff = data["float format"]
@@ -34,6 +34,7 @@ def draw_path(data, path, draw_options=None, simplify=None):
3434
# For path codes see: http://matplotlib.org/api/path_api.html
3535
#
3636
# if code == mpl.path.Path.STOP: pass
37+
is_area = False
3738
if code == mpl.path.Path.MOVETO:
3839
nodes.append(("(axis cs:" + ff + "," + ff + ")").format(*vert))
3940
elif code == mpl.path.Path.LINETO:
@@ -103,14 +104,15 @@ def draw_path(data, path, draw_options=None, simplify=None):
103104
else:
104105
assert code == mpl.path.Path.CLOSEPOLY
105106
nodes.append("--cycle")
107+
is_area = True
106108

107109
# Store the previous point for quadratic Beziers.
108110
prev = vert[0:2]
109111

110112
do = "[{}]".format(", ".join(draw_options)) if draw_options else ""
111-
path_command = "\\path {} {};\n\n".format(do, "\n".join(nodes))
113+
path_command = "\\path {}\n{};\n".format(do, "\n".join(nodes))
112114

113-
return data, path_command
115+
return data, path_command, draw_options, is_area
114116

115117

116118
def draw_pathcollection(data, obj):

0 commit comments

Comments
 (0)