Skip to content

Commit 582cf28

Browse files
committed
field: reduce reliance on global conf
1 parent 131b7e4 commit 582cf28

File tree

3 files changed

+47
-23
lines changed

3 files changed

+47
-23
lines changed

stagpy/field.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import numpy as np
1212
from mpl_toolkits.axes_grid1 import make_axes_locatable
1313

14-
from . import _helpers, conf, phyvars
14+
from . import _helpers, phyvars
15+
from .config import Config
1516
from .error import NotAvailableError
1617
from .stagyydata import StagyyData
1718

@@ -33,7 +34,7 @@
3334

3435

3536
def _threed_extract(
36-
step: Step, var: str, walls: bool = False
37+
conf: Config, step: Step, var: str, walls: bool = False
3738
) -> Tuple[Tuple[ndarray, ndarray], ndarray]:
3839
"""Return suitable slices and coords for 3D fields."""
3940
is_vector = not valid_field_var(var)
@@ -86,7 +87,7 @@ def valid_field_var(var: str) -> bool:
8687

8788

8889
def get_meshes_fld(
89-
step: Step, var: str, walls: bool = False
90+
conf: Config, step: Step, var: str, walls: bool = False
9091
) -> Tuple[ndarray, ndarray, ndarray, Varf]:
9192
"""Return scalar field along with coordinates meshes.
9293
@@ -108,7 +109,7 @@ def get_meshes_fld(
108109
or fld.values.shape[1] != step.geom.nytot
109110
)
110111
if step.geom.threed and step.geom.cartesian:
111-
(xcoord, ycoord), vals = _threed_extract(step, var, walls)
112+
(xcoord, ycoord), vals = _threed_extract(conf, step, var, walls)
112113
elif step.geom.twod_xz:
113114
xcoord = step.geom.x_walls if hwalls else step.geom.x_centers
114115
ycoord = step.geom.z_walls if walls else step.geom.z_centers
@@ -125,7 +126,9 @@ def get_meshes_fld(
125126
return xmesh, ymesh, vals, fld.meta
126127

127128

128-
def get_meshes_vec(step: Step, var: str) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
129+
def get_meshes_vec(
130+
conf: Config, step: Step, var: str
131+
) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
129132
"""Return vector field components along with coordinates meshes.
130133
131134
Only works properly in 2D geometry and 3D cartesian.
@@ -139,7 +142,7 @@ def get_meshes_vec(step: Step, var: str) -> Tuple[ndarray, ndarray, ndarray, nda
139142
requested vector field.
140143
"""
141144
if step.geom.threed and step.geom.cartesian:
142-
(xcoord, ycoord), (vec1, vec2) = _threed_extract(step, var)
145+
(xcoord, ycoord), (vec1, vec2) = _threed_extract(conf, step, var)
143146
elif step.geom.twod_xz:
144147
xcoord, ycoord = step.geom.x_walls, step.geom.z_centers
145148
vec1 = step.fields[var + "1"].values[:, 0, :, 0]
@@ -168,6 +171,7 @@ def plot_scalar(
168171
var: str,
169172
field: Optional[ndarray] = None,
170173
axis: Optional[Axes] = None,
174+
conf: Optional[Config] = None,
171175
**extra: Any,
172176
) -> Tuple[Figure, Axes, QuadMesh, Optional[Colorbar]]:
173177
"""Plot scalar field.
@@ -191,10 +195,12 @@ def plot_scalar(
191195
:func:`~matplotlib.axes.Axes.pcolormesh`, and the colorbar returned
192196
by :func:`matplotlib.pyplot.colorbar`.
193197
"""
198+
if conf is None:
199+
conf = Config.default_()
194200
if step.geom.threed and step.geom.spherical:
195201
raise NotAvailableError("plot_scalar not implemented for 3D spherical geometry")
196202

197-
xmesh, ymesh, fld, meta = get_meshes_fld(step, var, walls=True)
203+
xmesh, ymesh, fld, meta = get_meshes_fld(conf, step, var, walls=True)
198204
# interpolate at cell centers, this should be abstracted by field objects
199205
# via an "at_cell_centers" method or similar
200206
if fld.shape[0] > max(step.geom.nxtot, step.geom.nytot):
@@ -262,7 +268,12 @@ def plot_scalar(
262268

263269

264270
def plot_iso(
265-
axis: Axes, step: Step, var: str, field: Optional[ndarray] = None, **extra: Any
271+
axis: Axes,
272+
step: Step,
273+
var: str,
274+
field: Optional[ndarray] = None,
275+
conf: Optional[Config] = None,
276+
**extra: Any,
266277
) -> None:
267278
"""Plot isocontours of scalar field.
268279
@@ -278,7 +289,9 @@ def plot_iso(
278289
extra: options that will be passed on to
279290
:func:`matplotlib.axes.Axes.contour`.
280291
"""
281-
xmesh, ymesh, fld, _ = get_meshes_fld(step, var)
292+
if conf is None:
293+
conf = Config.default_()
294+
xmesh, ymesh, fld, _ = get_meshes_fld(conf, step, var)
282295

283296
if field is not None:
284297
fld = field
@@ -296,7 +309,12 @@ def plot_iso(
296309
axis.contour(xmesh, ymesh, fld, **extra_opts)
297310

298311

299-
def plot_vec(axis: Axes, step: Step, var: str) -> None:
312+
def plot_vec(
313+
axis: Axes,
314+
step: Step,
315+
var: str,
316+
conf: Optional[Config] = None,
317+
) -> None:
300318
"""Plot vector field.
301319
302320
Args:
@@ -305,7 +323,9 @@ def plot_vec(axis: Axes, step: Step, var: str) -> None:
305323
step: a :class:`~stagpy._step.Step` of a StagyyData instance.
306324
var: the vector field name.
307325
"""
308-
xmesh, ymesh, vec1, vec2 = get_meshes_vec(step, var)
326+
if conf is None:
327+
conf = Config.default_()
328+
xmesh, ymesh, vec1, vec2 = get_meshes_vec(conf, step, var)
309329
dipz = step.geom.nztot // 10
310330
if conf.field.shift:
311331
vec1 = np.roll(vec1, conf.field.shift, axis=0)
@@ -351,6 +371,8 @@ def cmd() -> None:
351371
conf.field
352372
conf.core
353373
"""
374+
from . import conf
375+
354376
sdat = StagyyData(conf.core.path)
355377
# no more than two fields in a subplot
356378
lovs = [[slov[:2] for slov in plov] for plov in conf.field.plot]
@@ -372,12 +394,12 @@ def cmd() -> None:
372394
opts: Dict[str, Any] = {}
373395
if var[0] in minmax:
374396
opts = dict(vmin=minmax[var[0]][0], vmax=minmax[var[0]][1])
375-
plot_scalar(step, var[0], axis=axis, **opts)
397+
plot_scalar(step, var[0], axis=axis, conf=conf, **opts)
376398
if len(var) == 2:
377399
if valid_field_var(var[1]):
378-
plot_iso(axis, step, var[1])
400+
plot_iso(axis, step, var[1], conf=conf)
379401
elif valid_field_var(var[1] + "1"):
380-
plot_vec(axis, step, var[1])
402+
plot_vec(axis, step, var[1], conf=conf)
381403
if conf.field.timelabel:
382404
time, unit = sdat.scale(step.timeinfo["t"], "s")
383405
time = _helpers.scilabel(time)

stagpy/plates.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def plot_scalar_field(
342342
"""
343343
if conf is None:
344344
conf = Config.default_()
345-
fig, axis, _, _ = field.plot_scalar(snap, fieldname)
345+
fig, axis, _, _ = field.plot_scalar(snap, fieldname, conf=conf)
346346

347347
if conf.plates.continents:
348348
c_field = np.ma.masked_where(
@@ -356,12 +356,13 @@ def plot_scalar_field(
356356
"c",
357357
c_field,
358358
axis,
359+
conf=conf,
359360
cmap=cmap,
360361
norm=colors.BoundaryNorm([2, 3, 4, 5], cmap.N),
361362
)
362363

363364
# plotting velocity vectors
364-
field.plot_vec(axis, snap, "sx" if conf.plates.stress else "v")
365+
field.plot_vec(axis, snap, "sx" if conf.plates.stress else "v", conf=conf)
365366

366367
# Put arrow where ridges and trenches are
367368
_plot_plate_limits_field(axis, snap, conf)

tests/test_field.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import pytest
22

33
import stagpy.error
4-
import stagpy.field
54
import stagpy.phyvars
5+
from stagpy.config import Config
6+
from stagpy.field import get_meshes_fld, get_meshes_vec, valid_field_var
67
from stagpy.stagyydata import Step
78

89

@@ -20,33 +21,33 @@ def test_field_missing(step: Step) -> None:
2021

2122
def test_valid_field_var() -> None:
2223
for var in stagpy.phyvars.FIELD:
23-
assert stagpy.field.valid_field_var(var)
24+
assert valid_field_var(var)
2425
for var in stagpy.phyvars.FIELD_EXTRA:
25-
assert stagpy.field.valid_field_var(var)
26+
assert valid_field_var(var)
2627

2728

2829
def test_valid_field_var_invalid() -> None:
29-
assert not stagpy.field.valid_field_var("dummyfieldvar")
30+
assert not valid_field_var("dummyfieldvar")
3031

3132

3233
def test_get_meshes_fld_no_walls(step: Step) -> None:
33-
xmesh, ymesh, fld, meta = stagpy.field.get_meshes_fld(step, "T", walls=False)
34+
xmesh, ymesh, fld, meta = get_meshes_fld(Config.default_(), step, "T", walls=False)
3435
assert len(fld.shape) == 2
3536
assert xmesh.shape[0] == ymesh.shape[0] == fld.shape[0]
3637
assert xmesh.shape[1] == ymesh.shape[1] == fld.shape[1]
3738
assert meta.description == "Temperature"
3839

3940

4041
def test_get_meshes_fld_walls(step: Step) -> None:
41-
xmesh, ymesh, fld, meta = stagpy.field.get_meshes_fld(step, "T", walls=True)
42+
xmesh, ymesh, fld, meta = get_meshes_fld(Config.default_(), step, "T", walls=True)
4243
assert len(fld.shape) == 2
4344
assert xmesh.shape[0] == ymesh.shape[0] == fld.shape[0] + 1
4445
assert xmesh.shape[1] == ymesh.shape[1] == fld.shape[1] + 1
4546
assert meta.description == "Temperature"
4647

4748

4849
def test_get_meshes_vec(step: Step) -> None:
49-
xmesh, ymesh, vec1, vec2 = stagpy.field.get_meshes_vec(step, "v")
50+
xmesh, ymesh, vec1, vec2 = get_meshes_vec(Config.default_(), step, "v")
5051
assert len(vec1.shape) == 2
5152
assert xmesh.shape[0] == ymesh.shape[0] == vec1.shape[0] == vec2.shape[0]
5253
assert xmesh.shape[1] == ymesh.shape[1] == vec1.shape[1] == vec2.shape[1]

0 commit comments

Comments
 (0)