Skip to content

Commit 5791ebc

Browse files
committed
fix image tests
1 parent 3c4ae4f commit 5791ebc

File tree

4 files changed

+34
-21
lines changed

4 files changed

+34
-21
lines changed

matplotlib2tikz/image.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,15 @@ def draw_image(data, obj):
4343
# RGB (+alpha) information at each point
4444
assert len(dims) == 3 and dims[2] in [3, 4]
4545
# convert to PIL image
46-
if obj.origin == "lower":
46+
if obj.origin == 'lower':
4747
img_array = numpy.flipud(img_array)
48-
image = PIL.Image.fromarray(img_array)
48+
49+
# Convert mpl image to PIL
50+
image = PIL.Image.fromarray(numpy.uint8(img_array*255))
51+
52+
# If the input image is PIL:
53+
# image = PIL.Image.fromarray(img_array)
54+
4955
image.save(filename, origin=obj.origin)
5056

5157
# write the corresponding information to the TikZ file

test/test_annotate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def plot():
3939
def test():
4040
plt.close('all')
4141
phash = Phash(plot())
42-
assert phash.phash == 'ab8a71a1549e54be', phash.get_details()
42+
assert phash.phash == 'ab8a79a1549654de', phash.get_details()
4343

4444

4545
if __name__ == '__main__':

test/test_fancybox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def plot():
193193

194194
def test():
195195
phash = helpers.Phash(plot())
196-
assert phash.phash == 'dd2325dc23cdd81a', phash.get_details()
196+
assert phash.phash == 'dd2325d823cdd85a', phash.get_details()
197197

198198

199199
if __name__ == '__main__':

test/test_image_plot.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,51 @@
11
# -*- coding: utf-8 -*-
22
#
33
import helpers
4+
5+
import matplotlib.pyplot as plt
46
import pytest
57

68
# the picture 'lena.png' with origin='lower' is flipped upside-down.
79
# So it has to be upside-down in the pdf-file as well.
810

911

10-
# test for monochrome picture
11-
def plot1():
12+
def plot_upper():
1213
from matplotlib import rcParams
13-
import matplotlib.pyplot as plt
14-
from PIL import Image
14+
import matplotlib.image as mpimg
1515
import os
1616

1717
this_dir = os.path.dirname(os.path.realpath(__file__))
18-
lena = Image.open(os.path.join(this_dir, 'lena.png'))
19-
lena = lena.convert('L')
18+
img = mpimg.imread(os.path.join(this_dir, 'lena.png'))
19+
2020
dpi = rcParams['figure.dpi']
21-
figsize = lena.size[0]/dpi, lena.size[1]/dpi
21+
figsize = img.shape[0]/dpi, img.shape[1]/dpi
2222
fig = plt.figure(figsize=figsize)
2323
ax = plt.axes([0, 0, 1, 1], frameon=False)
2424
ax.set_axis_off()
25-
plt.imshow(lena, cmap='viridis', origin='lower')
25+
26+
plt.imshow(img, cmap='viridis', origin='upper')
27+
2628
# Set the current color map to HSV.
2729
plt.hsv()
2830
plt.colorbar()
2931
return fig
3032

3133

32-
# test for rgb picture
33-
def plot2():
34+
def plot_lower():
3435
from matplotlib import rcParams
35-
import matplotlib.pyplot as plt
36-
from PIL import Image
36+
import matplotlib.image as mpimg
3737
import os
3838

3939
this_dir = os.path.dirname(os.path.realpath(__file__))
40-
lena = Image.open(os.path.join(this_dir, 'lena.png'))
40+
img = mpimg.imread(os.path.join(this_dir, 'lena.png'))
41+
4142
dpi = rcParams['figure.dpi']
42-
figsize = lena.size[0] / dpi, lena.size[1] / dpi
43+
figsize = img.shape[0] / dpi, img.shape[1] / dpi
44+
4345
fig = plt.figure(figsize=figsize)
4446
ax = plt.axes([0, 0, 1, 1], frameon=False)
4547
ax.set_axis_off()
46-
plt.imshow(lena, cmap='viridis', origin='lower')
48+
plt.imshow(img, cmap='viridis', origin='lower')
4749
# Set the current color map to HSV.
4850
plt.hsv()
4951
plt.colorbar()
@@ -52,11 +54,16 @@ def plot2():
5254

5355
@pytest.mark.parametrize(
5456
'plot, reference_phash', [
55-
(plot1, '455361ec211d72fb'),
56-
(plot2, '7558d3b30f634b06'),
57+
(plot_upper, '75c3d36d1f090ba1'),
58+
(plot_lower, '7548d3b34f234b07'),
5759
]
5860
)
5961
def test(plot, reference_phash):
6062
phash = helpers.Phash(plot())
6163
assert phash.phash == reference_phash, phash.get_details()
6264
return
65+
66+
67+
if __name__ == '__main__':
68+
plot_upper()
69+
plt.show()

0 commit comments

Comments
 (0)