Skip to content

Commit 72d7eff

Browse files
committed
TYP: Annotate deprecation and versioning machinery
1 parent 47fb865 commit 72d7eff

File tree

3 files changed

+55
-26
lines changed

3 files changed

+55
-26
lines changed

nibabel/deprecated.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
"""
33
from __future__ import annotations
44

5+
import typing as ty
56
import warnings
6-
from typing import Type
77

88
from .deprecator import Deprecator
99
from .pkg_info import cmp_pkg_version
1010

11+
if ty.TYPE_CHECKING: # pragma: no cover
12+
P = ty.ParamSpec('P')
13+
1114

1215
class ModuleProxy:
1316
"""Proxy for module that may not yet have been imported
@@ -30,14 +33,14 @@ class ModuleProxy:
3033
module.
3134
"""
3235

33-
def __init__(self, module_name):
36+
def __init__(self, module_name: str):
3437
self._module_name = module_name
3538

36-
def __getattr__(self, key):
39+
def __getattr__(self, key: str) -> ty.Any:
3740
mod = __import__(self._module_name, fromlist=[''])
3841
return getattr(mod, key)
3942

40-
def __repr__(self):
43+
def __repr__(self) -> str:
4144
return f'<module proxy for {self._module_name}>'
4245

4346

@@ -60,7 +63,7 @@ class FutureWarningMixin:
6063

6164
warn_message = 'This class will be removed in future versions'
6265

63-
def __init__(self, *args, **kwargs):
66+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
6467
warnings.warn(self.warn_message, FutureWarning, stacklevel=2)
6568
super().__init__(*args, **kwargs)
6669

@@ -85,12 +88,12 @@ def alert_future_error(
8588
msg: str,
8689
version: str,
8790
*,
88-
warning_class: Type[Warning] = FutureWarning,
89-
error_class: Type[Exception] = RuntimeError,
91+
warning_class: type[Warning] = FutureWarning,
92+
error_class: type[Exception] = RuntimeError,
9093
warning_rec: str = '',
9194
error_rec: str = '',
9295
stacklevel: int = 2,
93-
):
96+
) -> None:
9497
"""Warn or error with appropriate messages for changing functionality.
9598
9699
Parameters

nibabel/deprecator.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
"""Class for recording and reporting deprecations
22
"""
3+
from __future__ import annotations
34

45
import functools
56
import re
7+
import typing as ty
68
import warnings
79

10+
if ty.TYPE_CHECKING: # pragma: no cover
11+
T = ty.TypeVar('T')
12+
P = ty.ParamSpec('P')
13+
814
_LEADING_WHITE = re.compile(r'^(\s*)')
915

1016
TESTSETUP = """
@@ -38,15 +44,20 @@ class ExpiredDeprecationError(RuntimeError):
3844
pass
3945

4046

41-
def _ensure_cr(text):
47+
def _ensure_cr(text: str) -> str:
4248
"""Remove trailing whitespace and add carriage return
4349
4450
Ensures that `text` always ends with a carriage return
4551
"""
4652
return text.rstrip() + '\n'
4753

4854

49-
def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
55+
def _add_dep_doc(
56+
old_doc: str,
57+
dep_doc: str,
58+
setup: str = '',
59+
cleanup: str = '',
60+
) -> str:
5061
"""Add deprecation message `dep_doc` to docstring in `old_doc`
5162
5263
Parameters
@@ -55,6 +66,10 @@ def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
5566
Docstring from some object.
5667
dep_doc : str
5768
Deprecation warning to add to top of docstring, after initial line.
69+
setup : str, optional
70+
Doctest setup text
71+
cleanup : str, optional
72+
Doctest teardown text
5873
5974
Returns
6075
-------
@@ -76,7 +91,9 @@ def _add_dep_doc(old_doc, dep_doc, setup='', cleanup=''):
7691
if next_line >= len(old_lines):
7792
# nothing following first paragraph, just append message
7893
return old_doc + '\n' + dep_doc
79-
indent = _LEADING_WHITE.match(old_lines[next_line]).group()
94+
leading_white = _LEADING_WHITE.match(old_lines[next_line])
95+
assert leading_white is not None # Type narrowing, since this always matches
96+
indent = leading_white.group()
8097
setup_lines = [indent + L for L in setup.splitlines()]
8198
dep_lines = [indent + L for L in [''] + dep_doc.splitlines() + ['']]
8299
cleanup_lines = [indent + L for L in cleanup.splitlines()]
@@ -113,15 +130,15 @@ class Deprecator:
113130

114131
def __init__(
115132
self,
116-
version_comparator,
117-
warn_class=DeprecationWarning,
118-
error_class=ExpiredDeprecationError,
119-
):
133+
version_comparator: ty.Callable[[str], int],
134+
warn_class: type[Warning] = DeprecationWarning,
135+
error_class: type[Exception] = ExpiredDeprecationError,
136+
) -> None:
120137
self.version_comparator = version_comparator
121138
self.warn_class = warn_class
122139
self.error_class = error_class
123140

124-
def is_bad_version(self, version_str):
141+
def is_bad_version(self, version_str: str) -> bool:
125142
"""Return True if `version_str` is too high
126143
127144
Tests `version_str` with ``self.version_comparator``
@@ -139,7 +156,14 @@ def is_bad_version(self, version_str):
139156
"""
140157
return self.version_comparator(version_str) == -1
141158

142-
def __call__(self, message, since='', until='', warn_class=None, error_class=None):
159+
def __call__(
160+
self,
161+
message: str,
162+
since: str = '',
163+
until: str = '',
164+
warn_class: type[Warning] | None = None,
165+
error_class: type[Exception] | None = None,
166+
) -> ty.Callable[[ty.Callable[P, T]], ty.Callable[P, T]]:
143167
"""Return decorator function function for deprecation warning / error
144168
145169
Parameters
@@ -164,8 +188,8 @@ def __call__(self, message, since='', until='', warn_class=None, error_class=Non
164188
deprecator : func
165189
Function returning a decorator.
166190
"""
167-
warn_class = warn_class or self.warn_class
168-
error_class = error_class or self.error_class
191+
exception = error_class if error_class is not None else self.error_class
192+
warning = warn_class if warn_class is not None else self.warn_class
169193
messages = [message]
170194
if (since, until) != ('', ''):
171195
messages.append('')
@@ -174,19 +198,21 @@ def __call__(self, message, since='', until='', warn_class=None, error_class=Non
174198
if until:
175199
messages.append(
176200
f"* {'Raises' if self.is_bad_version(until) else 'Will raise'} "
177-
f'{error_class} as of version: {until}'
201+
f'{exception} as of version: {until}'
178202
)
179203
message = '\n'.join(messages)
180204

181-
def deprecator(func):
205+
def deprecator(func: ty.Callable[P, T]) -> ty.Callable[P, T]:
182206
@functools.wraps(func)
183-
def deprecated_func(*args, **kwargs):
207+
def deprecated_func(*args: P.args, **kwargs: P.kwargs) -> T:
184208
if until and self.is_bad_version(until):
185-
raise error_class(message)
186-
warnings.warn(message, warn_class, stacklevel=2)
209+
raise exception(message)
210+
warnings.warn(message, warning, stacklevel=2)
187211
return func(*args, **kwargs)
188212

189213
keep_doc = deprecated_func.__doc__
214+
if keep_doc is None:
215+
keep_doc = ''
190216
setup = TESTSETUP
191217
cleanup = TESTCLEANUP
192218
# After expiration, remove all but the first paragraph.

nibabel/pkg_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
COMMIT_HASH = '$Format:%h$'
1515

1616

17-
def _cmp(a, b) -> int:
17+
def _cmp(a: Version, b: Version) -> int:
1818
"""Implementation of ``cmp`` for Python 3"""
1919
return (a > b) - (a < b)
2020

@@ -113,7 +113,7 @@ def pkg_commit_hash(pkg_path: str | None = None) -> tuple[str, str]:
113113
return '(none found)', '<not found>'
114114

115115

116-
def get_pkg_info(pkg_path: str) -> dict:
116+
def get_pkg_info(pkg_path: str) -> dict[str, str]:
117117
"""Return dict describing the context of this package
118118
119119
Parameters

0 commit comments

Comments
 (0)