Skip to content

Commit 9f13d76

Browse files
committed
Merge pull request #60 from matthew-brett/dft-optional-packages
Make dicom / PIL.Image optional for tests
2 parents 64720a3 + e569221 commit 9f13d76

File tree

4 files changed

+167
-14
lines changed

4 files changed

+167
-14
lines changed

nibabel/dft.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@
99
# Copyright (C) 2011 Christian Haselgrove
1010

1111
import os
12-
import struct
1312
import tempfile
1413
import StringIO
15-
import numpy
16-
import nibabel
1714
import sqlite3
18-
import dicom
15+
16+
import numpy
17+
18+
from .nifti1 import Nifti1Header
19+
20+
# Shield optional dicom import
21+
from .optpkg import optional_package
22+
dicom, have_dicom, _ = optional_package('dicom')
23+
1924

2025
class DFTError(Exception):
2126
"base class for DFT exceptions"
@@ -176,7 +181,7 @@ def as_nifti(self):
176181

177182
m = numpy.array(m)
178183

179-
hdr = nibabel.nifti1.Nifti1Header(endianness='<')
184+
hdr = Nifti1Header(endianness='<')
180185
hdr.set_intent(0)
181186
hdr.set_qform(m, 1)
182187
hdr.set_xyzt_units(2, 8)

nibabel/optpkg.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
""" Routines to support optional packages """
2+
3+
try:
4+
import nose
5+
except ImportError:
6+
have_nose = False
7+
else:
8+
have_nose = True
9+
10+
from .tripwire import TripWire, is_tripwire
11+
12+
def optional_package(name, trip_msg=None):
13+
""" Return package-like thing and module setup for package `name`
14+
15+
Parameters
16+
----------
17+
name : str
18+
package name
19+
trip_msg : None or str
20+
message to give when someone tries to use the return package, but we
21+
could not import it, and have returned a TripWire object instead.
22+
Default message if None.
23+
24+
Returns
25+
-------
26+
pkg_like : module or ``TripWire`` instance
27+
If we can import the package, return it. Otherwise return an object
28+
raising an error when accessed
29+
have_pkg : bool
30+
True if import for package was successful, false otherwise
31+
module_setup : function
32+
callable usually set as ``setup_module`` in calling namespace, to allow
33+
skipping tests.
34+
35+
Example
36+
-------
37+
Typical use would be something like this at the top of a module using an
38+
optional package:
39+
40+
>>> from nipy.utils.optpkg import optional_package
41+
>>> pkg, have_pkg, setup_module = optional_package('not_a_package')
42+
43+
Of course in this case the package doesn't exist, and so, in the module:
44+
45+
>>> have_pkg
46+
False
47+
48+
and
49+
50+
>>> pkg.some_function()
51+
Traceback (most recent call last):
52+
...
53+
TripWireError: We need package not_a_package for these functions, but ``import not_a_package`` raised an ImportError
54+
55+
If the module does exist - we get the module
56+
57+
>>> pkg, _, _ = optional_package('os')
58+
>>> hasattr(pkg, 'path')
59+
True
60+
61+
Or a submodule if that's what we asked for
62+
63+
>>> subpkg, _, _ = optional_package('os.path')
64+
>>> hasattr(subpkg, 'dirname')
65+
True
66+
"""
67+
# fromlist=[''] results in submodule being returned, rather than the top
68+
# level module. See help(__import__)
69+
try:
70+
pkg = __import__(name, fromlist=[''])
71+
except ImportError:
72+
pass
73+
else: # import worked
74+
# top level module
75+
return pkg, True, lambda : None
76+
if trip_msg is None:
77+
trip_msg = ('We need package %s for these functions, but '
78+
'``import %s`` raised an ImportError'
79+
% (name, name))
80+
pkg = TripWire(trip_msg)
81+
def setup_module():
82+
if have_nose:
83+
raise nose.plugins.skip.SkipTest('No %s for these tests'
84+
% name)
85+
return pkg, False, setup_module
86+

nibabel/tests/test_dft.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,28 @@
33

44
from os.path import join as pjoin, dirname
55
import StringIO
6-
import PIL.Image
7-
from nose.tools import assert_true, assert_false, assert_equal, assert_raises
6+
7+
import numpy as np
8+
89
from .. import dft
910
from .. import nifti1
1011

12+
from nose.tools import (assert_true, assert_false, assert_equal, assert_raises)
13+
14+
# Shield optional package imports
15+
from ..optpkg import optional_package
16+
# setup_module will raise SkipTest if no dicom to import
17+
dicom, have_dicom, setup_module = optional_package('dicom')
18+
PImage, have_pil, _ = optional_package('PIL.Image')
19+
pil_test = np.testing.dec.skipif(not have_pil, 'could not import PIL.Image')
20+
1121
data_dir = pjoin(dirname(__file__), 'data')
1222

23+
1324
def test_init():
1425
dft.clear_cache()
1526
dft.update_cache(data_dir)
16-
return
27+
1728

1829
def test_study():
1930
studies = dft.get_studies(data_dir)
@@ -42,32 +53,35 @@ def test_series():
4253
assert_equal(ser.bits_allocated, 16)
4354
assert_equal(ser.bits_stored, 12)
4455

56+
4557
def test_storage_instances():
4658
studies = dft.get_studies(data_dir)
4759
sis = studies[0].series[0].storage_instances
4860
assert_equal(len(sis), 2)
4961
assert_equal(sis[0].instance_number, 1)
5062
assert_equal(sis[1].instance_number, 2)
51-
assert_equal(sis[0].uid,
63+
assert_equal(sis[0].uid,
5264
'1.3.12.2.1107.5.2.32.35119.2010011420300180088599504.0')
53-
assert_equal(sis[1].uid,
65+
assert_equal(sis[1].uid,
5466
'1.3.12.2.1107.5.2.32.35119.2010011420300180088599504.1')
5567

68+
5669
def test_storage_instance():
57-
return
70+
pass
5871

72+
73+
@pil_test
5974
def test_png():
6075
studies = dft.get_studies(data_dir)
6176
data = studies[0].series[0].as_png()
62-
im = PIL.Image.open(StringIO.StringIO(data))
77+
im = PImage.open(StringIO.StringIO(data))
6378
assert_equal(im.size, (256, 256))
64-
return
79+
6580

6681
def test_nifti():
6782
studies = dft.get_studies(data_dir)
6883
data = studies[0].series[0].as_nifti()
6984
assert_equal(len(data), 352 + 2*256*256*2)
7085
h = nifti1.Nifti1Header(data[:348])
7186
assert_equal(h.get_data_shape(), (256, 256, 2))
72-
return
7387

nibabel/tripwire.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
""" Class to raise error for missing modules or other misfortunes
2+
"""
3+
4+
class TripWireError(Exception):
5+
""" Exception if trying to use TripWire object """
6+
7+
8+
def is_tripwire(obj):
9+
""" Returns True if `obj` appears to be a TripWire object
10+
11+
Examples
12+
--------
13+
>>> is_tripwire(object())
14+
False
15+
>>> is_tripwire(TripWire('some message'))
16+
True
17+
"""
18+
try:
19+
obj.any_attribute
20+
except TripWireError:
21+
return True
22+
except:
23+
pass
24+
return False
25+
26+
27+
class TripWire(object):
28+
""" Class raising error if used
29+
30+
Standard use is to proxy modules that we could not import
31+
32+
Examples
33+
--------
34+
>>> try:
35+
... import silly_module_name
36+
... except ImportError:
37+
... silly_module_name = TripWire('We do not have silly_module_name')
38+
>>> silly_module_name.do_silly_thing('with silly string')
39+
Traceback (most recent call last):
40+
...
41+
TripWireError: We do not have silly_module_name
42+
"""
43+
def __init__(self, msg):
44+
self._msg = msg
45+
46+
def __getattr__(self, attr_name):
47+
''' Raise informative error accessing attributes '''
48+
raise TripWireError(self._msg)

0 commit comments

Comments
 (0)