Skip to content

Commit 0cdc85a

Browse files
committed
Replace NamedTuples with dataclasses
1 parent 302123d commit 0cdc85a

File tree

11 files changed

+126
-86
lines changed

11 files changed

+126
-86
lines changed

stagpy/_step.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,10 +508,16 @@ def __getitem__(self, name: str) -> Rprof:
508508
else:
509509
meta = Varr(name, "", "1")
510510
elif name in self._cached_extra:
511-
rprof, rad, meta = self._cached_extra[name]
511+
rpf = self._cached_extra[name]
512+
rprof = rpf.values
513+
rad = rpf.rad
514+
meta = rpf.meta
512515
elif name in phyvars.RPROF_EXTRA:
513516
self._cached_extra[name] = phyvars.RPROF_EXTRA[name](step)
514-
rprof, rad, meta = self._cached_extra[name]
517+
rpf = self._cached_extra[name]
518+
rprof = rpf.values
519+
rad = rpf.rad
520+
meta = rpf.meta
515521
else:
516522
raise error.UnknownRprofVarError(name)
517523
rprof, _ = step.sdat.scale(rprof, meta.dim)

stagpy/datatypes.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, NamedTuple
5+
import typing
6+
from dataclasses import dataclass
67

7-
if TYPE_CHECKING:
8-
from numpy import ndarray
8+
if typing.TYPE_CHECKING:
9+
from numpy.typing import NDArray
910

1011

11-
class Varf(NamedTuple):
12+
@dataclass(frozen=True)
13+
class Varf:
1214
"""Metadata of scalar field.
1315
1416
Attributes:
@@ -21,19 +23,21 @@ class Varf(NamedTuple):
2123
dim: str
2224

2325

24-
class Field(NamedTuple):
26+
@dataclass(frozen=True)
27+
class Field:
2528
"""Scalar field and associated metadata.
2629
2730
Attributes:
2831
values: the field itself.
2932
meta: the metadata of the field.
3033
"""
3134

32-
values: ndarray
35+
values: NDArray
3336
meta: Varf
3437

3538

36-
class Varr(NamedTuple):
39+
@dataclass(frozen=True)
40+
class Varr:
3741
"""Metadata of radial profiles.
3842
3943
Attributes:
@@ -46,12 +50,12 @@ class Varr(NamedTuple):
4650
"""
4751

4852
description: str
49-
# Callable[[Step], Tuple[ndarray, ndarray]]]
5053
kind: str
5154
dim: str
5255

5356

54-
class Rprof(NamedTuple):
57+
@dataclass(frozen=True)
58+
class Rprof:
5559
"""Radial profile with associated radius and metadata.
5660
5761
Attributes:
@@ -60,12 +64,13 @@ class Rprof(NamedTuple):
6064
meta: the metadata of the profile.
6165
"""
6266

63-
values: ndarray
64-
rad: ndarray
67+
values: NDArray
68+
rad: NDArray
6569
meta: Varr
6670

6771

68-
class Vart(NamedTuple):
72+
@dataclass(frozen=True)
73+
class Vart:
6974
"""Metadata of time series.
7075
7176
Attributes:
@@ -82,7 +87,8 @@ class Vart(NamedTuple):
8287
dim: str
8388

8489

85-
class Tseries(NamedTuple):
90+
@dataclass(frozen=True)
91+
class Tseries:
8692
"""A time series with associated time and metadata.
8793
8894
Attributes:
@@ -91,6 +97,6 @@ class Tseries(NamedTuple):
9197
meta: the metadata of the series.
9298
"""
9399

94-
values: ndarray
95-
time: ndarray
100+
values: NDArray
101+
time: NDArray
96102
meta: Vart

stagpy/field.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
def _threed_extract(
3636
step: Step, var: str, walls: bool = False
37-
) -> Tuple[Tuple[ndarray, ndarray], Any]:
37+
) -> Tuple[Tuple[ndarray, ndarray], ndarray]:
3838
"""Return suitable slices and coords for 3D fields."""
3939
is_vector = not valid_field_var(var)
4040
hwalls = is_vector or walls
@@ -101,24 +101,28 @@ def get_meshes_fld(
101101
x position, y position, the values and the metadata of the requested
102102
field.
103103
"""
104-
fld, meta = step.fields[var]
105-
hwalls = walls or fld.shape[0] != step.geom.nxtot or fld.shape[1] != step.geom.nytot
104+
fld = step.fields[var]
105+
hwalls = (
106+
walls
107+
or fld.values.shape[0] != step.geom.nxtot
108+
or fld.values.shape[1] != step.geom.nytot
109+
)
106110
if step.geom.threed and step.geom.cartesian:
107-
(xcoord, ycoord), fld = _threed_extract(step, var, walls)
111+
(xcoord, ycoord), vals = _threed_extract(step, var, walls)
108112
elif step.geom.twod_xz:
109113
xcoord = step.geom.x_walls if hwalls else step.geom.x_centers
110114
ycoord = step.geom.z_walls if walls else step.geom.z_centers
111-
fld = fld[:, 0, :, 0]
115+
vals = fld.values[:, 0, :, 0]
112116
else: # twod_yz
113117
xcoord = step.geom.y_walls if hwalls else step.geom.y_centers
114118
ycoord = step.geom.z_walls if walls else step.geom.z_centers
115119
if step.geom.curvilinear:
116120
pmesh, rmesh = np.meshgrid(xcoord, ycoord, indexing="ij")
117121
xmesh, ymesh = rmesh * np.cos(pmesh), rmesh * np.sin(pmesh)
118-
fld = fld[0, :, :, 0]
122+
vals = fld.values[0, :, :, 0]
119123
if step.geom.cartesian:
120124
xmesh, ymesh = np.meshgrid(xcoord, ycoord, indexing="ij")
121-
return xmesh, ymesh, fld, meta
125+
return xmesh, ymesh, vals, fld.meta
122126

123127

124128
def get_meshes_vec(step: Step, var: str) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
@@ -336,15 +340,15 @@ def _findminmax(
336340
for step in sdat.walk.filter(snap=True):
337341
for var in sovs:
338342
if var in step.fields:
339-
field, meta = step.fields[var]
340-
field, _ = sdat.scale(field, meta.dim)
343+
field = step.fields[var]
344+
vals, _ = sdat.scale(field.values, field.meta.dim)
341345
if var in minmax:
342346
minmax[var] = (
343-
min(minmax[var][0], np.nanmin(field)),
344-
max(minmax[var][1], np.nanmax(field)),
347+
min(minmax[var][0], np.nanmin(vals)),
348+
max(minmax[var][1], np.nanmax(vals)),
345349
)
346350
else:
347-
minmax[var] = np.nanmin(field), np.nanmax(field)
351+
minmax[var] = np.nanmin(vals), np.nanmax(vals)
348352
return minmax
349353

350354

stagpy/plates.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def _surf_diag(snap: Step, name: str) -> Field:
166166
return snap.sfields[name]
167167
isurf = _isurf(snap)
168168
with suppress(error.UnknownVarError):
169-
field, meta = snap.fields[name]
170-
return Field(field[0, :, isurf, 0], meta)
169+
field = snap.fields[name]
170+
return Field(field.values[0, :, isurf, 0], field.meta)
171171
if name == "dv2":
172172
vphi = snap.fields["v2"].values[0, :, isurf, 0]
173173
if snap.geom.cartesian:
@@ -229,14 +229,14 @@ def plot_at_surface(snap: Step, names: Sequence[Sequence[Sequence[str]]]) -> Non
229229
fname += "_".join(vplt) + "_"
230230
label = ""
231231
for name in vplt:
232-
data, meta = _surf_diag(snap, name)
233-
label = meta.description
232+
field = _surf_diag(snap, name)
233+
label = field.meta.description
234234
phi = (
235235
snap.geom.p_centers
236-
if data.size == snap.geom.nptot
236+
if field.values.size == snap.geom.nptot
237237
else snap.geom.p_walls
238238
)
239-
axis.plot(phi, data, label=label)
239+
axis.plot(phi, field.values, label=label)
240240
axis.set_ylim([conf.plot.vmin, conf.plot.vmax])
241241
if conf.plates.continents:
242242
continents = _continents_location(snap)

stagpy/processing.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def dt_dt(sdat: StagyyData) -> Tseries:
5050
Returns:
5151
derivative of temperature and time arrays.
5252
"""
53-
temp, time, _ = sdat.tseries["Tmean"]
53+
series = sdat.tseries["Tmean"]
54+
temp = series.values
55+
time = series.time
5456
dtdt = (temp[1:] - temp[:-1]) / (time[1:] - time[:-1])
5557
return Tseries(dtdt, time[:-1], Vart("Derivative of temperature", r"dT/dt", "K/s"))
5658

@@ -73,12 +75,12 @@ def ebalance(sdat: StagyyData) -> Tseries:
7375
else:
7476
coefsurf = 1.0
7577
volume = 1.0
76-
dtdt, time, _ = dt_dt(sdat)
78+
dtdt = dt_dt(sdat)
7779
ftop = sdat.tseries["ftop"].values * coefsurf
7880
fbot = sdat.tseries["fbot"].values
7981
radio = sdat.tseries["H_int"].values
80-
ebal = ftop[1:] - fbot[1:] + volume * (dtdt - radio[1:])
81-
return Tseries(ebal, time, Vart("Energy balance", r"$\mathrm{Nu}$", "1"))
82+
ebal = ftop[1:] - fbot[1:] + volume * (dtdt.values - radio[1:])
83+
return Tseries(ebal, dtdt.time, Vart("Energy balance", r"$\mathrm{Nu}$", "1"))
8284

8385

8486
def mobility(sdat: StagyyData) -> Tseries:
@@ -154,9 +156,9 @@ def diffs_prof(step: Step) -> Rprof:
154156
Returns:
155157
the diffusion and radius.
156158
"""
157-
diff, rad, _ = diff_prof(step)
159+
rpf = diff_prof(step)
158160
meta = Varr("Scaled diffusion", "Heat flux", "W/m2")
159-
return Rprof(_scale_prof(step, diff, rad), rad, meta)
161+
return Rprof(_scale_prof(step, rpf.values, rpf.rad), rpf.rad, meta)
160162

161163

162164
def advts_prof(step: Step) -> Rprof:
@@ -220,10 +222,12 @@ def energy_prof(step: Step) -> Rprof:
220222
Returns:
221223
the energy flux and radius.
222224
"""
223-
diff, rad, _ = diffs_prof(step)
224-
adv, _, _ = advts_prof(step)
225+
diff_p = diffs_prof(step)
226+
adv_p = advts_prof(step)
225227
return Rprof(
226-
diff + np.append(adv, 0), rad, Varr("Total heat flux", "Heat flux", "W/m2")
228+
diff_p.values + np.append(adv_p.values, 0),
229+
diff_p.rad,
230+
Varr("Total heat flux", "Heat flux", "W/m2"),
227231
)
228232

229233

stagpy/rprof.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def plot_rprofs(rprofs: _Rprofs, names: Sequence[Sequence[Sequence[str]]]) -> No
3636
xlabel = None
3737
profs_on_plt = (rprofs[rvar] for rvar in vplt)
3838
fname += "_".join(vplt) + "_"
39-
for ivar, (rprof, rad, meta) in enumerate(profs_on_plt):
39+
for ivar, rpf in enumerate(profs_on_plt):
40+
rprof = rpf.values
41+
rad = rpf.rad
42+
meta = rpf.meta
4043
if conf.rprof.depth:
4144
rad = sdat.scale(rprofs.bounds[1], "m")[0] - rad
4245
axes[iplt].plot(rprof, rad, conf.rprof.style, label=meta.description)
@@ -73,16 +76,16 @@ def plot_grid(step: Step) -> None:
7376
step (:class:`~stagpy._step.Step`): a step of a StagyyData
7477
instance.
7578
"""
76-
drad, rad, _ = step.rprofs["dr"]
79+
drprof = step.rprofs["dr"]
7780
_, unit = step.sdat.scale(1, "m")
7881
if unit:
7982
unit = f" ({unit})"
8083
fig, (ax1, ax2) = plt.subplots(2, sharex=True)
81-
ax1.plot(rad, "-ko")
84+
ax1.plot(drprof.rad, "-ko")
8285
ax1.set_ylabel("$r$" + unit)
83-
ax2.plot(drad, "-ko")
86+
ax2.plot(drprof.values, "-ko")
8487
ax2.set_ylabel("$dr$" + unit)
85-
ax2.set_xlim([-0.5, len(rad) - 0.5])
88+
ax2.set_xlim([-0.5, len(drprof.rad) - 0.5])
8689
ax2.set_xlabel("Cell number")
8790
_helpers.saveplot(fig, "grid", step.istep)
8891

stagpy/stagyydata.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,16 @@ def __getitem__(self, name: str) -> Tseries:
267267
else:
268268
meta = Vart(name, "", "1")
269269
elif name in self._cached_extra:
270-
series, time, meta = self._cached_extra[name]
270+
tseries = self._cached_extra[name]
271+
series = tseries.values
272+
time = tseries.time
273+
meta = tseries.meta
271274
elif name in phyvars.TIME_EXTRA:
272275
self._cached_extra[name] = phyvars.TIME_EXTRA[name](self.sdat)
273-
series, time, meta = self._cached_extra[name]
276+
tseries = self._cached_extra[name]
277+
series = tseries.values
278+
time = tseries.time
279+
meta = tseries.meta
274280
else:
275281
raise error.UnknownTimeVarError(name)
276282
series, _ = self.sdat.scale(series, meta.dim)
@@ -289,14 +295,18 @@ def tslice(
289295
tend: ending time. Set to None to stop at the end of available
290296
data.
291297
"""
292-
data, time, meta = self[name]
298+
series = self[name]
293299
istart = 0
294-
iend = len(time)
300+
iend = len(series.time)
295301
if tstart is not None:
296-
istart = _helpers.find_in_sorted_arr(tstart, time)
302+
istart = _helpers.find_in_sorted_arr(tstart, series.time)
297303
if tend is not None:
298-
iend = _helpers.find_in_sorted_arr(tend, time, True) + 1
299-
return Tseries(data[istart:iend], time[istart:iend], meta)
304+
iend = _helpers.find_in_sorted_arr(tend, series.time, True) + 1
305+
return Tseries(
306+
series.values[istart:iend],
307+
series.time[istart:iend],
308+
series.meta,
309+
)
300310

301311
@property
302312
def time(self) -> ndarray:
@@ -342,14 +352,14 @@ def __getitem__(self, name: str) -> Rprof:
342352
if name in self._cached_data:
343353
return self._cached_data[name]
344354
steps_iter = iter(self.steps)
345-
rprof, rad, meta = next(steps_iter).rprofs[name]
346-
rprof = np.copy(rprof)
355+
rpf = next(steps_iter).rprofs[name]
356+
rprof = np.copy(rpf.values)
347357
nprofs = 1
348358
for step in steps_iter:
349359
nprofs += 1
350360
rprof += step.rprofs[name].values
351361
rprof /= nprofs
352-
self._cached_data[name] = Rprof(rprof, rad, meta)
362+
self._cached_data[name] = Rprof(rprof, rpf.rad, rpf.meta)
353363
return self._cached_data[name]
354364

355365
@property

stagpy/stagyyparsers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import re
1111
import typing
12+
from dataclasses import dataclass
1213
from functools import partial
1314
from itertools import product
1415
from operator import itemgetter
@@ -368,14 +369,15 @@ def _readbin(
368369
return elts
369370

370371

371-
class _HeaderInfo(typing.NamedTuple):
372+
@dataclass(frozen=True)
373+
class _HeaderInfo:
372374
"""Header information."""
373375

374376
magic: int
375377
nval: int
376378
sfield: bool
377379
readbin: Callable
378-
header: Dict[str, Any]
380+
header: dict[str, Any]
379381

380382

381383
def _legacy_header(

0 commit comments

Comments
 (0)