Skip to content

Commit 53d55e2

Browse files
dougndnschloe
authored andcommitted
Add ability to render bar charts and legends for bar charts (#135)
* Add ability to render bar chart legends Bar charts in matplotlib consist of a set of rectangles. This commit finds the first rectange of each label, and provides a \addlegendimage with the appropriate draw options. * Removed draw_rectanges, added barchart test with no legend
1 parent 5f4f7a4 commit 53d55e2

File tree

6 files changed

+98
-16
lines changed

6 files changed

+98
-16
lines changed

matplotlib2tikz/patch.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,36 @@ def draw_patchcollection(data, obj):
5151
def _draw_rectangle(data, obj, draw_options):
5252
'''Return the PGFPlots code for rectangles.
5353
'''
54-
if not data['draw rectangles']:
54+
55+
# Objects with labels are plot objects (from bar charts, etc).
56+
# Even those without labels explicitly set have a label of
57+
# "_nolegend_". Everything else should be skipped because
58+
# they likely correspong to axis/legend objects which are
59+
# handled by PGFPlots
60+
label = obj.get_label()
61+
if label == '':
5562
return data, []
5663

64+
# get real label, bar charts by default only give rectangles
65+
# labels of "_nolegend_"
66+
# See http://stackoverflow.com/questions/35881290/how-to-get-the-label-on-bar-plot-stacked-bar-plot-in-matplotlib
67+
handles,labels = obj.axes.get_legend_handles_labels()
68+
labelsFound = [label for h,label in zip(handles, labels) if obj in h.get_children()]
69+
if len(labelsFound) == 1:
70+
label = labelsFound[0]
71+
72+
legend = ''
73+
if label != '_nolegend_' and label not in data['rectangle_legends']:
74+
data['rectangle_legends'].add(label)
75+
legend = ('\\addlegendimage{ybar,ybar legend,%s};\n'
76+
) % (','.join(draw_options))
77+
5778
left_lower_x = obj.get_x()
5879
left_lower_y = obj.get_y()
59-
cont = ('\draw[%s] (axis cs:%.15g,%.15g) '
80+
cont = ('%s\draw[%s] (axis cs:%.15g,%.15g) '
6081
'rectangle (axis cs:%.15g,%.15g);\n'
61-
) % (','.join(draw_options),
82+
) % (legend,
83+
','.join(draw_options),
6284
left_lower_x,
6385
left_lower_y,
6486
left_lower_x + obj.get_width(),

matplotlib2tikz/save.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def save(filepath,
2222
textsize=10.0,
2323
tex_relative_path_to_data=None,
2424
strict=False,
25-
draw_rectangles=False,
2625
wrap=True,
2726
extra=None,
2827
dpi=None,
@@ -75,16 +74,6 @@ def save(filepath,
7574
can decide where to put the ticks.
7675
:type strict: bool
7776
78-
:param draw_rectangles: Whether or not to draw Rectangle objects.
79-
You normally don't want that as legend, axes, and
80-
other entities which are natively taken care of by
81-
PGFPlots are represented as rectangles in
82-
matplotlib. Some plot types (such as bar plots)
83-
cannot otherwise be represented though.
84-
Don't expect working or clean output when using
85-
this option.
86-
:type draw_rectangles: bool
87-
8877
:param wrap: Whether ``'\\begin{tikzpicture}'`` and
8978
``'\\end{tikzpicture}'`` will be written. One might need to
9079
provide custom arguments to the environment (eg. scale= etc.).
@@ -121,11 +110,14 @@ def save(filepath,
121110
data['output dir'] = os.path.dirname(filepath)
122111
data['base name'] = os.path.splitext(os.path.basename(filepath))[0]
123112
data['strict'] = strict
124-
data['draw rectangles'] = draw_rectangles
125113
data['tikz libs'] = set()
126114
data['pgfplots libs'] = set()
127115
data['font size'] = textsize
128116
data['custom colors'] = {}
117+
# rectangle_legends is used to keep track of which rectangles have already
118+
# had \addlegendimage added. There should be only one \addlegenimage per
119+
# bar chart data series.
120+
data['rectangle_legends'] = set()
129121
if extra:
130122
data['extra axis options'] = extra.copy()
131123
else:

test/test_hashes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def test_hash(name):
3232
figureheight='7.5cm',
3333
show_info=True,
3434
strict=True,
35-
draw_rectangles=True
3635
)
3736

3837
# save reference figure

test/testfunctions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
'annotate',
55
'basic_sin',
66
'boxplot',
7+
'barchart_legend',
8+
'barchart',
79
'dual_axis',
810
'errorband',
911
'errorbar',

test/testfunctions/barchart.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# -*- coding: utf-8 -*-
2+
""" Bar Chart test
3+
4+
This tests plots a simple bar chart. Bar charts are plotted as
5+
rectangle patches witch are difficult to tell from other rectangle
6+
patches that should not be plotted in PGFPlots (e.g. axis, legend)
7+
8+
"""
9+
desc = 'Bar Chart'
10+
phash = '5f09a9e6b3728742'
11+
12+
def plot():
13+
import matplotlib.pyplot as plt
14+
import numpy as np
15+
16+
# plot data
17+
fig = plt.figure()
18+
ax = fig.add_subplot(111)
19+
20+
x = np.arange(3)
21+
y1 = [1, 2, 3]
22+
y2 = [3, 2, 4]
23+
y3 = [5, 3, 1]
24+
w = 0.25
25+
26+
ax.bar(x-w, y1, w, color='b', align='center')
27+
ax.bar(x, y2, w, color='g', align='center')
28+
ax.bar(x+w, y3, w, color='r', align='center')
29+
30+
return fig
31+

test/testfunctions/barchart_legend.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# -*- coding: utf-8 -*-
2+
""" Bar Chart Legend test
3+
4+
This tests plots a simple bar chart. Bar charts are plotted as
5+
rectangle patches witch are difficult to tell from other rectangle
6+
patches that should not be plotted in PGFPlots (e.g. axis, legend)
7+
8+
This also tests legends on barcharts. Which are difficult because
9+
in PGFPlots, they have no \\addplot, and thus legend must be
10+
manually added.
11+
12+
"""
13+
desc = 'Bar Chart'
14+
phash = '5f09a9e633728dc4'
15+
16+
def plot():
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
20+
# plot data
21+
fig = plt.figure()
22+
ax = fig.add_subplot(111)
23+
24+
x = np.arange(3)
25+
y1 = [1, 2, 3]
26+
y2 = [3, 2, 4]
27+
y3 = [5, 3, 1]
28+
w = 0.25
29+
30+
ax.bar(x-w, y1, w, color='b', align='center', label='Data 1')
31+
ax.bar(x, y2, w, color='g', align='center', label='Data 2')
32+
ax.bar(x+w, y3, w, color='r', align='center', label='Data 3')
33+
ax.legend()
34+
35+
return fig
36+

0 commit comments

Comments
 (0)