Skip to content

Commit 41b5835

Browse files
committed
Add type annotation in _helpers
1 parent ab5a076 commit 41b5835

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

stagpy/_helpers.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
11
"""Various helper functions and classes."""
22

3+
from __future__ import annotations
34
from inspect import getdoc
5+
import typing
46

57
import matplotlib.pyplot as plt
68

79
from . import conf
810

11+
if typing.TYPE_CHECKING:
12+
from typing import Optional, Any, List, Callable
13+
from matplotlib.figure import Figure
14+
from numpy import ndarray
915

10-
def out_name(stem, timestep=None):
16+
17+
def out_name(stem: str, timestep: Optional[int] = None) -> str:
1118
"""Return StagPy out file name.
1219
1320
Args:
14-
stem (str): short description of file content.
15-
timestep (int): timestep if relevant.
21+
stem: short description of file content.
22+
timestep: timestep if relevant.
1623
1724
Returns:
18-
str: the output file name.
25+
the output file name.
1926
2027
Other Parameters:
21-
conf.core.outname (str): the generic name stem, defaults to
22-
``'stagpy'``.
28+
conf.core.outname: the generic name stem, defaults to ``'stagpy'``.
2329
"""
2430
if conf.core.shortname:
2531
return conf.core.outname
@@ -28,32 +34,33 @@ def out_name(stem, timestep=None):
2834
return conf.core.outname + '_' + stem
2935

3036

31-
def scilabel(value, precision=2):
37+
def scilabel(value: float, precision: int = 2) -> str:
3238
"""Build scientific notation of some value.
3339
3440
This is dedicated to use in labels displaying scientific values.
3541
3642
Args:
37-
value (float): numeric value to format.
38-
precision (int): number of decimal digits.
43+
value: numeric value to format.
44+
precision: number of decimal digits.
3945
4046
Returns:
41-
str: the scientific notation the specified value.
47+
the scientific notation of the specified value.
4248
"""
43-
man, exp = f'{value:.{precision}e}'.split('e')
44-
exp = int(exp)
49+
man, exps = f'{value:.{precision}e}'.split('e')
50+
exp = int(exps)
4551
return fr'{man}\times 10^{{{exp}}}'
4652

4753

48-
def saveplot(fig, *name_args, close=True, **name_kwargs):
54+
def saveplot(fig: Figure, *name_args: Any, close: bool = True,
55+
**name_kwargs: Any):
4956
"""Save matplotlib figure.
5057
5158
You need to provide :data:`stem` as a positional or keyword argument (see
5259
:func:`out_name`).
5360
5461
Args:
55-
fig (:class:`matplotlib.figure.Figure`): matplotlib figure.
56-
close (bool): whether to close the figure.
62+
fig: the :class:`matplotlib.figure.Figure` to save.
63+
close: whether to close the figure.
5764
name_args: positional arguments passed on to :func:`out_name`.
5865
name_kwargs: keyword arguments passed on to :func:`out_name`.
5966
"""
@@ -64,7 +71,7 @@ def saveplot(fig, *name_args, close=True, **name_kwargs):
6471
plt.close(fig)
6572

6673

67-
def baredoc(obj):
74+
def baredoc(obj: object) -> str:
6875
"""Return the first line of the docstring of an object.
6976
7077
Trailing periods and spaces as well as leading spaces are removed from the
@@ -82,12 +89,12 @@ def baredoc(obj):
8289
return doc.rstrip(' .').lstrip()
8390

8491

85-
def list_of_vars(arg_plot):
92+
def list_of_vars(arg_plot: str) -> List[List[List[str]]]:
8693
"""Construct list of variables per plot.
8794
8895
Args:
89-
arg_plot (str): string with variable names separated with
90-
``-`` (figures), ``.`` (subplots) and ``,`` (same subplot).
96+
arg_plot: variable names separated with ``-`` (figures),
97+
``.`` (subplots) and ``,`` (same subplot).
9198
Returns:
9299
three nested lists of str
93100
@@ -102,13 +109,13 @@ def list_of_vars(arg_plot):
102109
return [lov for lov in lovs if lov]
103110

104111

105-
def find_in_sorted_arr(value, array, after=False):
112+
def find_in_sorted_arr(value: Any, array: ndarray, after=False) -> int:
106113
"""Return position of element in a sorted array.
107114
108115
Returns:
109-
int: the maximum position i such as array[i] <= value. If after is
110-
True, it returns the min i such as value <= array[i] (or 0 if such
111-
an indices does not exist).
116+
the maximum position i such as array[i] <= value. If after is True, it
117+
returns the min i such as value <= array[i] (or 0 if such an index does
118+
not exist).
112119
"""
113120
ielt = array.searchsorted(value)
114121
if ielt == array.size:
@@ -133,13 +140,13 @@ class CachedReadOnlyProperty:
133140
property is read-only instead of being writeable.
134141
"""
135142

136-
def __init__(self, thunk):
143+
def __init__(self, thunk: Callable[[Any], Any]):
137144
self._thunk = thunk
138145
self._name = thunk.__name__
139146
self._cache_name = f'_cropped_{self._name}'
140147
self.__doc__ = thunk.__doc__
141148

142-
def __get__(self, instance, _):
149+
def __get__(self, instance: Any, _) -> Any:
143150
try:
144151
return getattr(instance, self._cache_name)
145152
except AttributeError:
@@ -148,6 +155,6 @@ def __get__(self, instance, _):
148155
setattr(instance, self._cache_name, cached_value)
149156
return cached_value
150157

151-
def __set__(self, instance, _):
158+
def __set__(self, instance: Any, _):
152159
raise AttributeError(
153160
f'Cannot set {self._name} property of {instance!r}')

0 commit comments

Comments
 (0)