Skip to content

Commit 164bcea

Browse files
authored
cattrs converter with hook factory using xattree.asdict (#151)
Set up the boilerplace needed to use cattrs. Send the component to jinja as a dictionary instead of the instance itself, so the filter for value lookup can just use dict access instead of getattr(). Still need to set up conversion for list input, etc. Also some tests weren't cleaning up after themselves, fix it.
1 parent 7c9d3d7 commit 164bcea

File tree

5 files changed

+45
-42
lines changed

5 files changed

+45
-42
lines changed

flopy4/mf6/codec/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Any
44

55
import numpy as np
6+
import xattree
7+
from cattrs import Converter
68
from jinja2 import Environment, PackageLoader
79

810
from flopy4.mf6 import filters
@@ -29,13 +31,13 @@
2931
"threshold": sys.maxsize,
3032
}
3133

34+
_CONVERTER = Converter()
35+
_CONVERTER.register_unstructure_hook_factory(
36+
lambda cls: xattree.has(cls), lambda cls: xattree.asdict
37+
)
3238

33-
def unstructure(data):
34-
# TODO unstructure arrays into sparse dicts
35-
# TODO combine OC fields into list input as defined in the MF6 dfn
36-
# TODO return a dictionary instead of the component itself, then
37-
# update filters to use dictinoary access instead of getattr()
38-
return data
39+
# TODO unstructure arrays into sparse dicts
40+
# TODO combine OC fields into list input as defined in the MF6 dfn
3941

4042

4143
def loads(data: str) -> Any:
@@ -51,12 +53,12 @@ def load(path: str | PathLike) -> Any:
5153
def dumps(data) -> str:
5254
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
5355
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
54-
return template.render(dfn=type(data).dfn, data=unstructure(data))
56+
return template.render(dfn=type(data).dfn, data=_CONVERTER.unstructure(data))
5557

5658

5759
def dump(data, path: str | PathLike) -> None:
5860
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
59-
iterator = template.generate(dfn=type(data).dfn, data=unstructure(data))
61+
iterator = template.generate(dfn=type(data).dfn, data=_CONVERTER.unstructure(data))
6062
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
6163
f.writelines(iterator)
6264

flopy4/mf6/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def field_type(field: Field) -> str:
3535
@pass_context
3636
def field_value(ctx, field: Field):
3737
"""Get a field's value via the template context."""
38-
return getattr(ctx["data"], field["name"])
38+
return ctx["data"][field["name"]]
3939

4040

4141
def array_how(value: xr.DataArray) -> str:

flopy4/mf6/simulation.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ def time(self) -> ModelTime:
4444

4545
def run(self, exe: str | PathLike = "mf6", verbose: bool = False) -> None:
4646
"""Run the simulation using the given executable."""
47-
if self.workspace is None:
48-
raise ValueError(f"Simulation {self.name} has no workspace path.")
4947
with cd(self.workspace):
50-
stdout, stderr, retcode = run_cmd(exe, verbose=verbose)
51-
if retcode != 0:
48+
out, err, ret = run_cmd(exe, verbose=verbose)
49+
if ret != 0:
5250
raise RuntimeError(
53-
f"Simulation {self.name}: {exe} failed to run with returncode " # type: ignore
54-
f"{retcode}, and error message:\n\n{stdout + stderr} "
51+
f"Simulation {self.name}: {exe} failed with " # type: ignore
52+
f"return code {ret}, output:\n\n{out + err} "
5553
)
5654

5755
def load(self, format="ascii"):
5856
"""Load the simulation in the specified format."""
59-
super().load(format)
57+
with cd(self.workspace):
58+
super().load(format)
6059

6160
def write(self, format="ascii"):
6261
"""Write the simulation in the specified format."""
63-
super().write(format)
62+
with cd(self.workspace):
63+
super().write(format)

test/test_examples.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66
from modflow_devtools.markers import requires_exe
7-
from modflow_devtools.misc import run_cmd
7+
from modflow_devtools.misc import cd, run_cmd
88

99

1010
@pytest.mark.slow
@@ -16,7 +16,7 @@ def test_scripts(example_script):
1616

1717
@pytest.mark.slow
1818
@requires_exe("jupytext")
19-
def test_notebooks(example_script):
19+
def test_notebooks(example_script, tmp_path):
2020
args = [
2121
"jupytext",
2222
"--from",
@@ -26,5 +26,6 @@ def test_notebooks(example_script):
2626
"--execute",
2727
example_script,
2828
]
29-
stdout, stderr, retcode = run_cmd(*args, verbose=True)
30-
assert not retcode, stdout + stderr
29+
with cd(tmp_path):
30+
out, err, ret = run_cmd(*args, verbose=True)
31+
assert not ret, out + err

test/test_interface.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def quickstart_model():
3737
return gwf
3838

3939

40-
def test_flopy3_model():
40+
def test_flopy3_model(tmp_path):
4141
from flopy.mbase import ModelInterface
4242
from flopy.pakbase import PackageInterface
4343

@@ -114,12 +114,12 @@ def test_flopy3_model():
114114
print(f"data_type: {d.data_type}")
115115
print(f"array: {d.array}\n")
116116

117-
bpth = Path("output/flopy3_model/flopy3_model")
118-
Path("output/flopy3_model").mkdir(parents=True, exist_ok=True)
117+
bpth = Path(tmp_path) / "flopy3_model" / "flopy3_model"
118+
(Path(tmp_path) / "flopy3_model").mkdir(parents=True, exist_ok=True)
119119
gwf3.plot(filename_base=bpth)
120120

121121

122-
def test_flopy3_package():
122+
def test_flopy3_package(tmp_path):
123123
from flopy.mbase import ModelInterface
124124
from flopy.pakbase import PackageInterface
125125

@@ -236,12 +236,12 @@ def test_flopy3_package():
236236
if di.name == k:
237237
assert np.all(np.equal(di.array, v))
238238

239-
bpth = Path("output/flopy3_package/flopy3_package")
240-
Path("output/flopy3_package").mkdir(parents=True, exist_ok=True)
239+
bpth = Path(tmp_path) / "flopy3_package" / "flopy3_package"
240+
(Path(tmp_path) / "flopy3_package").mkdir(parents=True, exist_ok=True)
241241
dis3.plot(filename_base=bpth)
242242

243243

244-
def norun_test_flopy3_cbd_small():
244+
def norun_test_flopy3_cbd_small(tmp_path):
245245
import sys
246246

247247
sys.path.append("/home/mjreno/.clone/usgs/flopy/autotest")
@@ -262,13 +262,13 @@ def norun_test_flopy3_cbd_small():
262262
dis=dis,
263263
dims=dims,
264264
)
265-
bpth = Path("output/flopy3_cbd_small/flopy3_cbd_small")
266-
Path("output/flopy3_cbd_small").mkdir(parents=True, exist_ok=True)
265+
bpth = Path(tmp_path) / "flopy3_cbd_small" / "flopy3_cbd_small"
266+
(Path(tmp_path) / "flopy3_cbd_small").mkdir(parents=True, exist_ok=True)
267267
gwf3 = Flopy3Model(model=gwf, modelgrid=cbd_small, modeltime=time)
268268
gwf3.plot(filename_base=bpth)
269269

270270

271-
def test_flopy3_grid2():
271+
def test_flopy3_grid2(tmp_path):
272272
lx = 5.0
273273
lz = 1.0
274274
nlay = 1
@@ -303,18 +303,18 @@ def test_flopy3_grid2():
303303
dis=dis,
304304
dims=dims,
305305
)
306-
bpth = Path("output/flopy3_grid2/flopy3_grid2")
307-
Path("output/flopy3_grid2").mkdir(parents=True, exist_ok=True)
306+
bpth = Path(tmp_path) / "flopy3_grid2" / "flopy3_grid2"
307+
(Path(tmp_path) / "flopy3_grid2").mkdir(parents=True, exist_ok=True)
308308
gwf3 = Flopy3Model(model=gwf, modeltime=time)
309309
gwf3.plot(filename_base=bpth)
310310

311311

312-
def test_flopy3_export():
312+
def test_flopy3_export(tmp_path):
313313
# see flopy test_export.py test_export_output()
314-
Path("output/flopy3_model/shape").mkdir(parents=True, exist_ok=True)
315-
Path("output/flopy3_package/shape").mkdir(parents=True, exist_ok=True)
316-
Path("output/flopy3_model/netcdf").mkdir(parents=True, exist_ok=True)
317-
Path("output/flopy3_package/netcdf").mkdir(parents=True, exist_ok=True)
314+
(Path(tmp_path) / "flopy3_model" / "shape").mkdir(parents=True, exist_ok=True)
315+
(Path(tmp_path) / "flopy3_package/shape").mkdir(parents=True, exist_ok=True)
316+
(Path(tmp_path) / "flopy3_model/netcdf").mkdir(parents=True, exist_ok=True)
317+
(Path(tmp_path) / "flopy3_package/netcdf").mkdir(parents=True, exist_ok=True)
318318

319319
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
320320

@@ -323,20 +323,20 @@ def test_flopy3_export():
323323
dis3 = Flopy3Package(gwf.dis, model=gwf3)
324324

325325
# model shapefile export
326-
shp_mpth = Path("output/flopy3_model/shape/flopy3_model.shp")
326+
shp_mpth = Path(tmp_path / "flopy3_model" / "shape" / "flopy3_model.shp")
327327
gwf3.export(f=shp_mpth)
328328

329329
# package shapefile export
330-
shp_ppth = Path("output/flopy3_package/shape/flopy3_package.shp")
330+
shp_ppth = Path(tmp_path / "flopy3_package" / "shape" / "flopy3_package.shp")
331331
dis3.export(f=shp_ppth)
332332

333333
# model netcdf export
334-
nc_mpth = Path("output/flopy3_model/netcdf/flopy3_model.nc")
334+
nc_mpth = Path(tmp_path / "flopy3_model" / "netcdf" / "flopy3_model.nc")
335335
# TODO: needs flopy3 fix
336336
# gwf3.export(f=nc_mpth)
337337

338338
# package netcdf export
339-
nc_ppth = Path("output/flopy3_package/netcdf/flopy3_package.nc")
339+
nc_ppth = Path(tmp_path / "flopy3_package" / "netcdf" / "flopy3_package.nc")
340340
# TODO: needs flopy3 fix
341341
# dis3.export(f=nc_ppth)
342342

0 commit comments

Comments
 (0)