Skip to content

Commit 56bfe74

Browse files
committed
Add type annotations to plates module
1 parent 8195079 commit 56bfe74

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

stagpy/plates.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Plate analysis."""
22

3+
from __future__ import annotations
34
from contextlib import suppress
45
from functools import lru_cache
6+
import typing
57

68
import matplotlib.pyplot as plt
79
from matplotlib import colors
@@ -10,11 +12,17 @@
1012

1113
from . import conf, error, field, phyvars, _helpers
1214
from ._helpers import saveplot
13-
from ._step import Field
15+
from .datatypes import Field
1416
from .stagyydata import StagyyData
1517

18+
if typing.TYPE_CHECKING:
19+
from typing import Iterable, Tuple, TextIO, Union
20+
from matplotlib.axes import Axes
21+
from numpy import ndarray
22+
from ._step import Step, _Geometry
1623

17-
def _vzcheck(iphis, snap, vz_thres):
24+
25+
def _vzcheck(iphis: Iterable[int], snap: Step, vz_thres: float) -> ndarray:
1826
"""Remove positions where vz is below threshold."""
1927
# verifying vertical velocity
2028
vzabs = np.abs(snap.fields['v3'].values[0, ..., 0])
@@ -27,20 +35,21 @@ def _vzcheck(iphis, snap, vz_thres):
2735

2836

2937
@lru_cache()
30-
def detect_plates(snap, vz_thres_ratio=0):
38+
def detect_plates(snap: Step,
39+
vz_thres_ratio: float = 0) -> Tuple[ndarray, ndarray]:
3140
"""Detect plate limits using derivative of horizontal velocity.
3241
3342
This function is cached for convenience.
3443
3544
Args:
36-
snap (:class:`~stagpy._step.Step`): a step of a StagyyData instance.
37-
vz_thres_ratio (float): if above zero, an additional check based on the
45+
snap: a :class:`~stagpy._step.Step` of a StagyyData instance.
46+
vz_thres_ratio: if above zero, an additional check based on the
3847
vertical velocities is performed. Limits detected above a region
3948
where the vertical velocity is below vz_thres_ratio * mean(vzabs)
4049
are ignored.
4150
Returns:
42-
tuple of :class:`numpy.array`: itrenches, iridges
43-
1D arrays containing phi-index of detected trenches and ridges.
51+
tuple (itrenches, iridges). 1D arrays containing phi-index of detected
52+
trenches and ridges.
4453
"""
4554
dvphi = _surf_diag(snap, 'dv2').values
4655

@@ -83,15 +92,17 @@ def detect_plates(snap, vz_thres_ratio=0):
8392
return itrenches, iridges
8493

8594

86-
def _plot_plate_limits(axis, trenches, ridges):
95+
def _plot_plate_limits(axis: Axes, trenches: ndarray, ridges: ndarray):
8796
"""Plot lines designating ridges and trenches."""
8897
for trench in trenches:
8998
axis.axvline(x=trench, color='red', ls='dashed', alpha=0.4)
9099
for ridge in ridges:
91100
axis.axvline(x=ridge, color='green', ls='dashed', alpha=0.4)
92101

93102

94-
def _annot_pos(geom, iphi):
103+
def _annot_pos(
104+
geom: _Geometry, iphi: int
105+
) -> Tuple[Tuple[float, float], Tuple[float, float]]:
95106
"""Position of arrows to mark limit positions."""
96107
phi = geom.p_centers[iphi]
97108
rtot = geom.r_walls[-1]
@@ -106,7 +117,7 @@ def _annot_pos(geom, iphi):
106117
return p_beg, p_end
107118

108119

109-
def _plot_plate_limits_field(axis, snap):
120+
def _plot_plate_limits_field(axis: Axes, snap: Step):
110121
"""Plot arrows designating ridges and trenches in 2D field plots."""
111122
itrenches, iridges = detect_plates(snap, conf.plates.vzratio)
112123
for itrench in itrenches:
@@ -121,7 +132,7 @@ def _plot_plate_limits_field(axis, snap):
121132
annotation_clip=False)
122133

123134

124-
def _isurf(snap):
135+
def _isurf(snap: Step) -> int:
125136
"""Return index of surface accounting for air layer."""
126137
if snap.sdat.par['boundaries']['air_layer']:
127138
dsa = snap.sdat.par['boundaries']['air_thickness']
@@ -134,7 +145,7 @@ def _isurf(snap):
134145
return isurf
135146

136147

137-
def _surf_diag(snap, name):
148+
def _surf_diag(snap: Step, name: str) -> Field:
138149
"""Get a surface field.
139150
140151
Can be a sfield, a regular scalar field evaluated at the surface,
@@ -157,12 +168,13 @@ def _surf_diag(snap, name):
157168
raise error.UnknownVarError(name)
158169

159170

160-
def _continents_location(snap, at_surface=True):
171+
def _continents_location(snap: Step, at_surface: bool = True) -> ndarray:
161172
"""Location of continents as a boolean array.
162173
163174
If at_surface is True, it is evaluated only at the surface, otherwise it is
164175
evaluated in the entire domain.
165176
"""
177+
icont: Union[int, slice]
166178
if at_surface:
167179
if snap.sdat.par['boundaries']['air_layer']:
168180
icont = _isurf(snap) - 6
@@ -182,13 +194,13 @@ def _continents_location(snap, at_surface=True):
182194
return csurf >= 2
183195

184196

185-
def plot_at_surface(snap, names):
197+
def plot_at_surface(snap: Step, names: str):
186198
"""Plot surface diagnostics.
187199
188200
Args:
189-
snap (:class:`~stagpy._step.Step`): a step of a StagyyData instance.
190-
names (str): names of requested surface diagnotics. They are separated
191-
by ``-`` (figures), ``.`` (subplots) and ``,`` (same subplot).
201+
snap: a :class:`~stagpy._step.Step` of a StagyyData instance.
202+
names: names of requested surface diagnotics. They are separated by
203+
``-`` (figures), ``.`` (subplots) and ``,`` (same subplot).
192204
Surface diagnotics can be valid surface field names, field names,
193205
or `"dv2"` which is d(vphi)/dphi.
194206
"""
@@ -226,8 +238,9 @@ def plot_at_surface(snap, names):
226238
saveplot(fig, fname, snap.isnap)
227239

228240

229-
def _write_trench_diagnostics(step, vrms_surf, fid):
241+
def _write_trench_diagnostics(step: Step, vrms_surf: float, fid: TextIO):
230242
"""Print out some trench diagnostics."""
243+
assert step.isnap is not None
231244
itrenches, _ = detect_plates(step, conf.plates.vzratio)
232245
time = step.time * vrms_surf *\
233246
conf.scaling.ttransit / conf.scaling.yearins / 1.e6
@@ -279,12 +292,12 @@ def _write_trench_diagnostics(step, vrms_surf, fid):
279292
agetrenches[isubd]))
280293

281294

282-
def plot_scalar_field(snap, fieldname):
295+
def plot_scalar_field(snap: Step, fieldname: str):
283296
"""Plot scalar field with plate information.
284297
285298
Args:
286-
snap (:class:`~stagpy._step.Step`): a step of a StagyyData instance.
287-
fieldname (str): name of the field that should be decorated with plate
299+
snap: a :class:`~stagpy._step.Step` of a StagyyData instance.
300+
fieldname: name of the field that should be decorated with plate
288301
informations.
289302
"""
290303
fig, axis, _, _ = field.plot_scalar(snap, fieldname)

0 commit comments

Comments
 (0)