Skip to content

Commit e27c189

Browse files
committed
modernize annotations
1 parent ff0fcb7 commit e27c189

File tree

10 files changed

+81
-95
lines changed

10 files changed

+81
-95
lines changed

stagpy/_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Optional
1212

1313
from matplotlib.figure import Figure
14-
from numpy import ndarray
14+
from numpy.typing import NDArray
1515

1616
from .config import Config
1717
from .stagyydata import StagyyData, _StepsView
@@ -102,7 +102,7 @@ def baredoc(obj: object) -> str:
102102
return doc.rstrip(" .").lstrip()
103103

104104

105-
def find_in_sorted_arr(value: Any, array: ndarray, after: bool = False) -> int:
105+
def find_in_sorted_arr(value: Any, array: NDArray, after: bool = False) -> int:
106106
"""Return position of element in a sorted array.
107107
108108
Returns:

stagpy/_step.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,9 @@
2020
from .dimensions import Scales
2121

2222
if typing.TYPE_CHECKING:
23-
from typing import (
24-
Any,
25-
Callable,
26-
Dict,
27-
Iterator,
28-
List,
29-
Mapping,
30-
NoReturn,
31-
Optional,
32-
Tuple,
33-
)
34-
35-
from numpy import ndarray
23+
from typing import Any, Callable, Iterator, Mapping, NoReturn, Optional
24+
25+
from numpy.typing import NDArray
3626
from pandas import DataFrame, Series
3727

3828
from .datatypes import Varf
@@ -46,18 +36,18 @@ class _Geometry:
4636
output by StagYY.
4737
"""
4838

49-
def __init__(self, header: Dict[str, Any], step: Step):
39+
def __init__(self, header: dict[str, Any], step: Step):
5040
self._header = header
5141
self._step = step
52-
self._shape: Dict[str, Any] = {
42+
self._shape: dict[str, Any] = {
5343
"sph": False,
5444
"cyl": False,
5545
"axi": False,
5646
"ntot": list(header["nts"]) + [header["ntb"]],
5747
}
5848
self._init_shape()
5949

60-
def _scale_radius_mo(self, radius: ndarray) -> ndarray:
50+
def _scale_radius_mo(self, radius: NDArray) -> NDArray:
6151
"""Rescale radius for evolving MO runs."""
6252
if self._step.sdat.par.get("magma_oceans_in", "evolving_magma_oceans", False):
6353
return self._header["mo_thick_sol"] * (radius + self._header["mo_lambda"])
@@ -103,7 +93,7 @@ def nztot(self) -> int:
10393
return self.nrtot
10494

10595
@cached_property
106-
def r_walls(self) -> ndarray:
96+
def r_walls(self) -> NDArray:
10797
"""Position of FV walls along the z/r direction."""
10898
rgeom = self._header.get("rgeom")
10999
if rgeom is not None:
@@ -114,7 +104,7 @@ def r_walls(self) -> ndarray:
114104
return self._scale_radius_mo(walls)
115105

116106
@cached_property
117-
def r_centers(self) -> ndarray:
107+
def r_centers(self) -> NDArray:
118108
"""Position of FV centers along the z/r direction."""
119109
rgeom = self._header.get("rgeom")
120110
if rgeom is not None:
@@ -124,7 +114,7 @@ def r_centers(self) -> ndarray:
124114
return self._scale_radius_mo(walls)
125115

126116
@cached_property
127-
def t_walls(self) -> ndarray:
117+
def t_walls(self) -> NDArray:
128118
"""Position of FV walls along x/theta."""
129119
if self.threed or self.twod_xz:
130120
if self.yinyang:
@@ -143,12 +133,12 @@ def t_walls(self) -> ndarray:
143133
return np.array([center - d_t, center + d_t])
144134

145135
@cached_property
146-
def t_centers(self) -> ndarray:
136+
def t_centers(self) -> NDArray:
147137
"""Position of FV centers along x/theta."""
148138
return (self.t_walls[:-1] + self.t_walls[1:]) / 2
149139

150140
@cached_property
151-
def p_walls(self) -> ndarray:
141+
def p_walls(self) -> NDArray:
152142
"""Position of FV walls along y/phi."""
153143
if self.threed or self.twod_yz:
154144
if self.yinyang:
@@ -165,37 +155,37 @@ def p_walls(self) -> ndarray:
165155
return np.array([-d_p, d_p])
166156

167157
@cached_property
168-
def p_centers(self) -> ndarray:
158+
def p_centers(self) -> NDArray:
169159
"""Position of FV centers along y/phi."""
170160
return (self.p_walls[:-1] + self.p_walls[1:]) / 2
171161

172162
@property
173-
def z_walls(self) -> ndarray:
163+
def z_walls(self) -> NDArray:
174164
"""Same as r_walls."""
175165
return self.r_walls
176166

177167
@property
178-
def z_centers(self) -> ndarray:
168+
def z_centers(self) -> NDArray:
179169
"""Same as r_centers."""
180170
return self.r_centers
181171

182172
@property
183-
def x_walls(self) -> ndarray:
173+
def x_walls(self) -> NDArray:
184174
"""Same as t_walls."""
185175
return self.t_walls
186176

187177
@property
188-
def x_centers(self) -> ndarray:
178+
def x_centers(self) -> NDArray:
189179
"""Same as t_centers."""
190180
return self.t_centers
191181

192182
@property
193-
def y_walls(self) -> ndarray:
183+
def y_walls(self) -> NDArray:
194184
"""Same as p_walls."""
195185
return self.p_walls
196186

197187
@property
198-
def y_centers(self) -> ndarray:
188+
def y_centers(self) -> NDArray:
199189
"""Same as p_centers."""
200190
return self.p_centers
201191

@@ -297,15 +287,15 @@ def __init__(
297287
step: Step,
298288
variables: Mapping[str, Varf],
299289
extravars: Mapping[str, Callable[[Step], Field]],
300-
files: Mapping[str, List[str]],
301-
filesh5: Mapping[str, List[str]],
290+
files: Mapping[str, list[str]],
291+
filesh5: Mapping[str, list[str]],
302292
):
303293
self.step = step
304294
self._vars = variables
305295
self._extra = extravars
306296
self._files = files
307297
self._filesh5 = filesh5
308-
self._data: Dict[str, Field] = {}
298+
self._data: dict[str, Field] = {}
309299
super().__init__()
310300

311301
def __getitem__(self, name: str) -> Field:
@@ -329,7 +319,7 @@ def __getitem__(self, name: str) -> Field:
329319
return self._data[name]
330320

331321
@cached_property
332-
def _present_fields(self) -> List[str]:
322+
def _present_fields(self) -> list[str]:
333323
return [fld for fld in chain(self._vars, self._extra) if fld in self]
334324

335325
def __iter__(self) -> Iterator[str]:
@@ -347,7 +337,7 @@ def __len__(self) -> int:
347337
def __eq__(self, other: object) -> bool:
348338
return self is other
349339

350-
def _get_raw_data(self, name: str) -> Tuple[List[str], Any]:
340+
def _get_raw_data(self, name: str) -> tuple[list[str], Any]:
351341
"""Find file holding data and return its content."""
352342
# try legacy first, then hdf5
353343
filestem = ""
@@ -384,7 +374,7 @@ def _get_raw_data(self, name: str) -> Tuple[List[str], Any]:
384374
break
385375
return list_fvar, parsed_data
386376

387-
def _set(self, name: str, fld: ndarray) -> None:
377+
def _set(self, name: str, fld: NDArray) -> None:
388378
sdat = self.step.sdat
389379
col_fld = sdat._collected_fields
390380
col_fld.append((self.step.istep, name))
@@ -399,7 +389,7 @@ def __delitem__(self, name: str) -> None:
399389
del self._data[name]
400390

401391
@cached_property
402-
def _header(self) -> Optional[Dict[str, Any]]:
392+
def _header(self) -> Optional[dict[str, Any]]:
403393
if self.step.isnap is None:
404394
return None
405395
binfiles = self.step.sdat._binfiles_set(self.step.isnap)
@@ -439,9 +429,9 @@ class _Tracers:
439429

440430
def __init__(self, step: Step):
441431
self.step = step
442-
self._data: Dict[str, Optional[List[ndarray]]] = {}
432+
self._data: dict[str, Optional[list[NDArray]]] = {}
443433

444-
def __getitem__(self, name: str) -> Optional[List[ndarray]]:
434+
def __getitem__(self, name: str) -> Optional[list[NDArray]]:
445435
if name in self._data:
446436
return self._data[name]
447437
if self.step.isnap is None:
@@ -480,7 +470,7 @@ class _Rprofs:
480470

481471
def __init__(self, step: Step):
482472
self.step = step
483-
self._cached_extra: Dict[str, Rprof] = {}
473+
self._cached_extra: dict[str, Rprof] = {}
484474

485475
@cached_property
486476
def _data(self) -> Optional[DataFrame]:
@@ -527,12 +517,12 @@ def stepstr(self) -> str:
527517
return str(self.step.istep)
528518

529519
@cached_property
530-
def centers(self) -> ndarray:
520+
def centers(self) -> NDArray:
531521
"""Radial position of cell centers."""
532522
return self._rprofs["r"].to_numpy() + self.bounds[0]
533523

534524
@cached_property
535-
def walls(self) -> ndarray:
525+
def walls(self) -> NDArray:
536526
"""Radial position of cell walls."""
537527
rbot, rtop = self.bounds
538528
try:
@@ -547,7 +537,7 @@ def walls(self) -> ndarray:
547537
return walls
548538

549539
@cached_property
550-
def bounds(self) -> Tuple[float, float]:
540+
def bounds(self) -> tuple[float, float]:
551541
"""Radial or vertical position of boundaries.
552542
553543
Radial/vertical positions of boundaries of the domain.

stagpy/args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .config import Config
2727

2828
if typing.TYPE_CHECKING:
29-
from typing import Any, Callable, List, Optional
29+
from typing import Any, Callable, Optional
3030

3131

3232
def _sub(cmd: Any, *sections: str) -> Subcmd:
@@ -72,7 +72,7 @@ def _load_mplstyle(conf: Config) -> None:
7272

7373

7474
def parse_args(
75-
conf: Config, arglist: Optional[List[str]] = None
75+
conf: Config, arglist: Optional[list[str]] = None
7676
) -> Callable[[Config], None]:
7777
"""Parse cmd line arguments.
7878

stagpy/commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .config import CONFIG_LOCAL, Config
1515

1616
if typing.TYPE_CHECKING:
17-
from typing import Callable, Iterable, Mapping, Optional, Sequence, Tuple, Union
17+
from typing import Callable, Iterable, Mapping, Optional, Sequence, Union
1818

1919
from loam.base import Section
2020

@@ -52,7 +52,7 @@ def info_cmd(conf: Config) -> None:
5252

5353

5454
def _pretty_print(
55-
key_val: Sequence[Tuple[str, str]],
55+
key_val: Sequence[tuple[str, str]],
5656
sep: str = ": ",
5757
min_col_width: int = 39,
5858
text_width: Optional[int] = None,

stagpy/field.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from .stagyydata import StagyyData
1818

1919
if typing.TYPE_CHECKING:
20-
from typing import Any, Dict, Iterable, Optional, Tuple, Union
20+
from typing import Any, Iterable, Optional, Union
2121

2222
from matplotlib.axes import Axes
2323
from matplotlib.collections import QuadMesh
2424
from matplotlib.colorbar import Colorbar
2525
from matplotlib.figure import Figure
26-
from numpy import ndarray
26+
from numpy.typing import NDArray
2727

2828
from ._step import Step
2929
from .datatypes import Varf
@@ -36,7 +36,7 @@
3636

3737
def _threed_extract(
3838
conf: Config, step: Step, var: str, walls: bool = False
39-
) -> Tuple[Tuple[ndarray, ndarray], ndarray]:
39+
) -> tuple[tuple[NDArray, NDArray], NDArray]:
4040
"""Return suitable slices and coords for 3D fields."""
4141
is_vector = not valid_field_var(var)
4242
hwalls = is_vector or walls
@@ -89,7 +89,7 @@ def valid_field_var(var: str) -> bool:
8989

9090
def get_meshes_fld(
9191
conf: Config, step: Step, var: str, walls: bool = False
92-
) -> Tuple[ndarray, ndarray, ndarray, Varf]:
92+
) -> tuple[NDArray, NDArray, NDArray, Varf]:
9393
"""Return scalar field along with coordinates meshes.
9494
9595
Only works properly in 2D geometry and 3D cartesian.
@@ -129,7 +129,7 @@ def get_meshes_fld(
129129

130130
def get_meshes_vec(
131131
conf: Config, step: Step, var: str
132-
) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
132+
) -> tuple[NDArray, NDArray, NDArray, NDArray]:
133133
"""Return vector field components along with coordinates meshes.
134134
135135
Only works properly in 2D geometry and 3D cartesian.
@@ -170,11 +170,11 @@ def get_meshes_vec(
170170
def plot_scalar(
171171
step: Step,
172172
var: str,
173-
field: Optional[ndarray] = None,
173+
field: Optional[NDArray] = None,
174174
axis: Optional[Axes] = None,
175175
conf: Optional[Config] = None,
176176
**extra: Any,
177-
) -> Tuple[Figure, Axes, QuadMesh, Optional[Colorbar]]:
177+
) -> tuple[Figure, Axes, QuadMesh, Optional[Colorbar]]:
178178
"""Plot scalar field.
179179
180180
Args:
@@ -264,7 +264,7 @@ def plot_iso(
264264
axis: Axes,
265265
step: Step,
266266
var: str,
267-
field: Optional[ndarray] = None,
267+
field: Optional[NDArray] = None,
268268
conf: Optional[Config] = None,
269269
**extra: Any,
270270
) -> None:
@@ -289,7 +289,7 @@ def plot_iso(
289289

290290
if conf.field.shift:
291291
fld = np.roll(fld, conf.field.shift, axis=0)
292-
extra_opts: Dict[str, Any] = dict(linewidths=1)
292+
extra_opts: dict[str, Any] = dict(linewidths=1)
293293
if "cmap" not in extra and conf.field.isocolors:
294294
extra_opts["colors"] = conf.field.isocolors
295295
elif "colors" not in extra:
@@ -337,9 +337,9 @@ def plot_vec(
337337

338338
def _findminmax(
339339
view: _StepsView, sovs: Iterable[str]
340-
) -> Dict[str, Tuple[float, float]]:
340+
) -> dict[str, tuple[float, float]]:
341341
"""Find min and max values of several fields."""
342-
minmax: Dict[str, Tuple[float, float]] = {}
342+
minmax: dict[str, tuple[float, float]] = {}
343343
for step in view.filter(snap=True):
344344
for var in sovs:
345345
if var in step.fields:
@@ -375,7 +375,7 @@ def cmd(conf: Config) -> None:
375375
if var[0] not in step.fields:
376376
print(f"{var[0]!r} field on snap {step.isnap} not found")
377377
continue
378-
opts: Dict[str, Any] = {}
378+
opts: dict[str, Any] = {}
379379
if var[0] in minmax:
380380
opts = dict(vmin=minmax[var[0]][0], vmax=minmax[var[0]][1])
381381
plot_scalar(step, var[0], axis=axis, conf=conf, **opts)

0 commit comments

Comments
 (0)