Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,18 @@ class Component(ABC, MutableMapping):

@property
def path(self) -> Path:
"""Get the path to the component's input file."""
return Path.cwd() / self.filename

def _default_filename(self) -> str:
def default_filename(self) -> str:
"""
Generate a default filename for the component.
By default, this is the component's name then
the class name in lowercase, separated by dot.

Override this method in subclasses to provide
a custom default filename.
"""
name = self.name # type: ignore
cls_name = self.__class__.__name__.lower()
return f"{name}.{cls_name}"
Expand All @@ -50,10 +59,6 @@ def __attrs_init_subclass__(cls):
COMPONENTS[cls.__name__.lower()] = cls
cls.dfn = cls.get_dfn()

def __attrs_post_init__(self):
if not self.filename:
self.filename = self._default_filename()

def __getitem__(self, key):
return self.children[key] # type: ignore

Expand Down Expand Up @@ -89,12 +94,21 @@ def get_dfn(cls) -> Dfn:
**blocks,
)

def _preio(self, format: str) -> None:
"""Place for any pre-IO setup"""
if not self.filename:
self.filename = self.default_filename()

def load(self, format: str) -> None:
"""Load the component from an input file."""
self._preio(format=format)
self._load(format=format)
for child in self.children.values(): # type: ignore
child.load(format)
child.load(format=format)

def write(self, format: str) -> None:
"""Write the component to an input file."""
self._preio(format=format)
self._write(format=format)
for child in self.children.values(): # type: ignore
child.write(format)
child.write(format=format)
1 change: 1 addition & 0 deletions flopy4/mf6/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class Context(Component, ABC):
workspace: Path = field(default=None)

def __attrs_post_init__(self):
super().__attrs_post_init__()
if self.workspace is None:
self.workspace = Path.cwd()

Expand Down
3 changes: 2 additions & 1 deletion flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def array_chunks(value: xr.DataArray, chunks: Mapping[Hashable, int] | None = No
}
value = value.chunk(chunks)
for chunk in value.data.blocks:
yield chunk.compute()
yield np.squeeze(chunk.compute())


def array2string(value: NDArray) -> str:
Expand All @@ -112,6 +112,7 @@ def array2string(value: NDArray) -> str:
if value.ndim == 1:
# add an axis to 1d arrays so np.savetxt writes elements on 1 line
value = value[None]
value = np.atleast_1d(value)
format = (
"%d"
if np.issubdtype(value.dtype, np.integer)
Expand Down
1 change: 1 addition & 0 deletions flopy4/mf6/gwf/dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class Dis(Package):

def __attrs_post_init__(self):
self.nnodes = self.ncol * self.nrow * self.nlay
super().__attrs_post_init__()

def to_grid(self) -> StructuredGrid:
"""
Expand Down
3 changes: 2 additions & 1 deletion flopy4/mf6/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@

@xattree
class Model(Component, ABC):
pass
def default_filename(self) -> str:
return f"{self.name}.nam" # type: ignore
5 changes: 4 additions & 1 deletion flopy4/mf6/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@

@xattree
class Package(Component, ABC):
pass
def default_filename(self) -> str:
name = self.parent.name if self.parent else self.name # type: ignore
cls_name = self.__class__.__name__.lower()
return f"{name}.{cls_name}"
1 change: 1 addition & 0 deletions flopy4/mf6/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Simulation(Context):
filename: str = field(default="mfsim.nam", init=False)

def __attrs_post_init__(self):
super().__attrs_post_init__()
if self.filename != "mfsim.nam":
warn(
"Simulation filename must be 'mfsim.nam'.",
Expand Down
2 changes: 1 addition & 1 deletion flopy4/mf6/templates/blocks.jinja
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% import 'macros.jinja' as macros with context %}
{% for block_name, block_ in (dfn|dict_blocks).items() %}
BEGIN {{ block_name.upper() }}
{% for field in block_.values() -%}
{% for field in block_.values() if (field|field_value) is not none -%}
{{ macros.field(field) }}
{%- endfor %}
END {{ block_name.upper() }}
Expand Down
2 changes: 1 addition & 1 deletion flopy4/mf6/templates/macros.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ this macro receives the block definition. from that
it looks up the value of the one variable with the
same name as the block, which custom converter has
made sure exists in a sparse dict representation of
an array. we need to spin this out into a block for
an array. we need to expand this into a block for
each stress period.
#}
{% set dict = data[block_name] %}
Expand Down
2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path

pytest_plugins = ["modflow_devtools.fixtures"]

PROJ_ROOT_PATH = Path(__file__).parents[1]
DOCS_PATH = PROJ_ROOT_PATH / "docs"
EXAMPLES_PATH = DOCS_PATH / "examples"
Expand Down
3 changes: 0 additions & 3 deletions test/test_codec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from flopy4.mf6.codec import dumps


Expand Down Expand Up @@ -37,7 +35,6 @@ def test_dumps_oc():
assert result


@pytest.mark.xfail(reason="TODO 3D arrays")
def test_dumps_dis():
from flopy4.mf6.gwf import Dis

Expand Down
22 changes: 19 additions & 3 deletions test/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,27 @@ def test_ims_dfn():
assert "inner_maximum" in set(dfn["linear"].keys())


def test_write_ascii(tmp_path):
def test_write_ascii(function_tmpdir):
sim_name = "sim"
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
sim = Simulation(tdis=time, workspace=tmp_path)
grid = StructuredGrid(nlay=1, nrow=10, ncol=10)
sim = Simulation(tdis=time, workspace=function_tmpdir, name=sim_name)
gwf_name = "gwf"
gwf = Gwf(parent=sim, dis=grid, name=gwf_name)
ic = Ic(parent=gwf)
oc = Oc(parent=gwf)
npf = Npf(parent=gwf)
chd = Chd(parent=gwf, head={"*": {(0, 0, 0): 1.0, (0, 9, 9): 0.0}})

sim.write()

files = list(Path(tmp_path).glob("*"))
files = list(Path(function_tmpdir).glob("*"))
file_names = [f.name for f in files]
assert "mfsim.nam" in file_names
assert f"{sim_name}.tdis" in file_names
assert f"{gwf_name}.nam" in file_names
assert f"{gwf_name}.dis" in file_names
assert f"{gwf_name}.ic" in file_names
assert f"{gwf_name}.oc" in file_names
assert f"{gwf_name}.npf" in file_names
assert f"{gwf_name}.chd" in file_names
Loading