Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions flopy4/mf6/codec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any

import numpy as np
import xattree
from cattrs import Converter
from jinja2 import Environment, PackageLoader

from flopy4.mf6 import filters
Expand All @@ -29,13 +31,13 @@
"threshold": sys.maxsize,
}

_CONVERTER = Converter()
_CONVERTER.register_unstructure_hook_factory(
lambda cls: xattree.has(cls), lambda cls: xattree.asdict
)

def unstructure(data):
# TODO unstructure arrays into sparse dicts
# TODO combine OC fields into list input as defined in the MF6 dfn
# TODO return a dictionary instead of the component itself, then
# update filters to use dictinoary access instead of getattr()
return data
# TODO unstructure arrays into sparse dicts
# TODO combine OC fields into list input as defined in the MF6 dfn


def loads(data: str) -> Any:
Expand All @@ -51,12 +53,12 @@ def load(path: str | PathLike) -> Any:
def dumps(data) -> str:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
return template.render(dfn=type(data).dfn, data=unstructure(data))
return template.render(dfn=type(data).dfn, data=_CONVERTER.unstructure(data))


def dump(data, path: str | PathLike) -> None:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
iterator = template.generate(dfn=type(data).dfn, data=unstructure(data))
iterator = template.generate(dfn=type(data).dfn, data=_CONVERTER.unstructure(data))
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
f.writelines(iterator)

Expand Down
2 changes: 1 addition & 1 deletion flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def field_type(field: Field) -> str:
@pass_context
def field_value(ctx, field: Field):
"""Get a field's value via the template context."""
return getattr(ctx["data"], field["name"])
return ctx["data"][field["name"]]


def array_how(value: xr.DataArray) -> str:
Expand Down
16 changes: 8 additions & 8 deletions flopy4/mf6/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ def time(self) -> ModelTime:

def run(self, exe: str | PathLike = "mf6", verbose: bool = False) -> None:
"""Run the simulation using the given executable."""
if self.workspace is None:
raise ValueError(f"Simulation {self.name} has no workspace path.")
with cd(self.workspace):
stdout, stderr, retcode = run_cmd(exe, verbose=verbose)
if retcode != 0:
out, err, ret = run_cmd(exe, verbose=verbose)
if ret != 0:
raise RuntimeError(
f"Simulation {self.name}: {exe} failed to run with returncode " # type: ignore
f"{retcode}, and error message:\n\n{stdout + stderr} "
f"Simulation {self.name}: {exe} failed with " # type: ignore
f"return code {ret}, output:\n\n{out + err} "
)

def load(self, format="ascii"):
"""Load the simulation in the specified format."""
super().load(format)
with cd(self.workspace):
super().load(format)

def write(self, format="ascii"):
"""Write the simulation in the specified format."""
super().write(format)
with cd(self.workspace):
super().write(format)
9 changes: 5 additions & 4 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from modflow_devtools.markers import requires_exe
from modflow_devtools.misc import run_cmd
from modflow_devtools.misc import cd, run_cmd


@pytest.mark.slow
Expand All @@ -16,7 +16,7 @@ def test_scripts(example_script):

@pytest.mark.slow
@requires_exe("jupytext")
def test_notebooks(example_script):
def test_notebooks(example_script, tmp_path):
args = [
"jupytext",
"--from",
Expand All @@ -26,5 +26,6 @@ def test_notebooks(example_script):
"--execute",
example_script,
]
stdout, stderr, retcode = run_cmd(*args, verbose=True)
assert not retcode, stdout + stderr
with cd(tmp_path):
out, err, ret = run_cmd(*args, verbose=True)
assert not ret, out + err
42 changes: 21 additions & 21 deletions test/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def quickstart_model():
return gwf


def test_flopy3_model():
def test_flopy3_model(tmp_path):
from flopy.mbase import ModelInterface
from flopy.pakbase import PackageInterface

Expand Down Expand Up @@ -114,12 +114,12 @@ def test_flopy3_model():
print(f"data_type: {d.data_type}")
print(f"array: {d.array}\n")

bpth = Path("output/flopy3_model/flopy3_model")
Path("output/flopy3_model").mkdir(parents=True, exist_ok=True)
bpth = Path(tmp_path) / "flopy3_model" / "flopy3_model"
(Path(tmp_path) / "flopy3_model").mkdir(parents=True, exist_ok=True)
gwf3.plot(filename_base=bpth)


def test_flopy3_package():
def test_flopy3_package(tmp_path):
from flopy.mbase import ModelInterface
from flopy.pakbase import PackageInterface

Expand Down Expand Up @@ -236,12 +236,12 @@ def test_flopy3_package():
if di.name == k:
assert np.all(np.equal(di.array, v))

bpth = Path("output/flopy3_package/flopy3_package")
Path("output/flopy3_package").mkdir(parents=True, exist_ok=True)
bpth = Path(tmp_path) / "flopy3_package" / "flopy3_package"
(Path(tmp_path) / "flopy3_package").mkdir(parents=True, exist_ok=True)
dis3.plot(filename_base=bpth)


def norun_test_flopy3_cbd_small():
def norun_test_flopy3_cbd_small(tmp_path):
import sys

sys.path.append("/home/mjreno/.clone/usgs/flopy/autotest")
Expand All @@ -262,13 +262,13 @@ def norun_test_flopy3_cbd_small():
dis=dis,
dims=dims,
)
bpth = Path("output/flopy3_cbd_small/flopy3_cbd_small")
Path("output/flopy3_cbd_small").mkdir(parents=True, exist_ok=True)
bpth = Path(tmp_path) / "flopy3_cbd_small" / "flopy3_cbd_small"
(Path(tmp_path) / "flopy3_cbd_small").mkdir(parents=True, exist_ok=True)
gwf3 = Flopy3Model(model=gwf, modelgrid=cbd_small, modeltime=time)
gwf3.plot(filename_base=bpth)


def test_flopy3_grid2():
def test_flopy3_grid2(tmp_path):
lx = 5.0
lz = 1.0
nlay = 1
Expand Down Expand Up @@ -303,18 +303,18 @@ def test_flopy3_grid2():
dis=dis,
dims=dims,
)
bpth = Path("output/flopy3_grid2/flopy3_grid2")
Path("output/flopy3_grid2").mkdir(parents=True, exist_ok=True)
bpth = Path(tmp_path) / "flopy3_grid2" / "flopy3_grid2"
(Path(tmp_path) / "flopy3_grid2").mkdir(parents=True, exist_ok=True)
gwf3 = Flopy3Model(model=gwf, modeltime=time)
gwf3.plot(filename_base=bpth)


def test_flopy3_export():
def test_flopy3_export(tmp_path):
# see flopy test_export.py test_export_output()
Path("output/flopy3_model/shape").mkdir(parents=True, exist_ok=True)
Path("output/flopy3_package/shape").mkdir(parents=True, exist_ok=True)
Path("output/flopy3_model/netcdf").mkdir(parents=True, exist_ok=True)
Path("output/flopy3_package/netcdf").mkdir(parents=True, exist_ok=True)
(Path(tmp_path) / "flopy3_model" / "shape").mkdir(parents=True, exist_ok=True)
(Path(tmp_path) / "flopy3_package/shape").mkdir(parents=True, exist_ok=True)
(Path(tmp_path) / "flopy3_model/netcdf").mkdir(parents=True, exist_ok=True)
(Path(tmp_path) / "flopy3_package/netcdf").mkdir(parents=True, exist_ok=True)

time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])

Expand All @@ -323,20 +323,20 @@ def test_flopy3_export():
dis3 = Flopy3Package(gwf.dis, model=gwf3)

# model shapefile export
shp_mpth = Path("output/flopy3_model/shape/flopy3_model.shp")
shp_mpth = Path(tmp_path / "flopy3_model" / "shape" / "flopy3_model.shp")
gwf3.export(f=shp_mpth)

# package shapefile export
shp_ppth = Path("output/flopy3_package/shape/flopy3_package.shp")
shp_ppth = Path(tmp_path / "flopy3_package" / "shape" / "flopy3_package.shp")
dis3.export(f=shp_ppth)

# model netcdf export
nc_mpth = Path("output/flopy3_model/netcdf/flopy3_model.nc")
nc_mpth = Path(tmp_path / "flopy3_model" / "netcdf" / "flopy3_model.nc")
# TODO: needs flopy3 fix
# gwf3.export(f=nc_mpth)

# package netcdf export
nc_ppth = Path("output/flopy3_package/netcdf/flopy3_package.nc")
nc_ppth = Path(tmp_path / "flopy3_package" / "netcdf" / "flopy3_package.nc")
# TODO: needs flopy3 fix
# dis3.export(f=nc_ppth)

Expand Down
Loading