Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions nibabel/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
"""
from __future__ import annotations

import typing as ty
import warnings
from typing import Type

from .deprecator import Deprecator
from .pkg_info import cmp_pkg_version

if ty.TYPE_CHECKING: # pragma: no cover
P = ty.ParamSpec('P')


class ModuleProxy:
"""Proxy for module that may not yet have been imported
Expand All @@ -30,14 +33,14 @@ class ModuleProxy:
module.
"""

def __init__(self, module_name):
def __init__(self, module_name: str) -> None:
self._module_name = module_name

def __getattr__(self, key):
def __getattr__(self, key: str) -> ty.Any:
mod = __import__(self._module_name, fromlist=[''])
return getattr(mod, key)

def __repr__(self):
def __repr__(self) -> str:
return f'<module proxy for {self._module_name}>'


Expand All @@ -60,7 +63,7 @@ class FutureWarningMixin:

warn_message = 'This class will be removed in future versions'

def __init__(self, *args, **kwargs):
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
warnings.warn(self.warn_message, FutureWarning, stacklevel=2)
super().__init__(*args, **kwargs)

Expand All @@ -85,12 +88,12 @@ def alert_future_error(
msg: str,
version: str,
*,
warning_class: Type[Warning] = FutureWarning,
error_class: Type[Exception] = RuntimeError,
warning_class: type[Warning] = FutureWarning,
error_class: type[Exception] = RuntimeError,
warning_rec: str = '',
error_rec: str = '',
stacklevel: int = 2,
):
) -> None:
"""Warn or error with appropriate messages for changing functionality.

Parameters
Expand Down
58 changes: 42 additions & 16 deletions nibabel/deprecator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""Class for recording and reporting deprecations
"""
from __future__ import annotations

import functools
import re
import typing as ty
import warnings

if ty.TYPE_CHECKING: # pragma: no cover
T = ty.TypeVar('T')
P = ty.ParamSpec('P')

_LEADING_WHITE = re.compile(r'^(\s*)')

TESTSETUP = """
Expand Down Expand Up @@ -38,15 +44,20 @@ class ExpiredDeprecationError(RuntimeError):
pass


def _ensure_cr(text):
def _ensure_cr(text: str) -> str:
"""Remove trailing whitespace and add carriage return

Ensures that `text` always ends with a carriage return
"""
return text.rstrip() + '\n'


def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
def _add_dep_doc(
old_doc: str,
dep_doc: str,
setup: str = '',
cleanup: str = '',
) -> str:
"""Add deprecation message `dep_doc` to docstring in `old_doc`

Parameters
Expand All @@ -55,6 +66,10 @@ def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
Docstring from some object.
dep_doc : str
Deprecation warning to add to top of docstring, after initial line.
setup : str, optional
Doctest setup text
cleanup : str, optional
Doctest teardown text

Returns
-------
Expand All @@ -76,7 +91,9 @@ def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
if next_line >= len(old_lines):
# nothing following first paragraph, just append message
return old_doc + '\n' + dep_doc
indent = _LEADING_WHITE.match(old_lines[next_line]).group()
leading_white = _LEADING_WHITE.match(old_lines[next_line])
assert leading_white is not None # Type narrowing, since this always matches
indent = leading_white.group()
setup_lines = [indent + L for L in setup.splitlines()]
dep_lines = [indent + L for L in [''] + dep_doc.splitlines() + ['']]
cleanup_lines = [indent + L for L in cleanup.splitlines()]
Expand Down Expand Up @@ -113,15 +130,15 @@ class Deprecator:

def __init__(
self,
version_comparator,
warn_class=DeprecationWarning,
error_class=ExpiredDeprecationError,
):
version_comparator: ty.Callable[[str], int],
warn_class: type[Warning] = DeprecationWarning,
error_class: type[Exception] = ExpiredDeprecationError,
) -> None:
self.version_comparator = version_comparator
self.warn_class = warn_class
self.error_class = error_class

def is_bad_version(self, version_str):
def is_bad_version(self, version_str: str) -> bool:
"""Return True if `version_str` is too high

Tests `version_str` with ``self.version_comparator``
Expand All @@ -139,7 +156,14 @@ def is_bad_version(self, version_str):
"""
return self.version_comparator(version_str) == -1

def __call__(self, message, since='', until='', warn_class=None, error_class=None):
def __call__(
self,
message: str,
since: str = '',
until: str = '',
warn_class: type[Warning] | None = None,
error_class: type[Exception] | None = None,
) -> ty.Callable[[ty.Callable[P, T]], ty.Callable[P, T]]:
"""Return decorator function function for deprecation warning / error

Parameters
Expand All @@ -164,8 +188,8 @@ def __call__(self, message, since='', until='', warn_class=None, error_class=Non
deprecator : func
Function returning a decorator.
"""
warn_class = warn_class or self.warn_class
error_class = error_class or self.error_class
exception = error_class if error_class is not None else self.error_class
warning = warn_class if warn_class is not None else self.warn_class
messages = [message]
if (since, until) != ('', ''):
messages.append('')
Expand All @@ -174,19 +198,21 @@ def __call__(self, message, since='', until='', warn_class=None, error_class=Non
if until:
messages.append(
f"* {'Raises' if self.is_bad_version(until) else 'Will raise'} "
f'{error_class} as of version: {until}'
f'{exception} as of version: {until}'
)
message = '\n'.join(messages)

def deprecator(func):
def deprecator(func: ty.Callable[P, T]) -> ty.Callable[P, T]:
@functools.wraps(func)
def deprecated_func(*args, **kwargs):
def deprecated_func(*args: P.args, **kwargs: P.kwargs) -> T:
if until and self.is_bad_version(until):
raise error_class(message)
warnings.warn(message, warn_class, stacklevel=2)
raise exception(message)
warnings.warn(message, warning, stacklevel=2)
return func(*args, **kwargs)

keep_doc = deprecated_func.__doc__
if keep_doc is None:
keep_doc = ''
setup = TESTSETUP
cleanup = TESTCLEANUP
# After expiration, remove all but the first paragraph.
Expand Down
41 changes: 28 additions & 13 deletions nibabel/onetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

[2] Python data model, https://docs.python.org/reference/datamodel.html
"""
from __future__ import annotations

import typing as ty

InstanceT = ty.TypeVar('InstanceT')
T = ty.TypeVar('T')

from nibabel.deprecated import deprecate_with_version

Expand Down Expand Up @@ -96,26 +102,24 @@ class ResetMixin:
10.0
"""

def reset(self):
def reset(self) -> None:
"""Reset all OneTimeProperty attributes that may have fired already."""
instdict = self.__dict__
classdict = self.__class__.__dict__
# To reset them, we simply remove them from the instance dict. At that
# point, it's as if they had never been computed. On the next access,
# the accessor function from the parent class will be called, simply
# because that's how the python descriptor protocol works.
for mname, mval in classdict.items():
if mname in instdict and isinstance(mval, OneTimeProperty):
for mname, mval in self.__class__.__dict__.items():
if mname in self.__dict__ and isinstance(mval, OneTimeProperty):
delattr(self, mname)


class OneTimeProperty:
class OneTimeProperty(ty.Generic[T]):
"""A descriptor to make special properties that become normal attributes.

This is meant to be used mostly by the auto_attr decorator in this module.
"""

def __init__(self, func):
def __init__(self, func: ty.Callable[[InstanceT], T]) -> None:
"""Create a OneTimeProperty instance.

Parameters
Expand All @@ -128,24 +132,35 @@ def __init__(self, func):
"""
self.getter = func
self.name = func.__name__
self.__doc__ = func.__doc__

@ty.overload
def __get__(
self, obj: None, objtype: type[InstanceT] | None = None
) -> ty.Callable[[InstanceT], T]:
... # pragma: no cover

@ty.overload
def __get__(self, obj: InstanceT, objtype: type[InstanceT] | None = None) -> T:
... # pragma: no cover

def __get__(self, obj, type=None):
def __get__(
self, obj: InstanceT | None, objtype: type[InstanceT] | None = None
) -> T | ty.Callable[[InstanceT], T]:
"""This will be called on attribute access on the class or instance."""
if obj is None:
# Being called on the class, return the original function. This
# way, introspection works on the class.
# return func
return self.getter

# Errors in the following line are errors in setting a
# OneTimeProperty
# Errors in the following line are errors in setting a OneTimeProperty
val = self.getter(obj)

setattr(obj, self.name, val)
obj.__dict__[self.name] = val
return val


def auto_attr(func):
def auto_attr(func: ty.Callable[[InstanceT], T]) -> OneTimeProperty[T]:
"""Decorator to create OneTimeProperty attributes.

Parameters
Expand Down
35 changes: 23 additions & 12 deletions nibabel/optpkg.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
"""Routines to support optional packages"""
from __future__ import annotations

import typing as ty
from types import ModuleType

from packaging.version import Version

from .tripwire import TripWire


def _check_pkg_version(pkg, min_version):
# Default version checking function
if isinstance(min_version, str):
min_version = Version(min_version)
try:
return min_version <= Version(pkg.__version__)
except AttributeError:
def _check_pkg_version(min_version: str | Version) -> ty.Callable[[ModuleType], bool]:
min_ver = Version(min_version) if isinstance(min_version, str) else min_version

def check(pkg: ModuleType) -> bool:
pkg_ver = getattr(pkg, '__version__', None)
if isinstance(pkg_ver, str):
return min_ver <= Version(pkg_ver)
return False

return check


def optional_package(name, trip_msg=None, min_version=None):
def optional_package(
name: str,
trip_msg: str | None = None,
min_version: str | Version | ty.Callable[[ModuleType], bool] | None = None,
) -> tuple[ModuleType | TripWire, bool, ty.Callable[[], None]]:
"""Return package-like thing and module setup for package `name`

Parameters
Expand Down Expand Up @@ -81,7 +92,7 @@ def optional_package(name, trip_msg=None, min_version=None):
elif min_version is None:
check_version = lambda pkg: True
else:
check_version = lambda pkg: _check_pkg_version(pkg, min_version)
check_version = _check_pkg_version(min_version)
# fromlist=[''] results in submodule being returned, rather than the top
# level module. See help(__import__)
fromlist = [''] if '.' in name else []
Expand All @@ -107,11 +118,11 @@ def optional_package(name, trip_msg=None, min_version=None):
trip_msg = (
f'We need package {name} for these functions, but ``import {name}`` raised {exc}'
)
pkg = TripWire(trip_msg)
trip = TripWire(trip_msg)

def setup_module():
def setup_module() -> None:
import unittest

raise unittest.SkipTest(f'No {name} for these tests')

return pkg, False, setup_module
return trip, False, setup_module
4 changes: 2 additions & 2 deletions nibabel/pkg_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
COMMIT_HASH = '$Format:%h$'


def _cmp(a, b) -> int:
def _cmp(a: Version, b: Version) -> int:
"""Implementation of ``cmp`` for Python 3"""
return (a > b) - (a < b)

Expand Down Expand Up @@ -113,7 +113,7 @@ def pkg_commit_hash(pkg_path: str | None = None) -> tuple[str, str]:
return '(none found)', '<not found>'


def get_pkg_info(pkg_path: str) -> dict:
def get_pkg_info(pkg_path: str) -> dict[str, str]:
"""Return dict describing the context of this package

Parameters
Expand Down
2 changes: 1 addition & 1 deletion nibabel/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .optpkg import optional_package

spnd, _, _ = optional_package('scipy.ndimage')
spnd = optional_package('scipy.ndimage')[0]

from .affines import AffineError, append_diag, from_matvec, rescale_affine, to_matvec
from .imageclasses import spatial_axes_first
Expand Down
2 changes: 1 addition & 1 deletion nibabel/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..optpkg import optional_package

_, have_scipy, _ = optional_package('scipy.io')
have_scipy = optional_package('scipy.io')[1]

from numpy.testing import assert_array_equal

Expand Down
Loading