Skip to content

Commit 8d3ca0f

Browse files
committed
Add type annotations to field module
1 parent a6db3f5 commit 8d3ca0f

File tree

2 files changed

+73
-48
lines changed

2 files changed

+73
-48
lines changed

stagpy/field.py

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Plot scalar and vector fields."""
22

3+
from __future__ import annotations
34
from itertools import chain
5+
import typing
46

57
import numpy as np
68
import matplotlib as mpl
@@ -12,12 +14,21 @@
1214
from .error import NotAvailableError
1315
from .stagyydata import StagyyData
1416

17+
if typing.TYPE_CHECKING:
18+
from typing import Tuple, Optional, Any, Iterable, Dict
19+
from numpy import ndarray
20+
from matplotlib.axes import Axes
21+
from .datatypes import Varf
22+
from ._step import Step
23+
1524

1625
# The location is off for vertical velocities: they have an extra
1726
# point in (x,y) instead of z in the output
1827

1928

20-
def _threed_extract(step, var, walls=False):
29+
def _threed_extract(
30+
step: Step, var: str, walls: bool = False
31+
) -> Tuple[Tuple[ndarray, ndarray], Any]:
2132
"""Return suitable slices and coords for 3D fields."""
2233
is_vector = not valid_field_var(var)
2334
hwalls = is_vector or walls
@@ -53,33 +64,33 @@ def _threed_extract(step, var, walls=False):
5364
return (xcoord, ycoord), data
5465

5566

56-
def valid_field_var(var):
67+
def valid_field_var(var: str) -> bool:
5768
"""Whether a field variable is defined.
5869
59-
This function checks if a definition of the variable exists in
60-
:data:`~stagpy.phyvars.FIELD` or :data:`~stagpy.phyvars.FIELD_EXTRA`.
61-
6270
Args:
63-
var (str): the variable name to be checked.
71+
var: the variable name to be checked.
6472
Returns:
65-
bool: True is the var is defined, False otherwise.
73+
whether the var is defined in :data:`~stagpy.phyvars.FIELD` or
74+
:data:`~stagpy.phyvars.FIELD_EXTRA`.
6675
"""
6776
return var in phyvars.FIELD or var in phyvars.FIELD_EXTRA
6877

6978

70-
def get_meshes_fld(step, var, walls=False):
79+
def get_meshes_fld(
80+
step: Step, var: str, walls: bool = False
81+
) -> Tuple[ndarray, ndarray, ndarray, Varf]:
7182
"""Return scalar field along with coordinates meshes.
7283
7384
Only works properly in 2D geometry and 3D cartesian.
7485
7586
Args:
76-
step (:class:`~stagpy._step.Step`): a step of a StagyyData instance.
77-
var (str): scalar field name.
78-
walls (bool): consider the walls as the relevant mesh.
87+
step: a :class:`~stagpy._step.Step` of a StagyyData instance.
88+
var: scalar field name.
89+
walls: consider the walls as the relevant mesh.
7990
Returns:
80-
tuple of :class:`numpy.array`: xmesh, ymesh, fld, meta
81-
2D arrays containing respectively the x position, y position, the
82-
values and the metadata of the requested field.
91+
tuple (xmesh, ymesh, fld, meta). 2D arrays containing respectively the
92+
x position, y position, the values and the metadata of the requested
93+
field.
8394
"""
8495
fld, meta = step.fields[var]
8596
hwalls = (walls or fld.shape[0] != step.geom.nxtot or
@@ -102,18 +113,20 @@ def get_meshes_fld(step, var, walls=False):
102113
return xmesh, ymesh, fld, meta
103114

104115

105-
def get_meshes_vec(step, var):
116+
def get_meshes_vec(
117+
step: Step, var: str
118+
) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
106119
"""Return vector field components along with coordinates meshes.
107120
108121
Only works properly in 2D geometry and 3D cartesian.
109122
110123
Args:
111-
step (:class:`~stagpy._step.Step`): a step of a StagyyData instance.
112-
var (str): vector field name.
124+
step: a :class:`~stagpy._step.Step` of a StagyyData instance.
125+
var: vector field name.
113126
Returns:
114-
tuple of :class:`numpy.array`: xmesh, ymesh, fldx, fldy
115-
2D arrays containing respectively the x position, y position, x
116-
component and y component of the requested vector field.
127+
tuple (xmesh, ymesh, fldx, fldy). 2D arrays containing respectively
128+
the x position, y position, x component and y component of the
129+
requested vector field.
117130
"""
118131
if step.geom.threed and step.geom.cartesian:
119132
(xcoord, ycoord), (vec1, vec2) = _threed_extract(step, var)
@@ -140,20 +153,21 @@ def get_meshes_vec(step, var):
140153
return xmesh, ymesh, vec1, vec2
141154

142155

143-
def plot_scalar(step, var, field=None, axis=None, **extra):
156+
def plot_scalar(step: Step, var: str, field: Optional[ndarray] = None,
157+
axis: Optional[Axes] = None, **extra: Any):
144158
"""Plot scalar field.
145159
146160
Args:
147-
step (:class:`~stagpy._step.Step`): a step of a StagyyData instance.
148-
var (str): the scalar field name.
149-
field (:class:`numpy.array`): if not None, it is plotted instead of
150-
step.fields[var]. This is useful to plot a masked or rescaled
151-
array. Note that if conf.scaling.dimensional is True, this
152-
field will be scaled accordingly.
153-
axis (:class:`matplotlib.axes.Axes`): the axis objet where the field
154-
should be plotted. If set to None, a new figure with one subplot
155-
is created.
156-
extra (dict): options that will be passed on to
161+
step: a :class:`~stagpy._step.Step` of a StagyyData instance.
162+
var: the scalar field name.
163+
field: if not None, it is plotted instead of step.fields[var]. This is
164+
useful to plot a masked or rescaled array. Note that if
165+
conf.scaling.dimensional is True, this field will be scaled
166+
accordingly.
167+
axis: the :class:`matplotlib.axes.Axes` object where the field should
168+
be plotted. If set to None, a new figure with one subplot is
169+
created.
170+
extra: options that will be passed on to
157171
:func:`matplotlib.axes.Axes.pcolormesh`.
158172
Returns:
159173
fig, axis, surf, cbar
@@ -241,22 +255,21 @@ def plot_scalar(step, var, field=None, axis=None, **extra):
241255
return fig, axis, surf, cbar
242256

243257

244-
def plot_iso(axis, step, var, **extra):
258+
def plot_iso(axis: Axes, step: Step, var: str, **extra: Any):
245259
"""Plot isocontours of scalar field.
246260
247261
Args:
248-
axis (:class:`matplotlib.axes.Axes`): the axis handler of an
249-
existing matplotlib figure where the isocontours should
250-
be plotted.
251-
step (:class:`~stagpy._step.Step`): a step of a StagyyData instance.
252-
var (str): the scalar field name.
253-
extra (dict): options that will be passed on to
262+
axis: the :class:`matplotlib.axes.Axes` of an existing matplotlib
263+
figure where the isocontours should be plotted.
264+
step: a :class:`~stagpy._step.Step` of a StagyyData instance.
265+
var: the scalar field name.
266+
extra: options that will be passed on to
254267
:func:`matplotlib.axes.Axes.contour`.
255268
"""
256269
xmesh, ymesh, fld, _ = get_meshes_fld(step, var)
257270
if conf.field.shift:
258271
fld = np.roll(fld, conf.field.shift, axis=0)
259-
extra_opts = dict(linewidths=1)
272+
extra_opts: Dict[str, Any] = dict(linewidths=1)
260273
if 'cmap' not in extra and conf.field.isocolors:
261274
extra_opts['colors'] = conf.field.isocolors.split(',')
262275
elif 'colors' not in extra:
@@ -267,15 +280,14 @@ def plot_iso(axis, step, var, **extra):
267280
axis.contour(xmesh, ymesh, fld, **extra_opts)
268281

269282

270-
def plot_vec(axis, step, var):
283+
def plot_vec(axis: Axes, step: Step, var: str):
271284
"""Plot vector field.
272285
273286
Args:
274-
axis (:class:`matplotlib.axes.Axes`): the axis handler of an
275-
existing matplotlib figure where the vector field should
276-
be plotted.
277-
step (:class:`~stagpy._step.Step`): a step of a StagyyData instance.
278-
var (str): the vector field name.
287+
axis: the :class:`matplotlib.axes.Axes` of an existing matplotlib
288+
figure where the vector field should be plotted.
289+
step: a :class:`~stagpy._step.Step` of a StagyyData instance.
290+
var: the vector field name.
279291
"""
280292
xmesh, ymesh, vec1, vec2 = get_meshes_vec(step, var)
281293
dipz = step.geom.nztot // 10
@@ -292,9 +304,11 @@ def plot_vec(axis, step, var):
292304
linewidths=1)
293305

294306

295-
def _findminmax(sdat, sovs):
307+
def _findminmax(
308+
sdat: StagyyData, sovs: Iterable[str]
309+
) -> Dict[str, Tuple[float, float]]:
296310
"""Find min and max values of several fields."""
297-
minmax = {}
311+
minmax: Dict[str, Tuple[float, float]] = {}
298312
for step in sdat.walk.filter(snap=True):
299313
for var in sovs:
300314
if var in step.fields:
@@ -333,7 +347,7 @@ def cmd():
333347
if var[0] not in step.fields:
334348
print(f"{var[0]!r} field on snap {step.isnap} not found")
335349
continue
336-
opts = {}
350+
opts: Dict[str, Any] = {}
337351
if var[0] in minmax:
338352
opts = dict(vmin=minmax[var[0]][0], vmax=minmax[var[0]][1])
339353
plot_scalar(step, var[0], axis=axis, **opts)

stagpy/stagyydata.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,18 @@ def nfields_max(self, nfields: Optional[int]):
832832
raise error.InvalidNfieldsError(nfields)
833833
self._nfields_max = nfields
834834

835+
@typing.overload
835836
def scale(self, data: ndarray, unit: str) -> Tuple[ndarray, str]:
837+
"""Scale a ndarray."""
838+
...
839+
840+
@typing.overload
841+
def scale(self, data: float, unit: str) -> Tuple[float, str]:
842+
"""Scale a float."""
843+
...
844+
845+
def scale(self, data: Union[ndarray, float],
846+
unit: str) -> Tuple[Union[ndarray, float], str]:
836847
"""Scales quantity to obtain dimensionful quantity.
837848
838849
Args:

0 commit comments

Comments
 (0)