28
28
#
29
29
# https://github.com/astrofrog/wcsaxes
30
30
31
- from functools import wraps
32
-
33
- import contextlib
34
31
import io
35
32
import os
36
- import sys
37
33
import json
38
34
import shutil
35
+ import hashlib
39
36
import inspect
40
37
import tempfile
41
38
import warnings
42
- import hashlib
43
- from distutils .version import LooseVersion
39
+ import contextlib
44
40
from pathlib import Path
41
+ from functools import wraps
42
+ from urllib .request import urlopen
45
43
46
44
import pytest
47
45
48
- if sys .version_info [0 ] == 2 :
49
- from urllib import urlopen
50
- string_types = basestring # noqa
51
- else :
52
- from urllib .request import urlopen
53
- string_types = str
54
-
55
-
56
46
SHAPE_MISMATCH_ERROR = """Error: Image dimensions did not match.
57
47
Expected shape: {expected_shape}
58
48
{expected_path}
@@ -74,11 +64,11 @@ def _download_file(baseline, filename):
74
64
else :
75
65
raise Exception ("Could not download baseline image from any of the "
76
66
"available URLs" )
77
- result_dir = tempfile .mkdtemp ()
78
- filename = os . path . join ( result_dir , 'downloaded' )
79
- with open (filename , 'wb' ) as tmpfile :
67
+ result_dir = Path ( tempfile .mkdtemp () )
68
+ filename = result_dir / 'downloaded'
69
+ with open (str ( filename ) , 'wb' ) as tmpfile :
80
70
tmpfile .write (content )
81
- return filename
71
+ return Path ( filename )
82
72
83
73
84
74
def _hash_file (in_stream ):
@@ -212,6 +202,10 @@ def get_marker(item, marker_name):
212
202
return item .keywords .get (marker_name )
213
203
214
204
205
+ def path_is_not_none (apath ):
206
+ return Path (apath ) if apath is not None else apath
207
+
208
+
215
209
class ImageComparison (object ):
216
210
217
211
def __init__ (self ,
@@ -225,13 +219,13 @@ def __init__(self,
225
219
):
226
220
self .config = config
227
221
self .baseline_dir = baseline_dir
228
- self .baseline_relative_dir = baseline_relative_dir
229
- self .generate_dir = generate_dir
230
- self .results_dir = results_dir
231
- self .hash_library = hash_library
232
- self .generate_hash_library = generate_hash_library
233
- if self .results_dir and not os . path .exists (self . results_dir ):
234
- os . mkdir ( self .results_dir )
222
+ self .baseline_relative_dir = path_is_not_none ( baseline_relative_dir )
223
+ self .generate_dir = path_is_not_none ( generate_dir )
224
+ self .results_dir = path_is_not_none ( results_dir )
225
+ self .hash_library = path_is_not_none ( hash_library )
226
+ self .generate_hash_library = path_is_not_none ( generate_hash_library )
227
+ if self .results_dir and not self . results_dir .exists ():
228
+ self .results_dir . mkdir ( )
235
229
236
230
# We need global state to store all the hashes generated over the run
237
231
self ._generated_hash_library = {}
@@ -261,7 +255,7 @@ def make_results_dir(self, item):
261
255
"""
262
256
Generate the directory to put the results in.
263
257
"""
264
- return tempfile .mkdtemp (dir = self .results_dir )
258
+ return Path ( tempfile .mkdtemp (dir = self .results_dir ) )
265
259
266
260
def get_baseline_directory (self , item ):
267
261
"""
@@ -274,21 +268,19 @@ def get_baseline_directory(self, item):
274
268
baseline_dir = compare .kwargs .get ('baseline_dir' , None )
275
269
if baseline_dir is None :
276
270
if self .baseline_dir is None :
277
- baseline_dir = os . path . join ( os . path . dirname ( item .fspath . strpath ), 'baseline' )
271
+ baseline_dir = Path ( item .fspath ). parent / 'baseline'
278
272
else :
279
273
if self .baseline_relative_dir :
280
274
# baseline dir is relative to the current test
281
- baseline_dir = os .path .join (
282
- os .path .dirname (item .fspath .strpath ),
283
- self .baseline_relative_dir
284
- )
275
+ baseline_dir = Path (item .fspath ).parent / self .baseline_relative_dir
285
276
else :
286
277
# baseline dir is relative to where pytest was run
287
278
baseline_dir = self .baseline_dir
288
279
289
- baseline_remote = baseline_dir .startswith (('http://' , 'https://' ))
280
+ baseline_remote = (isinstance (baseline_dir , str ) and # noqa
281
+ baseline_dir .startswith (('http://' , 'https://' )))
290
282
if not baseline_remote :
291
- return os . path . join ( os . path . dirname ( item .fspath . strpath ), baseline_dir )
283
+ return Path ( item .fspath ). parent / baseline_dir
292
284
293
285
return baseline_dir
294
286
@@ -301,13 +293,14 @@ def obtain_baseline_image(self, item, target_dir):
301
293
"""
302
294
filename = self .generate_filename (item )
303
295
baseline_dir = self .get_baseline_directory (item )
304
- baseline_remote = baseline_dir .startswith (('http://' , 'https://' ))
296
+ baseline_remote = (isinstance (baseline_dir , str ) and # noqa
297
+ baseline_dir .startswith (('http://' , 'https://' )))
305
298
if baseline_remote :
306
299
# baseline_dir can be a list of URLs when remote, so we have to
307
300
# pass base and filename to download
308
301
baseline_image = _download_file (baseline_dir , filename )
309
302
else :
310
- baseline_image = os . path . abspath ( os . path . join ( baseline_dir , filename ))
303
+ baseline_image = ( baseline_dir / filename ). absolute ( )
311
304
312
305
return baseline_image
313
306
@@ -321,10 +314,11 @@ def generate_baseline_image(self, item, fig):
321
314
if not os .path .exists (self .generate_dir ):
322
315
os .makedirs (self .generate_dir )
323
316
324
- fig .savefig (os . path . abspath ( os . path . join (self .generate_dir , self .generate_filename (item ))),
317
+ fig .savefig (str ( (self .generate_dir / self .generate_filename (item )). absolute ( )),
325
318
** savefig_kwargs )
319
+
326
320
close_mpl_figure (fig )
327
- pytest .skip ("Skipping test, since generating data " )
321
+ pytest .skip ("Skipping test, since generating image " )
328
322
329
323
def generate_hash_name (self , item ):
330
324
"""
@@ -363,8 +357,8 @@ def compare_image_to_baseline(self, item, fig, result_dir):
363
357
364
358
baseline_image_ref = self .obtain_baseline_image (item , result_dir )
365
359
366
- test_image = os . path . abspath ( os . path . join ( result_dir , self .generate_filename (item )))
367
- fig .savefig (test_image , ** savefig_kwargs )
360
+ test_image = ( result_dir / self .generate_filename (item )). absolute ( )
361
+ fig .savefig (str ( test_image ) , ** savefig_kwargs )
368
362
369
363
if not os .path .exists (baseline_image_ref ):
370
364
pytest .fail ("Image file not found for comparison test in: "
@@ -376,38 +370,32 @@ def compare_image_to_baseline(self, item, fig, result_dir):
376
370
377
371
# distutils may put the baseline images in non-accessible places,
378
372
# copy to our tmpdir to be sure to keep them in case of failure
379
- baseline_image = os .path .abspath (
380
- os .path .join (result_dir ,
381
- 'baseline-' + self .generate_filename (item ))
382
- )
373
+ baseline_image = (result_dir / f"baseline-{ self .generate_filename (item )} " ).absolute ()
383
374
shutil .copyfile (baseline_image_ref , baseline_image )
384
375
385
376
# Compare image size ourselves since the Matplotlib
386
377
# exception is a bit cryptic in this case and doesn't show
387
378
# the filenames
388
- expected_shape = imread (baseline_image ).shape [:2 ]
389
- actual_shape = imread (test_image ).shape [:2 ]
379
+ expected_shape = imread (str ( baseline_image ) ).shape [:2 ]
380
+ actual_shape = imread (str ( test_image ) ).shape [:2 ]
390
381
if expected_shape != actual_shape :
391
382
error = SHAPE_MISMATCH_ERROR .format (expected_path = baseline_image ,
392
383
expected_shape = expected_shape ,
393
384
actual_path = test_image ,
394
385
actual_shape = actual_shape )
395
386
pytest .fail (error , pytrace = False )
396
387
397
- return compare_images (baseline_image , test_image , tol = tolerance )
388
+ return compare_images (str ( baseline_image ), str ( test_image ) , tol = tolerance )
398
389
399
390
def load_hash_library (self , library_path ):
400
- with open (library_path ) as fp :
391
+ with open (str ( library_path ) ) as fp :
401
392
return json .load (fp )
402
393
403
394
def compare_image_to_hash_library (self , item , fig , result_dir ):
404
395
compare = self .get_compare (item )
405
396
406
397
hash_library_filename = self .hash_library or compare .kwargs .get ('hash_library' , None )
407
- hash_library_filename = os .path .abspath (
408
- os .path .join (os .path .dirname (item .fspath .strpath ),
409
- hash_library_filename )
410
- )
398
+ hash_library_filename = (Path (item .fspath ).parent / hash_library_filename ).absolute ()
411
399
412
400
if not Path (hash_library_filename ).exists ():
413
401
pytest .fail (f"Can't find hash library at path { hash_library_filename } " )
@@ -432,23 +420,17 @@ def pytest_runtest_setup(self, item): # noqa
432
420
if compare is None :
433
421
return
434
422
435
- import matplotlib
436
423
import matplotlib .pyplot as plt
437
424
try :
438
425
from matplotlib .testing .decorators import remove_ticks_and_titles
439
426
except ImportError :
440
427
from matplotlib .testing .decorators import ImageComparisonTest as MplImageComparisonTest
441
428
remove_ticks_and_titles = MplImageComparisonTest .remove_text
442
429
443
- MPL_LT_15 = LooseVersion (matplotlib .__version__ ) < LooseVersion ('1.5' )
444
-
445
430
style = compare .kwargs .get ('style' , 'classic' )
446
431
remove_text = compare .kwargs .get ('remove_text' , False )
447
432
backend = compare .kwargs .get ('backend' , 'agg' )
448
433
449
- if MPL_LT_15 and style == 'classic' :
450
- style = os .path .join (os .path .dirname (__file__ ), 'classic.mplstyle' )
451
-
452
434
original = item .function
453
435
454
436
@wraps (item .function )
0 commit comments