Skip to content

Commit daa14d9

Browse files
committed
stagyyparsers: annotate with builtins
1 parent 2732b33 commit daa14d9

File tree

1 file changed

+32
-43
lines changed

1 file changed

+32
-43
lines changed

stagpy/stagyyparsers.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,15 @@
2525

2626
if typing.TYPE_CHECKING:
2727
from pathlib import Path
28-
from typing import (
29-
Any,
30-
BinaryIO,
31-
Callable,
32-
Dict,
33-
Iterator,
34-
List,
35-
Mapping,
36-
Optional,
37-
Tuple,
38-
)
28+
from typing import Any, BinaryIO, Callable, Iterator, Mapping, Optional
3929
from xml.etree.ElementTree import Element
4030

41-
from numpy import ndarray
4231
from numpy.typing import NDArray
4332
from pandas import DataFrame
4433

4534

4635
def _tidy_names(
47-
names: List[str], nnames: int, extra_names: Optional[List[str]] = None
36+
names: list[str], nnames: int, extra_names: Optional[list[str]] = None
4837
) -> None:
4938
"""Truncate or extend names so that its len is nnames.
5039
@@ -62,7 +51,7 @@ def _tidy_names(
6251
del names[nnames:]
6352

6453

65-
def time_series(timefile: Path, colnames: List[str]) -> Optional[DataFrame]:
54+
def time_series(timefile: Path, colnames: list[str]) -> Optional[DataFrame]:
6655
"""Read temporal series text file.
6756
6857
If :data:`colnames` is too long, it will be truncated. If it is too short,
@@ -113,7 +102,7 @@ def time_series(timefile: Path, colnames: List[str]) -> Optional[DataFrame]:
113102
return data
114103

115104

116-
def time_series_h5(timefile: Path, colnames: List[str]) -> Optional[DataFrame]:
105+
def time_series_h5(timefile: Path, colnames: list[str]) -> Optional[DataFrame]:
117106
"""Read temporal series HDF5 file.
118107
119108
If :data:`colnames` is too long, it will be truncated. If it is too short,
@@ -144,7 +133,7 @@ def time_series_h5(timefile: Path, colnames: List[str]) -> Optional[DataFrame]:
144133

145134
def _extract_rsnap_isteps(
146135
rproffile: Path, data: DataFrame
147-
) -> List[Tuple[int, float, DataFrame]]:
136+
) -> list[tuple[int, float, DataFrame]]:
148137
"""Extract istep, time and build separate rprof df."""
149138
step_regex = re.compile(r"^\*+step:\s*(\d+) ; time =\s*(\S+)")
150139
isteps = [] # list of (istep, time, df)
@@ -179,8 +168,8 @@ def _extract_rsnap_isteps(
179168

180169

181170
def rprof(
182-
rproffile: Path, colnames: List[str]
183-
) -> Tuple[Dict[int, DataFrame], Optional[DataFrame]]:
171+
rproffile: Path, colnames: list[str]
172+
) -> tuple[dict[int, DataFrame], Optional[DataFrame]]:
184173
"""Extract radial profiles data.
185174
186175
If :data:`colnames` is too long, it will be truncated. If it is too short,
@@ -228,8 +217,8 @@ def rprof(
228217

229218

230219
def rprof_h5(
231-
rproffile: Path, colnames: List[str]
232-
) -> Tuple[Dict[int, DataFrame], Optional[DataFrame]]:
220+
rproffile: Path, colnames: list[str]
221+
) -> tuple[dict[int, DataFrame], Optional[DataFrame]]:
233222
"""Extract radial profiles data.
234223
235224
If :data:`colnames` is too long, it will be truncated. If it is too short,
@@ -266,7 +255,7 @@ def rprof_h5(
266255
return data, df_times
267256

268257

269-
def _clean_names_refstate(names: List[str]) -> List[str]:
258+
def _clean_names_refstate(names: list[str]) -> list[str]:
270259
"""Uniformization of refstate profile names."""
271260
to_clean = {
272261
"Tref": "T",
@@ -278,15 +267,15 @@ def _clean_names_refstate(names: List[str]) -> List[str]:
278267

279268
def refstate(
280269
reffile: Path, ncols: int = 8
281-
) -> Optional[Tuple[List[List[DataFrame]], List[DataFrame]]]:
270+
) -> Optional[tuple[list[list[DataFrame]], list[DataFrame]]]:
282271
"""Extract reference state profiles.
283272
284273
Args:
285274
reffile: path of the refstate file.
286275
ncols: number of columns.
287276
288277
Returns:
289-
Tuple (syst, adia).
278+
tuple (syst, adia).
290279
291280
:data:`syst` is a list of list of
292281
:class:`pandas.DataFrame` containing the reference state profiles for
@@ -312,8 +301,8 @@ def refstate(
312301
# drop lines corresponding to metadata
313302
data.dropna(subset=[0], inplace=True)
314303
isystem = -1
315-
systems: List[List[List[str]]] = [[]]
316-
adiabats: List[List[str]] = []
304+
systems: list[list[list[str]]] = [[]]
305+
adiabats: list[list[str]] = []
317306
with reffile.open() as rsf:
318307
for line in rsf:
319308
line = line.lstrip()
@@ -329,8 +318,8 @@ def refstate(
329318
nprofs = sum(map(len, systems)) + len(adiabats)
330319
nzprof = len(data) // nprofs
331320
iprof = 0
332-
syst: List[List[DataFrame]] = []
333-
adia: List[DataFrame] = []
321+
syst: list[list[DataFrame]] = []
322+
adia: list[DataFrame] = []
334323
for isys, layers in enumerate(systems):
335324
syst.append([])
336325
for layer in layers:
@@ -479,7 +468,7 @@ def field_istep(fieldfile: Path) -> Optional[int]:
479468
return hdr.header["ti_step"]
480469

481470

482-
def field_header(fieldfile: Path) -> Optional[Dict[str, Any]]:
471+
def field_header(fieldfile: Path) -> Optional[dict[str, Any]]:
483472
"""Read header info from binary field file.
484473
485474
Args:
@@ -495,7 +484,7 @@ def field_header(fieldfile: Path) -> Optional[Dict[str, Any]]:
495484
return hdr.header
496485

497486

498-
def fields(fieldfile: Path) -> Optional[Tuple[Dict[str, Any], ndarray]]:
487+
def fields(fieldfile: Path) -> Optional[tuple[dict[str, Any], NDArray]]:
499488
"""Extract fields data.
500489
501490
Args:
@@ -577,7 +566,7 @@ def fields(fieldfile: Path) -> Optional[Tuple[Dict[str, Any], ndarray]]:
577566
return header, flds
578567

579568

580-
def tracers(tracersfile: Path) -> Optional[Dict[str, List[ndarray]]]:
569+
def tracers(tracersfile: Path) -> Optional[dict[str, list[NDArray]]]:
581570
"""Extract tracers data.
582571
583572
Args:
@@ -588,7 +577,7 @@ def tracers(tracersfile: Path) -> Optional[Dict[str, List[ndarray]]]:
588577
"""
589578
if not tracersfile.is_file():
590579
return None
591-
tra: Dict[str, List[ndarray]] = {}
580+
tra: dict[str, list[NDArray]] = {}
592581
with tracersfile.open("rb") as fid:
593582
readbin = partial(_readbin, fid)
594583
magic = readbin()
@@ -625,7 +614,7 @@ def tracers(tracersfile: Path) -> Optional[Dict[str, List[ndarray]]]:
625614
return tra
626615

627616

628-
def _read_group_h5(filename: Path, groupname: str) -> ndarray:
617+
def _read_group_h5(filename: Path, groupname: str) -> NDArray:
629618
"""Return group content.
630619
631620
Args:
@@ -644,7 +633,7 @@ def _read_group_h5(filename: Path, groupname: str) -> ndarray:
644633
return data # need to be reshaped
645634

646635

647-
def _make_3d(field: ndarray, twod: Optional[str]) -> ndarray:
636+
def _make_3d(field: NDArray, twod: Optional[str]) -> NDArray:
648637
"""Add a dimension to field if necessary.
649638
650639
Args:
@@ -661,7 +650,7 @@ def _make_3d(field: ndarray, twod: Optional[str]) -> ndarray:
661650
return field.reshape(shp)
662651

663652

664-
def _ncores(meshes: List[Dict[str, ndarray]], twod: Optional[str]) -> ndarray:
653+
def _ncores(meshes: list[dict[str, NDArray]], twod: Optional[str]) -> NDArray:
665654
"""Compute number of nodes in each direction."""
666655
nnpb = len(meshes) # number of nodes per block
667656
nns = [1, 1, 1] # number of nodes in x, y, z directions
@@ -692,8 +681,8 @@ def _ncores(meshes: List[Dict[str, ndarray]], twod: Optional[str]) -> ndarray:
692681

693682

694683
def _conglomerate_meshes(
695-
meshin: List[Dict[str, ndarray]], header: Dict[str, Any]
696-
) -> Dict[str, ndarray]:
684+
meshin: list[dict[str, NDArray]], header: dict[str, Any]
685+
) -> dict[str, NDArray]:
697686
"""Conglomerate meshes from several cores into one."""
698687
meshout = {}
699688
npc = header["nts"] // header["ncs"]
@@ -882,7 +871,7 @@ def read_geom_h5(xdmf: FieldXmf, snapshot: int) -> dict[str, Any]:
882871
Returns:
883872
geometry information.
884873
"""
885-
header: Dict[str, Any] = {}
874+
header: dict[str, Any] = {}
886875

887876
entry = xdmf[snapshot]
888877
header["ti_ad"] = entry.time
@@ -947,7 +936,7 @@ def read_geom_h5(xdmf: FieldXmf, snapshot: int) -> dict[str, Any]:
947936
return header
948937

949938

950-
def _to_spherical(flds: ndarray, header: Dict[str, Any]) -> ndarray:
939+
def _to_spherical(flds: NDArray, header: dict[str, Any]) -> NDArray:
951940
"""Convert vector field to spherical."""
952941
cth = np.cos(header["t_mesh"][:, :, :-1])
953942
sth = np.sin(header["t_mesh"][:, :, :-1])
@@ -960,7 +949,7 @@ def _to_spherical(flds: ndarray, header: Dict[str, Any]) -> ndarray:
960949
return fout
961950

962951

963-
def _flds_shape(fieldname: str, header: Dict[str, Any]) -> List[int]:
952+
def _flds_shape(fieldname: str, header: dict[str, Any]) -> list[int]:
964953
"""Compute shape of flds variable."""
965954
shp = list(header["nts"])
966955
shp.append(header["ntb"])
@@ -982,7 +971,7 @@ def _flds_shape(fieldname: str, header: Dict[str, Any]) -> List[int]:
982971
return shp
983972

984973

985-
def _post_read_flds(flds: ndarray, header: Dict[str, Any]) -> ndarray:
974+
def _post_read_flds(flds: NDArray, header: dict[str, Any]) -> NDArray:
986975
"""Process flds to handle sphericity."""
987976
if flds.shape[0] >= 3 and header["rcmb"] > 0:
988977
# spherical vector
@@ -1001,8 +990,8 @@ def read_field_h5(
1001990
xdmf: FieldXmf,
1002991
fieldname: str,
1003992
snapshot: int,
1004-
header: Optional[Dict[str, Any]] = None,
1005-
) -> Optional[Tuple[Dict[str, Any], ndarray]]:
993+
header: Optional[dict[str, Any]] = None,
994+
) -> Optional[tuple[dict[str, Any], NDArray]]:
1006995
"""Extract field data from hdf5 files.
1007996
1008997
Args:
@@ -1187,7 +1176,7 @@ def read_tracers_h5(xdmf: TracersXmf, infoname: str, snapshot: int) -> list[NDAr
11871176
return tra_concat
11881177

11891178

1190-
def read_time_h5(h5folder: Path) -> Iterator[Tuple[int, int]]:
1179+
def read_time_h5(h5folder: Path) -> Iterator[tuple[int, int]]:
11911180
"""Iterate through (isnap, istep) recorded in h5folder/'time_botT.h5'.
11921181
11931182
Args:

0 commit comments

Comments
 (0)