Skip to content

Commit 4d2404d

Browse files
committed
NF: add min version argumemt to optpkg and test
Add ability to specify minimum version of package when asking for an optional package.
1 parent 53a8606 commit 4d2404d

File tree

2 files changed

+107
-4
lines changed

2 files changed

+107
-4
lines changed

nibabel/optpkg.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
""" Routines to support optional packages """
2+
from distutils.version import LooseVersion
3+
4+
from .externals.six import string_types, callable
25

36
try:
47
import nose
@@ -10,7 +13,17 @@
1013
from .tripwire import TripWire
1114

1215

13-
def optional_package(name, trip_msg=None):
16+
def _check_pkg_version(pkg, min_version):
17+
# Default version checking function
18+
if isinstance(min_version, string_types):
19+
min_version = LooseVersion(min_version)
20+
try:
21+
return min_version <= pkg.__version__
22+
except AttributeError:
23+
return False
24+
25+
26+
def optional_package(name, trip_msg=None, min_version=None):
1427
""" Return package-like thing and module setup for package `name`
1528
1629
Parameters
@@ -19,8 +32,14 @@ def optional_package(name, trip_msg=None):
1932
package name
2033
trip_msg : None or str
2134
message to give when someone tries to use the return package, but we
22-
could not import it, and have returned a TripWire object instead.
23-
Default message if None.
35+
could not import it at an acceptable version, and have returned a
36+
TripWire object instead. Default message if None.
37+
min_version : None or str or LooseVersion or callable
38+
If None, do not specify a minimum version. If str, convert to a
39+
`distutils.version.LooseVersion`. If str or LooseVersion` compare to
40+
version of package `name` with ``min_version <= pkg.__version__``. If
41+
callable, accepts imported ``pkg`` as argument, and returns value of
42+
callable is True for acceptable package versions, False otherwise.
2443
2544
Returns
2645
-------
@@ -66,6 +85,12 @@ def optional_package(name, trip_msg=None):
6685
>>> hasattr(subpkg, 'dirname')
6786
True
6887
"""
88+
if callable(min_version):
89+
check_version = min_version
90+
elif min_version is None:
91+
check_version = lambda pkg: True
92+
else:
93+
check_version = lambda pkg: _check_pkg_version(pkg, min_version)
6994
# fromlist=[''] results in submodule being returned, rather than the top
7095
# level module. See help(__import__)
7196
fromlist = [''] if '.' in name else []
@@ -75,7 +100,15 @@ def optional_package(name, trip_msg=None):
75100
pass
76101
else: # import worked
77102
# top level module
78-
return pkg, True, lambda: None
103+
if check_version(pkg):
104+
return pkg, True, lambda: None
105+
# Failed version check
106+
if trip_msg is None:
107+
if callable(min_version):
108+
trip_msg = 'Package %s fails version check' % min_version
109+
else:
110+
trip_msg = ('These functions need %s version >= %s' %
111+
(name, min_version))
79112
if trip_msg is None:
80113
trip_msg = ('We need package %s for these functions, but '
81114
'``import %s`` raised an ImportError'

nibabel/tests/test_optpkg.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
""" Testing optpkg module
2+
"""
3+
4+
import types
5+
import sys
6+
from distutils.version import LooseVersion
7+
8+
from nose import SkipTest
9+
from nose.tools import (assert_true, assert_false, assert_raises,
10+
assert_equal)
11+
12+
13+
from nibabel.optpkg import optional_package
14+
from nibabel.tripwire import TripWire, TripWireError
15+
16+
17+
def assert_good(pkg_name, min_version=None):
18+
pkg, have_pkg, setup = optional_package(pkg_name, min_version=min_version)
19+
assert_true(have_pkg)
20+
assert_equal(sys.modules[pkg_name], pkg)
21+
assert_equal(setup(), None)
22+
23+
24+
def assert_bad(pkg_name, min_version=None):
25+
pkg, have_pkg, setup = optional_package(pkg_name, min_version=min_version)
26+
assert_false(have_pkg)
27+
assert_true(isinstance(pkg, TripWire))
28+
assert_raises(TripWireError, getattr, pkg, 'a_method')
29+
assert_raises(SkipTest, setup)
30+
31+
32+
def test_basic():
33+
# We always have os
34+
assert_good('os')
35+
# Subpackage
36+
assert_good('os.path')
37+
# We never have package _not_a_package
38+
assert_bad('_not_a_package')
39+
40+
41+
def test_versions():
42+
fake_name = '_a_fake_package'
43+
fake_pkg = types.ModuleType(fake_name)
44+
assert_false('fake_pkg' in sys.modules)
45+
# Not inserted yet
46+
assert_bad(fake_name)
47+
try:
48+
sys.modules[fake_name] = fake_pkg
49+
# No __version__ yet
50+
assert_good(fake_name) # With no version check
51+
assert_bad(fake_name, '1.0')
52+
# We can make an arbitrary callable to check version
53+
assert_good(fake_name, lambda pkg: True)
54+
# Now add a version
55+
fake_pkg.__version__ = '2.0'
56+
# We have fake_pkg > 1.0
57+
for min_ver in (None, '1.0', LooseVersion('1.0'), lambda pkg: True):
58+
assert_good(fake_name, min_ver)
59+
# We never have fake_pkg > 100.0
60+
for min_ver in ('100.0', LooseVersion('100.0'), lambda pkg: False):
61+
assert_bad(fake_name, min_ver)
62+
# Check error string for bad version
63+
pkg, _, _ = optional_package(fake_name, min_version='3.0')
64+
try:
65+
pkg.some_method
66+
except TripWireError as err:
67+
assert_equal(str(err),
68+
'These functions need _a_fake_package version >= 3.0')
69+
finally:
70+
del sys.modules[fake_name]

0 commit comments

Comments
 (0)