Skip to content

Commit c95d367

Browse files
authored
Merge pull request #163 from danielhkl/master
extra_axis_parameters and extra_tikzpicture_parameters
2 parents 736ef34 + 5674d1b commit c95d367

File tree

5 files changed

+73
-27
lines changed

5 files changed

+73
-27
lines changed

matplotlib2tikz/image.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ def draw_image(data, obj):
2727
# store the image as in a file
2828
img_array = obj.get_array()
2929

30+
if obj.origin == "lower":
31+
img_array = numpy.flipud(img_array)
32+
3033
dims = img_array.shape
3134
if len(dims) == 2: # the values are given as one real number: look at cmap
3235
clims = obj.get_clim()
3336

34-
if obj.origin == "lower":
35-
img_array = numpy.flipud(img_array)
36-
3737
mpl.pyplot.imsave(fname=filename,
3838
arr=img_array,
3939
cmap=obj.get_cmap(),
@@ -44,9 +44,9 @@ def draw_image(data, obj):
4444
else:
4545
# RGB (+alpha) information at each point
4646
assert len(dims) == 3 and dims[2] in [3, 4]
47-
# convert to PIL image (after upside-down flip)
48-
image = PIL.Image.fromarray(numpy.flipud(img_array))
49-
image.save(filename)
47+
# convert to PIL image
48+
image = PIL.Image.fromarray(img_array)
49+
image.save(filename, origin=obj.origin)
5050

5151
# write the corresponding information to the TikZ file
5252
extent = obj.get_extent()

matplotlib2tikz/save.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ def get_tikz_code(
2626
tex_relative_path_to_data=None,
2727
strict=False,
2828
wrap=True,
29-
extra=None,
29+
extra_axis_parameters=None,
30+
extra_tikzpicture_parameters=None,
3031
dpi=None,
3132
show_info=True
32-
):
33+
):
3334
'''Main function. Here, the recursion into the image starts and the
3435
contents are picked up. The actual file gets written in this routine.
3536
@@ -83,9 +84,15 @@ def get_tikz_code(
8384
Default is ``True``.
8485
:type wrap: bool
8586
86-
:param extra: Extra axis options to be passed (as a set) to pgfplots.
87-
Default is ``None``.
88-
:type extra: a set of strings for the pfgplots axes.
87+
:param extra_axis_parameters: Extra axis options to be passed (as a set)
88+
to pgfplots. Default is ``None``.
89+
:type extra_axis_parameters: a set of strings for the pfgplots axes.
90+
91+
:param extra_tikzpicture_parameters: Extra tikzpicture options to be passed
92+
(as a set) to pgfplots.
93+
94+
:type extra_tikzpicture_parameters: a set of strings for the pfgplots
95+
tikzpicture.
8996
9097
:param dpi: The resolution in dots per inch of the rendered image in case
9198
of QuadMesh plots. If ``None`` it will default to the value
@@ -117,12 +124,13 @@ def get_tikz_code(
117124
data['pgfplots libs'] = set()
118125
data['font size'] = textsize
119126
data['custom colors'] = {}
127+
data['extra tikzpicture parameters'] = extra_tikzpicture_parameters
120128
# rectangle_legends is used to keep track of which rectangles have already
121129
# had \addlegendimage added. There should be only one \addlegenimage per
122130
# bar chart data series.
123131
data['rectangle_legends'] = set()
124-
if extra:
125-
data['extra axis options [base]'] = extra.copy()
132+
if extra_axis_parameters:
133+
data['extra axis options [base]'] = extra_axis_parameters.copy()
126134
else:
127135
data['extra axis options [base]'] = set()
128136

@@ -152,6 +160,9 @@ def get_tikz_code(
152160
# write the contents
153161
if wrap:
154162
code += '\\begin{tikzpicture}\n\n'
163+
if extra_tikzpicture_parameters:
164+
code += ',\n'.join(data['extra tikzpicture parameters'])
165+
code += '\n'
155166

156167
coldefs = _get_color_definitions(data)
157168
if coldefs:
@@ -255,7 +266,7 @@ def _recurse(data, obj):
255266
content = _ContentManager()
256267
for child in obj.get_children():
257268
if isinstance(child, mpl.axes.Axes):
258-
# Reset 'extra axis options' for every new Axes environment.
269+
# Reset 'extra axis parameters' for every new Axes environment.
259270
data['extra axis options'] = \
260271
data['extra axis options [base]'].copy()
261272

test/helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import matplotlib2tikz
44

55
import os
6+
import shutil
67
import tempfile
78
import subprocess
89
from PIL import Image
@@ -93,6 +94,9 @@ def assert_phash(fig, reference_phash):
9394
compute_phash(fig)
9495

9596
if reference_phash != phash:
97+
# Copy pdf_file in test directory
98+
shutil.copy(pdf_file, os.path.dirname(os.path.abspath(__file__)))
99+
96100
# Compute the Hamming distance between the two 64-bit numbers
97101
hamming_dist = \
98102
bin(int(phash, 16) ^ int(reference_phash, 16)).count('1')

test/test_image_plot.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,61 @@
11
# -*- coding: utf-8 -*-
22
#
33
import helpers
4+
from matplotlib import rcParams
5+
from matplotlib import pyplot as pp
6+
import pytest
7+
import os
48

9+
try:
10+
from PIL import Image
11+
except ImportError:
12+
raise RuntimeError('PIL must be installed to run this example')
513

6-
def plot():
7-
from matplotlib import rcParams
8-
from matplotlib import pyplot as pp
9-
import os
10-
try:
11-
from PIL import Image
12-
except ImportError:
13-
raise RuntimeError('PIL must be installed to run this example')
14+
# the picture 'lena.png' with origin='lower' is flipped upside-down.
15+
# So it has to be upside-down in the pdf-file as well.
16+
17+
18+
# test for monochrome picture
19+
def plot1():
1420

1521
this_dir = os.path.dirname(os.path.realpath(__file__))
1622
lena = Image.open(os.path.join(this_dir, 'lena.png'))
23+
lena = lena.convert('L')
1724
dpi = rcParams['figure.dpi']
1825
figsize = lena.size[0]/dpi, lena.size[1]/dpi
1926
fig = pp.figure(figsize=figsize)
2027
ax = pp.axes([0, 0, 1, 1], frameon=False)
2128
ax.set_axis_off()
22-
pp.imshow(lena, origin='lower')
29+
pp.imshow(lena, cmap='viridis', origin='lower')
30+
# Set the current color map to HSV.
31+
pp.hsv()
32+
pp.colorbar()
33+
return fig
34+
35+
36+
# test for rgb picture
37+
def plot2():
38+
39+
this_dir = os.path.dirname(os.path.realpath(__file__))
40+
lena = Image.open(os.path.join(this_dir, 'lena.png'))
41+
dpi = rcParams['figure.dpi']
42+
figsize = lena.size[0] / dpi, lena.size[1] / dpi
43+
fig = pp.figure(figsize=figsize)
44+
ax = pp.axes([0, 0, 1, 1], frameon=False)
45+
ax.set_axis_off()
46+
pp.imshow(lena, cmap='viridis', origin='lower')
2347
# Set the current color map to HSV.
2448
pp.hsv()
2549
pp.colorbar()
2650
return fig
2751

2852

29-
def test():
30-
helpers.assert_phash(plot(), '7558d3b30f634b06')
53+
@pytest.mark.parametrize(
54+
'plot, phash', [
55+
(plot1, '455361ec211d72fb'),
56+
(plot2, '7558d3b30f634b06'),
57+
]
58+
)
59+
def test(plot, phash):
60+
helpers.assert_phash(plot(), phash)
61+
return

test/test_rotated_labels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_rotated_labels_parameters(x_alignment, y_alignment,
6262
matplotlib2tikz.save(
6363
tikz_file,
6464
figurewidth='7.5cm',
65-
extra=extra_dict
65+
extra_axis_parameters=extra_dict
6666
)
6767

6868
# close figure
@@ -101,7 +101,7 @@ def test_rotated_labels_parameters_different_values(x_tick_label_width,
101101
matplotlib2tikz.save(
102102
tikz_file,
103103
figurewidth='7.5cm',
104-
extra=extra_dict
104+
extra_axis_parameters=extra_dict
105105
)
106106

107107
# close figure

0 commit comments

Comments
 (0)