Skip to content

Commit 35a3a83

Browse files
authored
Merge pull request #275 from nschloe/annotations
Annotations
2 parents 1625a89 + e476413 commit 35a3a83

12 files changed

+356
-40
lines changed

matplotlib2tikz/axes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,19 @@ def __init__(self, data, obj):
3636

3737
# get plot title
3838
title = obj.get_title()
39+
data["current axis title"] = title
3940
if title:
40-
self.axis_options.append("title={{{}}}".format(title))
41+
self.axis_options.append(u"title={{{}}}".format(title))
4142

4243
# get axes titles
4344
xlabel = obj.get_xlabel()
4445
if xlabel:
4546
xlabel = mpl_backend_pgf.common_texification(xlabel)
46-
self.axis_options.append("xlabel={{{}}}".format(xlabel))
47+
self.axis_options.append(u"xlabel={{{}}}".format(xlabel))
4748
ylabel = obj.get_ylabel()
4849
if ylabel:
4950
ylabel = mpl_backend_pgf.common_texification(ylabel)
50-
self.axis_options.append("ylabel={{{}}}".format(ylabel))
51+
self.axis_options.append(u"ylabel={{{}}}".format(ylabel))
5152

5253
# Axes limits.
5354
# Sort the limits so make sure that the smaller of the two is actually

matplotlib2tikz/save.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from . import quadmesh as qmsh
1919
from . import path
2020
from . import patch
21-
from . import text as mytext
21+
from . import text
2222

2323
from .__about__ import __version__
2424

@@ -223,7 +223,6 @@ def get_tikz_code(
223223
\\usepackage[utf8]{{inputenc}}
224224
\\usepackage{{pgfplots}}
225225
\\usepgfplotslibrary{{groupplots}}
226-
\\usetikzlibrary{{shapes.arrows}}
227226
\\pgfplotsset{{compat=newest}}
228227
\\begin{{document}}
229228
{}
@@ -277,7 +276,7 @@ def _print_pgfplot_libs_message(data):
277276
pgfplotslibs = ",".join(list(data["pgfplots libs"]))
278277
tikzlibs = ",".join(list(data["tikz libs"]))
279278

280-
print("=========================================================")
279+
print(70 * "=")
281280
print("Please add the following lines to your LaTeX preamble:\n")
282281
print("\\usepackage[utf8]{inputenc}")
283282
print("\\usepackage{fontspec} % This line only for XeLaTeX and LuaLaTeX")
@@ -286,7 +285,7 @@ def _print_pgfplot_libs_message(data):
286285
print("\\usetikzlibrary{" + tikzlibs + "}")
287286
if pgfplotslibs:
288287
print("\\usepgfplotslibrary{" + pgfplotslibs + "}")
289-
print("=========================================================")
288+
print(70 * "=")
290289
return
291290

292291

@@ -360,9 +359,6 @@ def _recurse(data, obj):
360359
elif isinstance(child, mpl.image.AxesImage):
361360
data, cont = img.draw_image(data, child)
362361
content.extend(cont, child.get_zorder())
363-
# # Really necessary?
364-
# data, children_content = _recurse(data, child)
365-
# content.extend(children_content)
366362
elif isinstance(child, mpl.patches.Patch):
367363
data, cont = patch.draw_patch(data, child)
368364
content.extend(cont, child.get_zorder())
@@ -384,19 +380,15 @@ def _recurse(data, obj):
384380
data = legend.draw_legend(data, child)
385381
if data["legend colors"]:
386382
content.extend(data["legend colors"], 0)
387-
elif isinstance(
388-
child, (mpl.axis.XAxis, mpl.axis.YAxis, mpl.spines.Spine, mpl.text.Text)
389-
):
383+
elif isinstance(child, (mpl.text.Text, mpl.text.Annotation)):
384+
data, cont = text.draw_text(data, child)
385+
content.extend(cont, child.get_zorder())
386+
elif isinstance(child, (mpl.axis.XAxis, mpl.axis.YAxis, mpl.spines.Spine)):
390387
pass
391388
else:
392389
warnings.warn(
393390
"matplotlib2tikz: Don't know how to handle object {}.".format(
394391
type(child)
395392
)
396393
)
397-
# XXX: This is ugly
398-
if isinstance(obj, (mpl.axes.Subplot, mpl.figure.Figure)):
399-
for text in obj.texts:
400-
data, cont = mytext.draw_text(data, text)
401-
content.extend(cont, text.get_zorder())
402394
return data, content.flatten()

matplotlib2tikz/text.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def draw_text(data, obj):
2121
# -------1--------2---3--4--
2222
pos = obj.get_position()
2323
text = obj.get_text()
24+
25+
if text in ["", data["current axis title"]]:
26+
# Text nodes which are direct children of Axes are typically titles. They are
27+
# already captured by the `title` property of pgfplots axes, so skip them here.
28+
return data, content
29+
2430
size = obj.get_size()
2531
bbox = obj.get_bbox_patch()
2632
converter = mpl.colors.ColorConverter()
@@ -109,8 +115,8 @@ def draw_text(data, obj):
109115
text = text.replace("\n ", "\\\\")
110116

111117
content.append(
112-
"\\node at {}[\n {}\n]{{{} {}}};\n".format(
113-
tikz_pos, ",\n ".join(properties), " ".join(style), text
118+
"\\node at {}[\n {}\n]{{{}}};\n".format(
119+
tikz_pos, ",\n ".join(properties), " ".join(style + [text])
114120
)
115121
)
116122
return data, content
@@ -144,11 +150,27 @@ def _annotation(obj, data, content):
144150
)
145151
return data, content
146152
else: # Create a basic tikz arrow
153+
arrow_translate = {
154+
"-": ["-"],
155+
"->": ["->"],
156+
"<-": ["<-"],
157+
"<->": ["<->"],
158+
"|-|": ["|-|"],
159+
"-|>": ["-latex"],
160+
"<|-": ["latex-"],
161+
"<|-|>": ["latex-latex"],
162+
"]-[": ["|-|"],
163+
"-[": ["-|"],
164+
"]-": ["|-"],
165+
"fancy": ["-latex", "very thick"],
166+
"simple": ["-latex", "very thick"],
167+
"wedge": ["-latex", "very thick"],
168+
}
147169
arrow_style = []
148170
if obj.arrowprops is not None:
149171
if obj.arrowprops["arrowstyle"] is not None:
150-
if obj.arrowprops["arrowstyle"] in ["-", "->", "<-", "<->"]:
151-
arrow_style.append(obj.arrowprops["arrowstyle"])
172+
if obj.arrowprops["arrowstyle"] in arrow_translate:
173+
arrow_style += arrow_translate[obj.arrowprops["arrowstyle"]]
152174
data, col, _ = color.mpl_color2xcolor(
153175
data, obj.arrow_patch.get_ec()
154176
)

test/helpers.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def assert_equality(plot, filename):
4444
assert reference == code, _unidiff_output(reference, code)
4545

4646
code = matplotlib2tikz.get_tikz_code(include_disclaimer=False, standalone=True)
47-
assert _does_compile(code)
47+
assert _compile(code) is not None
4848
return
4949

5050

51-
def _does_compile(code):
51+
def _compile(code):
5252
_, tmp_base = tempfile.mkstemp()
5353

5454
tex_file = tmp_base + ".tex"
@@ -65,12 +65,33 @@ def _does_compile(code):
6565
stderr=subprocess.STDOUT,
6666
)
6767
except subprocess.CalledProcessError as e:
68-
print("Command output:")
68+
print("pdflatex output:")
6969
print("=" * 70)
70-
print(e.output)
70+
print(e.output.decode("utf-8"))
7171
print("=" * 70)
72-
does_compile = False
72+
output_pdf = None
7373
else:
74-
does_compile = True
74+
output_pdf = tmp_base + ".pdf"
7575

76-
return does_compile
76+
return output_pdf
77+
78+
79+
def compare_mpl_latex(plot):
80+
plot()
81+
code = matplotlib2tikz.get_tikz_code(standalone=True)
82+
directory = os.getcwd()
83+
filename = "test-0.png"
84+
plt.savefig(filename)
85+
plt.close()
86+
87+
pdf_file = _compile(code)
88+
pdf_dirname = os.path.dirname(pdf_file)
89+
90+
# Convert PDF to PNG.
91+
subprocess.check_output(
92+
["pdftoppm", "-r", "1000", "-png", pdf_file, "test"], stderr=subprocess.STDOUT
93+
)
94+
png_path = os.path.join(pdf_dirname, "test-1.png")
95+
96+
os.rename(png_path, os.path.join(directory, "test-1.png"))
97+
return

test/test_annotate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,10 @@ def plot():
3434
def test():
3535
assert_equality(plot, "test_annotate_reference.tex")
3636
return
37+
38+
39+
if __name__ == "__main__":
40+
import helpers
41+
42+
# helpers.compare_mpl_latex(plot)
43+
helpers.print_tree(plot())

test/test_annotate_reference.tex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
anchor=base west,
4343
text=black,
4444
rotate=0.0
45-
]{ text};
45+
]{text};
4646
\node at (axis cs:-50,30)[
4747
scale=0.5,
4848
anchor=base west,
4949
text=black,
5050
rotate=0.0
51-
]{ arrowstyle};
51+
]{arrowstyle};
5252
\end{axis}
5353

5454
\end{tikzpicture}

test/test_arrows.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
import matplotlib.patches as mpatches
4+
import matplotlib.pyplot as plt
5+
6+
7+
# https://matplotlib.org/examples/pylab_examples/fancyarrow_demo.html
8+
def plot():
9+
styles = mpatches.ArrowStyle.get_styles()
10+
11+
ncol = 2
12+
nrow = (len(styles) + 1) // ncol
13+
figheight = nrow + 0.5
14+
fig1 = plt.figure(1, (4.0 * ncol / 1.5, figheight / 1.5))
15+
fontsize = 0.2 * 70
16+
17+
ax = fig1.add_axes([0, 0, 1, 1], frameon=False, aspect=1.0)
18+
19+
ax.set_xlim(0, 4 * ncol)
20+
ax.set_ylim(0, figheight)
21+
22+
def to_texstring(s):
23+
s = s.replace("<", r"$<$")
24+
s = s.replace(">", r"$>$")
25+
s = s.replace("|", r"$|$")
26+
return s
27+
28+
for i, (stylename, styleclass) in enumerate(sorted(styles.items())):
29+
x = 3.2 + (i // nrow) * 4
30+
y = figheight - 0.7 - i % nrow # /figheight
31+
p = mpatches.Circle((x, y), 0.2)
32+
ax.add_patch(p)
33+
34+
ax.annotate(
35+
to_texstring(stylename),
36+
(x, y),
37+
(x - 1.2, y),
38+
# xycoords="figure fraction", textcoords="figure fraction",
39+
ha="right",
40+
va="center",
41+
size=fontsize,
42+
arrowprops=dict(
43+
arrowstyle=stylename,
44+
patchB=p,
45+
shrinkA=5,
46+
shrinkB=5,
47+
fc="k",
48+
ec="k",
49+
connectionstyle="arc3,rad=-0.05",
50+
),
51+
bbox=dict(boxstyle="square", fc="w"),
52+
)
53+
54+
ax.xaxis.set_visible(False)
55+
ax.yaxis.set_visible(False)
56+
return plt.gcf()
57+
58+
59+
# if __name__ == "__main__":
60+
# import helpers
61+
#
62+
# # fig = plot()
63+
# # helpers.print_tree(fig)
64+
# # plt.show()
65+
#
66+
# plot()
67+
# import matplotlib2tikz
68+
# code = matplotlib2tikz.get_tikz_code(include_disclaimer=False, standalone=True)
69+
# plt.close()
70+
# helpers._does_compile(code)
71+
72+
if __name__ == "__main__":
73+
import helpers
74+
75+
helpers.compare_mpl_latex(plot)
76+
# helpers.print_tree(plot())

0 commit comments

Comments
 (0)