diff --git a/nibabel/deprecated.py b/nibabel/deprecated.py index eb3252fe7e..c353071954 100644 --- a/nibabel/deprecated.py +++ b/nibabel/deprecated.py @@ -2,12 +2,15 @@ """ from __future__ import annotations +import typing as ty import warnings -from typing import Type from .deprecator import Deprecator from .pkg_info import cmp_pkg_version +if ty.TYPE_CHECKING: # pragma: no cover + P = ty.ParamSpec('P') + class ModuleProxy: """Proxy for module that may not yet have been imported @@ -30,14 +33,14 @@ class ModuleProxy: module. """ - def __init__(self, module_name): + def __init__(self, module_name: str) -> None: self._module_name = module_name - def __getattr__(self, key): + def __getattr__(self, key: str) -> ty.Any: mod = __import__(self._module_name, fromlist=['']) return getattr(mod, key) - def __repr__(self): + def __repr__(self) -> str: return f'' @@ -60,7 +63,7 @@ class FutureWarningMixin: warn_message = 'This class will be removed in future versions' - def __init__(self, *args, **kwargs): + def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: warnings.warn(self.warn_message, FutureWarning, stacklevel=2) super().__init__(*args, **kwargs) @@ -85,12 +88,12 @@ def alert_future_error( msg: str, version: str, *, - warning_class: Type[Warning] = FutureWarning, - error_class: Type[Exception] = RuntimeError, + warning_class: type[Warning] = FutureWarning, + error_class: type[Exception] = RuntimeError, warning_rec: str = '', error_rec: str = '', stacklevel: int = 2, -): +) -> None: """Warn or error with appropriate messages for changing functionality. Parameters diff --git a/nibabel/deprecator.py b/nibabel/deprecator.py index 251e10d64c..3ef6b45066 100644 --- a/nibabel/deprecator.py +++ b/nibabel/deprecator.py @@ -1,10 +1,16 @@ """Class for recording and reporting deprecations """ +from __future__ import annotations import functools import re +import typing as ty import warnings +if ty.TYPE_CHECKING: # pragma: no cover + T = ty.TypeVar('T') + P = ty.ParamSpec('P') + _LEADING_WHITE = re.compile(r'^(\s*)') TESTSETUP = """ @@ -38,7 +44,7 @@ class ExpiredDeprecationError(RuntimeError): pass -def _ensure_cr(text): +def _ensure_cr(text: str) -> str: """Remove trailing whitespace and add carriage return Ensures that `text` always ends with a carriage return @@ -46,7 +52,12 @@ def _ensure_cr(text): return text.rstrip() + '\n' -def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''): +def _add_dep_doc( + old_doc: str, + dep_doc: str, + setup: str = '', + cleanup: str = '', +) -> str: """Add deprecation message `dep_doc` to docstring in `old_doc` Parameters @@ -55,6 +66,10 @@ def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''): Docstring from some object. dep_doc : str Deprecation warning to add to top of docstring, after initial line. + setup : str, optional + Doctest setup text + cleanup : str, optional + Doctest teardown text Returns ------- @@ -76,7 +91,9 @@ def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''): if next_line >= len(old_lines): # nothing following first paragraph, just append message return old_doc + '\n' + dep_doc - indent = _LEADING_WHITE.match(old_lines[next_line]).group() + leading_white = _LEADING_WHITE.match(old_lines[next_line]) + assert leading_white is not None # Type narrowing, since this always matches + indent = leading_white.group() setup_lines = [indent + L for L in setup.splitlines()] dep_lines = [indent + L for L in [''] + dep_doc.splitlines() + ['']] cleanup_lines = [indent + L for L in cleanup.splitlines()] @@ -113,15 +130,15 @@ class Deprecator: def __init__( self, - version_comparator, - warn_class=DeprecationWarning, - error_class=ExpiredDeprecationError, - ): + version_comparator: ty.Callable[[str], int], + warn_class: type[Warning] = DeprecationWarning, + error_class: type[Exception] = ExpiredDeprecationError, + ) -> None: self.version_comparator = version_comparator self.warn_class = warn_class self.error_class = error_class - def is_bad_version(self, version_str): + def is_bad_version(self, version_str: str) -> bool: """Return True if `version_str` is too high Tests `version_str` with ``self.version_comparator`` @@ -139,7 +156,14 @@ def is_bad_version(self, version_str): """ return self.version_comparator(version_str) == -1 - def __call__(self, message, since='', until='', warn_class=None, error_class=None): + def __call__( + self, + message: str, + since: str = '', + until: str = '', + warn_class: type[Warning] | None = None, + error_class: type[Exception] | None = None, + ) -> ty.Callable[[ty.Callable[P, T]], ty.Callable[P, T]]: """Return decorator function function for deprecation warning / error Parameters @@ -164,8 +188,8 @@ def __call__(self, message, since='', until='', warn_class=None, error_class=Non deprecator : func Function returning a decorator. """ - warn_class = warn_class or self.warn_class - error_class = error_class or self.error_class + exception = error_class if error_class is not None else self.error_class + warning = warn_class if warn_class is not None else self.warn_class messages = [message] if (since, until) != ('', ''): messages.append('') @@ -174,19 +198,21 @@ def __call__(self, message, since='', until='', warn_class=None, error_class=Non if until: messages.append( f"* {'Raises' if self.is_bad_version(until) else 'Will raise'} " - f'{error_class} as of version: {until}' + f'{exception} as of version: {until}' ) message = '\n'.join(messages) - def deprecator(func): + def deprecator(func: ty.Callable[P, T]) -> ty.Callable[P, T]: @functools.wraps(func) - def deprecated_func(*args, **kwargs): + def deprecated_func(*args: P.args, **kwargs: P.kwargs) -> T: if until and self.is_bad_version(until): - raise error_class(message) - warnings.warn(message, warn_class, stacklevel=2) + raise exception(message) + warnings.warn(message, warning, stacklevel=2) return func(*args, **kwargs) keep_doc = deprecated_func.__doc__ + if keep_doc is None: + keep_doc = '' setup = TESTSETUP cleanup = TESTCLEANUP # After expiration, remove all but the first paragraph. diff --git a/nibabel/onetime.py b/nibabel/onetime.py index 8156b1a403..7c723d4c83 100644 --- a/nibabel/onetime.py +++ b/nibabel/onetime.py @@ -19,6 +19,12 @@ [2] Python data model, https://docs.python.org/reference/datamodel.html """ +from __future__ import annotations + +import typing as ty + +InstanceT = ty.TypeVar('InstanceT') +T = ty.TypeVar('T') from nibabel.deprecated import deprecate_with_version @@ -96,26 +102,24 @@ class ResetMixin: 10.0 """ - def reset(self): + def reset(self) -> None: """Reset all OneTimeProperty attributes that may have fired already.""" - instdict = self.__dict__ - classdict = self.__class__.__dict__ # To reset them, we simply remove them from the instance dict. At that # point, it's as if they had never been computed. On the next access, # the accessor function from the parent class will be called, simply # because that's how the python descriptor protocol works. - for mname, mval in classdict.items(): - if mname in instdict and isinstance(mval, OneTimeProperty): + for mname, mval in self.__class__.__dict__.items(): + if mname in self.__dict__ and isinstance(mval, OneTimeProperty): delattr(self, mname) -class OneTimeProperty: +class OneTimeProperty(ty.Generic[T]): """A descriptor to make special properties that become normal attributes. This is meant to be used mostly by the auto_attr decorator in this module. """ - def __init__(self, func): + def __init__(self, func: ty.Callable[[InstanceT], T]) -> None: """Create a OneTimeProperty instance. Parameters @@ -128,24 +132,35 @@ def __init__(self, func): """ self.getter = func self.name = func.__name__ + self.__doc__ = func.__doc__ + + @ty.overload + def __get__( + self, obj: None, objtype: type[InstanceT] | None = None + ) -> ty.Callable[[InstanceT], T]: + ... # pragma: no cover + + @ty.overload + def __get__(self, obj: InstanceT, objtype: type[InstanceT] | None = None) -> T: + ... # pragma: no cover - def __get__(self, obj, type=None): + def __get__( + self, obj: InstanceT | None, objtype: type[InstanceT] | None = None + ) -> T | ty.Callable[[InstanceT], T]: """This will be called on attribute access on the class or instance.""" if obj is None: # Being called on the class, return the original function. This # way, introspection works on the class. - # return func return self.getter - # Errors in the following line are errors in setting a - # OneTimeProperty + # Errors in the following line are errors in setting a OneTimeProperty val = self.getter(obj) - setattr(obj, self.name, val) + obj.__dict__[self.name] = val return val -def auto_attr(func): +def auto_attr(func: ty.Callable[[InstanceT], T]) -> OneTimeProperty[T]: """Decorator to create OneTimeProperty attributes. Parameters diff --git a/nibabel/optpkg.py b/nibabel/optpkg.py index d1eb9d17d5..b59a89bb35 100644 --- a/nibabel/optpkg.py +++ b/nibabel/optpkg.py @@ -1,20 +1,31 @@ """Routines to support optional packages""" +from __future__ import annotations + +import typing as ty +from types import ModuleType + from packaging.version import Version from .tripwire import TripWire -def _check_pkg_version(pkg, min_version): - # Default version checking function - if isinstance(min_version, str): - min_version = Version(min_version) - try: - return min_version <= Version(pkg.__version__) - except AttributeError: +def _check_pkg_version(min_version: str | Version) -> ty.Callable[[ModuleType], bool]: + min_ver = Version(min_version) if isinstance(min_version, str) else min_version + + def check(pkg: ModuleType) -> bool: + pkg_ver = getattr(pkg, '__version__', None) + if isinstance(pkg_ver, str): + return min_ver <= Version(pkg_ver) return False + return check + -def optional_package(name, trip_msg=None, min_version=None): +def optional_package( + name: str, + trip_msg: str | None = None, + min_version: str | Version | ty.Callable[[ModuleType], bool] | None = None, +) -> tuple[ModuleType | TripWire, bool, ty.Callable[[], None]]: """Return package-like thing and module setup for package `name` Parameters @@ -81,7 +92,7 @@ def optional_package(name, trip_msg=None, min_version=None): elif min_version is None: check_version = lambda pkg: True else: - check_version = lambda pkg: _check_pkg_version(pkg, min_version) + check_version = _check_pkg_version(min_version) # fromlist=[''] results in submodule being returned, rather than the top # level module. See help(__import__) fromlist = [''] if '.' in name else [] @@ -107,11 +118,11 @@ def optional_package(name, trip_msg=None, min_version=None): trip_msg = ( f'We need package {name} for these functions, but ``import {name}`` raised {exc}' ) - pkg = TripWire(trip_msg) + trip = TripWire(trip_msg) - def setup_module(): + def setup_module() -> None: import unittest raise unittest.SkipTest(f'No {name} for these tests') - return pkg, False, setup_module + return trip, False, setup_module diff --git a/nibabel/pkg_info.py b/nibabel/pkg_info.py index 73dfd92ed2..061cc3e6d1 100644 --- a/nibabel/pkg_info.py +++ b/nibabel/pkg_info.py @@ -14,7 +14,7 @@ COMMIT_HASH = '$Format:%h$' -def _cmp(a, b) -> int: +def _cmp(a: Version, b: Version) -> int: """Implementation of ``cmp`` for Python 3""" return (a > b) - (a < b) @@ -113,7 +113,7 @@ def pkg_commit_hash(pkg_path: str | None = None) -> tuple[str, str]: return '(none found)', '' -def get_pkg_info(pkg_path: str) -> dict: +def get_pkg_info(pkg_path: str) -> dict[str, str]: """Return dict describing the context of this package Parameters diff --git a/nibabel/processing.py b/nibabel/processing.py index d0a01b52b3..c7bd3888de 100644 --- a/nibabel/processing.py +++ b/nibabel/processing.py @@ -20,7 +20,7 @@ from .optpkg import optional_package -spnd, _, _ = optional_package('scipy.ndimage') +spnd = optional_package('scipy.ndimage')[0] from .affines import AffineError, append_diag, from_matvec, rescale_affine, to_matvec from .imageclasses import spatial_axes_first diff --git a/nibabel/testing/helpers.py b/nibabel/testing/helpers.py index 35b13049f1..2f25a354d7 100644 --- a/nibabel/testing/helpers.py +++ b/nibabel/testing/helpers.py @@ -6,7 +6,7 @@ from ..optpkg import optional_package -_, have_scipy, _ = optional_package('scipy.io') +have_scipy = optional_package('scipy.io')[1] from numpy.testing import assert_array_equal diff --git a/nibabel/tripwire.py b/nibabel/tripwire.py index 3b6ecfbb40..d0c3d4c50c 100644 --- a/nibabel/tripwire.py +++ b/nibabel/tripwire.py @@ -1,5 +1,6 @@ """Class to raise error for missing modules or other misfortunes """ +from typing import Any class TripWireError(AttributeError): @@ -11,7 +12,7 @@ class TripWireError(AttributeError): # is not present. -def is_tripwire(obj): +def is_tripwire(obj: Any) -> bool: """Returns True if `obj` appears to be a TripWire object Examples @@ -44,9 +45,9 @@ class TripWire: TripWireError: We do not have a_module """ - def __init__(self, msg): + def __init__(self, msg: str) -> None: self._msg = msg - def __getattr__(self, attr_name): + def __getattr__(self, attr_name: str) -> Any: """Raise informative error accessing attributes""" raise TripWireError(self._msg)