Skip to content

Commit cdc7c6d

Browse files
committed
tighten NDArray annotations
1 parent 918066c commit cdc7c6d

File tree

9 files changed

+78
-57
lines changed

9 files changed

+78
-57
lines changed

src/stagpy/_helpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from inspect import getdoc
77

88
import matplotlib.pyplot as plt
9+
import numpy as np
910

1011
if typing.TYPE_CHECKING:
1112
from matplotlib.figure import Figure
@@ -101,7 +102,9 @@ def baredoc(obj: object) -> str:
101102
return doc.rstrip(" .").lstrip()
102103

103104

104-
def find_in_sorted_arr(value: float, array: NDArray, after: bool = False) -> int:
105+
def find_in_sorted_arr(
106+
value: float, array: NDArray[np.float64], after: bool = False
107+
) -> int:
105108
"""Return position of element in a sorted array.
106109
107110
Returns:

src/stagpy/datatypes.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import typing
66
from dataclasses import dataclass
77

8+
import numpy as np
9+
810
if typing.TYPE_CHECKING:
911
from numpy.typing import NDArray
1012

@@ -23,7 +25,7 @@ class Varf:
2325
class Field:
2426
"""Scalar field and associated metadata."""
2527

26-
values: NDArray
28+
values: NDArray[np.float64]
2729
"""values of field."""
2830
meta: Varf
2931
"""metadata."""
@@ -45,9 +47,9 @@ class Varr:
4547
class Rprof:
4648
"""Radial profile with associated radius and metadata."""
4749

48-
values: NDArray
50+
values: NDArray[np.float64]
4951
"""values of profile."""
50-
rad: NDArray
52+
rad: NDArray[np.float64]
5153
"""radial position of profile."""
5254
meta: Varr
5355
"""metadata."""
@@ -69,9 +71,9 @@ class Vart:
6971
class Tseries:
7072
"""A time series with associated time and metadata."""
7173

72-
values: NDArray
74+
values: NDArray[np.float64]
7375
"""values of time series."""
74-
time: NDArray
76+
time: NDArray[np.float64]
7577
"""time position of time series."""
7678
meta: Vart
7779
"""metadata."""

src/stagpy/dimensions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
from dataclasses import dataclass
55
from functools import cached_property
66

7+
import numpy as np
8+
79
from . import phyvars
810
from .config import Scaling
911

1012
if typing.TYPE_CHECKING:
11-
from typing import TypeVar
12-
1313
from numpy.typing import NDArray
1414

1515
from .parfile import StagyyPar
1616
from .stagyydata import StagyyData
1717

18-
T = TypeVar("T", float, NDArray)
18+
T = float | NDArray[np.float64]
1919

2020

2121
@dataclass(frozen=True)

src/stagpy/field.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
def _threed_extract(
3939
conf: Config, step: Step, var: str, walls: bool = False
40-
) -> tuple[tuple[NDArray, NDArray], NDArray]:
40+
) -> tuple[tuple[NDArray[np.float64], NDArray[np.float64]], NDArray[np.float64]]:
4141
"""Return suitable slices and coords for 3D fields."""
4242
is_vector = not valid_field_var(var)
4343
hwalls = is_vector or walls
@@ -91,7 +91,7 @@ def valid_field_var(var: str) -> bool:
9191

9292
def get_meshes_fld(
9393
conf: Config, step: Step, var: str, walls: bool = False
94-
) -> tuple[NDArray, NDArray, NDArray, Varf]:
94+
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64], Varf]:
9595
"""Return scalar field along with coordinates meshes.
9696
9797
Only works properly in 2D geometry and 3D cartesian.
@@ -134,7 +134,9 @@ def get_meshes_fld(
134134

135135
def get_meshes_vec(
136136
conf: Config, step: Step, var: str
137-
) -> tuple[NDArray, NDArray, NDArray, NDArray]:
137+
) -> tuple[
138+
NDArray[np.float64], NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]
139+
]:
138140
"""Return vector field components along with coordinates meshes.
139141
140142
Only works properly in 2D geometry and 3D cartesian.
@@ -178,7 +180,7 @@ def get_meshes_vec(
178180
def plot_scalar(
179181
step: Step,
180182
var: str,
181-
field: NDArray | None = None,
183+
field: NDArray[np.float64] | None = None,
182184
axis: Axes | None = None,
183185
conf: Config | None = None,
184186
**extra: Any,
@@ -272,7 +274,7 @@ def plot_iso(
272274
axis: Axes,
273275
step: Step,
274276
var: str,
275-
field: NDArray | None = None,
277+
field: NDArray[np.float64] | None = None,
276278
conf: Config | None = None,
277279
**extra: Any,
278280
) -> None:

src/stagpy/plates.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .step import Geometry, Step
2828

2929

30-
def _vzcheck(iphis: Sequence[int], snap: Step, vz_thres: float) -> NDArray:
30+
def _vzcheck(iphis: Sequence[int], snap: Step, vz_thres: float) -> NDArray[np.int32]:
3131
"""Remove positions where vz is below threshold."""
3232
# verifying vertical velocity
3333
vzabs = np.abs(snap.fields["v3"].values[0, ..., 0])
@@ -40,7 +40,9 @@ def _vzcheck(iphis: Sequence[int], snap: Step, vz_thres: float) -> NDArray:
4040

4141

4242
@lru_cache
43-
def detect_plates(snap: Step, vz_thres_ratio: float = 0) -> tuple[NDArray, NDArray]:
43+
def detect_plates(
44+
snap: Step, vz_thres_ratio: float = 0
45+
) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
4446
"""Detect plate limits using derivative of horizontal velocity.
4547
4648
This function is cached for convenience.
@@ -98,7 +100,9 @@ def detect_plates(snap: Step, vz_thres_ratio: float = 0) -> tuple[NDArray, NDArr
98100
return itrenches, iridges
99101

100102

101-
def _plot_plate_limits(axis: Axes, trenches: NDArray, ridges: NDArray) -> None:
103+
def _plot_plate_limits(
104+
axis: Axes, trenches: NDArray[np.float64], ridges: NDArray[np.float64]
105+
) -> None:
102106
"""Plot lines designating ridges and trenches."""
103107
for trench in trenches:
104108
axis.axvline(x=trench, color="red", ls="dashed", alpha=0.4)
@@ -183,7 +187,7 @@ def _surf_diag(snap: Step, name: str) -> Field:
183187
raise error.UnknownVarError(name)
184188

185189

186-
def _continents_location(snap: Step, at_surface: bool = True) -> NDArray:
190+
def _continents_location(snap: Step, at_surface: bool = True) -> NDArray[np.bool]:
187191
"""Location of continents as a boolean array.
188192
189193
If at_surface is True, it is evaluated only at the surface, otherwise it is

src/stagpy/processing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def delta_r(step: Step) -> Rprof:
119119
return Rprof((edges[1:] - edges[:-1]), step.rprofs.centers, meta)
120120

121121

122-
def _scale_prof(step: Step, rprof: NDArray, rad: NDArray | None = None) -> NDArray:
122+
def _scale_prof(
123+
step: Step, rprof: NDArray[np.float64], rad: NDArray[np.float64] | None = None
124+
) -> NDArray[np.float64]:
123125
"""Scale profile to take sphericity into account."""
124126
rbot, rtop = step.rprofs.bounds
125127
if rbot == 0: # not spherical

src/stagpy/stagyydata.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,12 @@ def tslice(
212212
)
213213

214214
@property
215-
def time(self) -> NDArray:
215+
def time(self) -> NDArray[np.float64]:
216216
"""Time vector."""
217217
return self._tseries["t"].to_numpy()
218218

219219
@property
220-
def isteps(self) -> NDArray:
220+
def isteps(self) -> NDArray[np.float64]:
221221
"""Step indices.
222222
223223
This is such that `time[istep]` is at step `isteps[istep]`.
@@ -282,11 +282,11 @@ def _first_rprofs(self) -> step.RprofsInstant:
282282
return first_step.rprofs
283283

284284
@property
285-
def centers(self) -> NDArray:
285+
def centers(self) -> NDArray[np.float64]:
286286
return self._first_rprofs.centers
287287

288288
@property
289-
def walls(self) -> NDArray:
289+
def walls(self) -> NDArray[np.float64]:
290290
return self._first_rprofs.walls
291291

292292
@property

src/stagpy/stagyyparsers.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def field_header(fieldfile: Path) -> dict[str, Any] | None:
477477
return hdr.header
478478

479479

480-
def fields(fieldfile: Path) -> tuple[dict[str, Any], NDArray] | None:
480+
def fields(fieldfile: Path) -> tuple[dict[str, Any], NDArray[np.float64]] | None:
481481
"""Extract fields data.
482482
483483
Args:
@@ -555,7 +555,7 @@ def fields(fieldfile: Path) -> tuple[dict[str, Any], NDArray] | None:
555555
return header, flds
556556

557557

558-
def tracers(tracersfile: Path) -> dict[str, list[NDArray]] | None:
558+
def tracers(tracersfile: Path) -> dict[str, list[NDArray[np.float64]]] | None:
559559
"""Extract tracers data.
560560
561561
Args:
@@ -566,7 +566,7 @@ def tracers(tracersfile: Path) -> dict[str, list[NDArray]] | None:
566566
"""
567567
if not tracersfile.is_file():
568568
return None
569-
tra: dict[str, list[NDArray]] = {}
569+
tra: dict[str, list[NDArray[np.float64]]] = {}
570570
with tracersfile.open("rb") as fid:
571571
readbin = partial(_readbin, fid)
572572
magic = readbin()
@@ -603,7 +603,7 @@ def tracers(tracersfile: Path) -> dict[str, list[NDArray]] | None:
603603
return tra
604604

605605

606-
def _read_group_h5(filename: Path, groupname: str) -> NDArray:
606+
def _read_group_h5(filename: Path, groupname: str) -> NDArray[np.float64]:
607607
"""Return group content.
608608
609609
Args:
@@ -623,7 +623,7 @@ def _read_group_h5(filename: Path, groupname: str) -> NDArray:
623623
return data # need to be reshaped
624624

625625

626-
def _make_3d(field: NDArray, twod: str | None) -> NDArray:
626+
def _make_3d(field: NDArray[np.float64], twod: str | None) -> NDArray[np.float64]:
627627
"""Add a dimension to field if necessary.
628628
629629
Args:
@@ -641,7 +641,9 @@ def _make_3d(field: NDArray, twod: str | None) -> NDArray:
641641
return field.reshape(shp)
642642

643643

644-
def _ncores(meshes: list[dict[str, NDArray]], twod: str | None) -> NDArray:
644+
def _ncores(
645+
meshes: list[dict[str, NDArray[np.float64]]], twod: str | None
646+
) -> NDArray[np.float64]:
645647
"""Compute number of nodes in each direction."""
646648
nnpb = len(meshes) # number of nodes per block
647649
nns = [1, 1, 1] # number of nodes in x, y, z directions
@@ -672,8 +674,8 @@ def _ncores(meshes: list[dict[str, NDArray]], twod: str | None) -> NDArray:
672674

673675

674676
def _conglomerate_meshes(
675-
meshin: list[dict[str, NDArray]], header: dict[str, Any]
676-
) -> dict[str, NDArray]:
677+
meshin: list[dict[str, NDArray[np.float64]]], header: dict[str, Any]
678+
) -> dict[str, NDArray[np.float64]]:
677679
"""Conglomerate meshes from several cores into one."""
678680
meshout = {}
679681
npc = header["nts"] // header["ncs"]
@@ -870,7 +872,7 @@ def read_geom_h5(xdmf: FieldXmf, snapshot: int) -> dict[str, Any]:
870872
header["mo_thick_sol"] = entry.mo_thick_sol
871873
header["ntb"] = 2 if entry.yin_yang else 1
872874

873-
all_meshes: list[dict[str, NDArray]] = []
875+
all_meshes: list[dict[str, NDArray[np.float64]]] = []
874876
for h5file in entry.coord_files_yin(xdmf.path.parent):
875877
all_meshes.append({})
876878
with h5py.File(h5file, "r") as h5f:
@@ -927,7 +929,9 @@ def read_geom_h5(xdmf: FieldXmf, snapshot: int) -> dict[str, Any]:
927929
return header
928930

929931

930-
def _to_spherical(flds: NDArray, header: dict[str, Any]) -> NDArray:
932+
def _to_spherical(
933+
flds: NDArray[np.float64], header: dict[str, Any]
934+
) -> NDArray[np.float64]:
931935
"""Convert vector field to spherical."""
932936
cth = np.cos(header["t_mesh"][:, :, :-1])
933937
sth = np.sin(header["t_mesh"][:, :, :-1])
@@ -962,7 +966,9 @@ def _flds_shape(fieldname: str, header: dict[str, Any]) -> list[int]:
962966
return shp
963967

964968

965-
def _post_read_flds(flds: NDArray, header: dict[str, Any]) -> NDArray:
969+
def _post_read_flds(
970+
flds: NDArray[np.float64], header: dict[str, Any]
971+
) -> NDArray[np.float64]:
966972
"""Process flds to handle sphericity."""
967973
if flds.shape[0] >= 3 and header["rcmb"] > 0:
968974
# spherical vector
@@ -982,7 +988,7 @@ def read_field_h5(
982988
fieldname: str,
983989
snapshot: int,
984990
header: dict[str, Any] | None = None,
985-
) -> tuple[dict[str, Any], NDArray] | None:
991+
) -> tuple[dict[str, Any], NDArray[np.float64]] | None:
986992
"""Extract field data from hdf5 files.
987993
988994
Args:
@@ -1146,7 +1152,9 @@ def __getitem__(self, isnap: int) -> XmfTracersEntry:
11461152
raise ParsingError(self.path, f"no data for snapshot {isnap}")
11471153

11481154

1149-
def read_tracers_h5(xdmf: TracersXmf, infoname: str, snapshot: int) -> list[NDArray]:
1155+
def read_tracers_h5(
1156+
xdmf: TracersXmf, infoname: str, snapshot: int
1157+
) -> list[NDArray[np.float64]]:
11501158
"""Extract tracers data from hdf5 files.
11511159
11521160
Args:
@@ -1157,11 +1165,11 @@ def read_tracers_h5(xdmf: TracersXmf, infoname: str, snapshot: int) -> list[NDAr
11571165
Returns:
11581166
Tracers data organized by attribute and block.
11591167
"""
1160-
tra: list[list[NDArray]] = [[], []] # [block][core]
1168+
tra: list[list[NDArray[np.float64]]] = [[], []] # [block][core]
11611169
for tsub in xdmf[snapshot].tra_subdomains(xdmf.path.parent, infoname):
11621170
tra[tsub.iblock].append(_read_group_h5(tsub.file, tsub.dataset))
11631171

1164-
tra_concat: list[NDArray] = []
1172+
tra_concat: list[NDArray[np.float64]] = []
11651173
for trab in tra:
11661174
if trab:
11671175
tra_concat.append(np.concatenate(trab))

0 commit comments

Comments
 (0)