Skip to content

Commit 47fb865

Browse files
committed
TYP: Annotate tripwire and optpkg modules
Refactor _check_pkg_version to make types clearer. Partial application and lambdas seem hard to mypy.
1 parent aa0bfff commit 47fb865

File tree

4 files changed

+29
-17
lines changed

4 files changed

+29
-17
lines changed

nibabel/optpkg.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
"""Routines to support optional packages"""
2+
from __future__ import annotations
3+
4+
import typing as ty
5+
from types import ModuleType
6+
27
from packaging.version import Version
38

49
from .tripwire import TripWire
510

611

7-
def _check_pkg_version(pkg, min_version):
8-
# Default version checking function
9-
if isinstance(min_version, str):
10-
min_version = Version(min_version)
11-
try:
12-
return min_version <= Version(pkg.__version__)
13-
except AttributeError:
12+
def _check_pkg_version(min_version: str | Version) -> ty.Callable[[ModuleType], bool]:
13+
min_ver = Version(min_version) if isinstance(min_version, str) else min_version
14+
15+
def check(pkg: ModuleType) -> bool:
16+
pkg_ver = getattr(pkg, '__version__', None)
17+
if isinstance(pkg_ver, str):
18+
return min_ver <= Version(pkg_ver)
1419
return False
1520

21+
return check
22+
1623

17-
def optional_package(name, trip_msg=None, min_version=None):
24+
def optional_package(
25+
name: str,
26+
trip_msg: str | None = None,
27+
min_version: str | Version | ty.Callable[[ModuleType], bool] | None = None,
28+
) -> tuple[ModuleType | TripWire, bool, ty.Callable[[], None]]:
1829
"""Return package-like thing and module setup for package `name`
1930
2031
Parameters
@@ -81,7 +92,7 @@ def optional_package(name, trip_msg=None, min_version=None):
8192
elif min_version is None:
8293
check_version = lambda pkg: True
8394
else:
84-
check_version = lambda pkg: _check_pkg_version(pkg, min_version)
95+
check_version = _check_pkg_version(min_version)
8596
# fromlist=[''] results in submodule being returned, rather than the top
8697
# level module. See help(__import__)
8798
fromlist = [''] if '.' in name else []
@@ -107,11 +118,11 @@ def optional_package(name, trip_msg=None, min_version=None):
107118
trip_msg = (
108119
f'We need package {name} for these functions, but ``import {name}`` raised {exc}'
109120
)
110-
pkg = TripWire(trip_msg)
121+
trip = TripWire(trip_msg)
111122

112-
def setup_module():
123+
def setup_module() -> None:
113124
import unittest
114125

115126
raise unittest.SkipTest(f'No {name} for these tests')
116127

117-
return pkg, False, setup_module
128+
return trip, False, setup_module

nibabel/processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from .optpkg import optional_package
2222

23-
spnd, _, _ = optional_package('scipy.ndimage')
23+
spnd = optional_package('scipy.ndimage')[0]
2424

2525
from .affines import AffineError, append_diag, from_matvec, rescale_affine, to_matvec
2626
from .imageclasses import spatial_axes_first

nibabel/testing/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ..optpkg import optional_package
88

9-
_, have_scipy, _ = optional_package('scipy.io')
9+
have_scipy = optional_package('scipy.io')[1]
1010

1111
from numpy.testing import assert_array_equal
1212

nibabel/tripwire.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Class to raise error for missing modules or other misfortunes
22
"""
3+
from typing import Any
34

45

56
class TripWireError(AttributeError):
@@ -11,7 +12,7 @@ class TripWireError(AttributeError):
1112
# is not present.
1213

1314

14-
def is_tripwire(obj):
15+
def is_tripwire(obj: Any) -> bool:
1516
"""Returns True if `obj` appears to be a TripWire object
1617
1718
Examples
@@ -44,9 +45,9 @@ class TripWire:
4445
TripWireError: We do not have a_module
4546
"""
4647

47-
def __init__(self, msg):
48+
def __init__(self, msg: str):
4849
self._msg = msg
4950

50-
def __getattr__(self, attr_name):
51+
def __getattr__(self, attr_name: str) -> Any:
5152
"""Raise informative error accessing attributes"""
5253
raise TripWireError(self._msg)

0 commit comments

Comments
 (0)