diff --git a/docs/examples/quickstart.py b/docs/examples/quickstart.py index 3277c9c4..5740a481 100644 --- a/docs/examples/quickstart.py +++ b/docs/examples/quickstart.py @@ -5,23 +5,32 @@ from flopy.discretization.modeltime import ModelTime from flopy.discretization.structuredgrid import StructuredGrid -from flopy4.mf6.gwf import Chd, Gwf, Npf, Oc +from flopy4.mf6.gwf import Chd, Gwf, Ic, Npf, Oc from flopy4.mf6.ims import Ims from flopy4.mf6.simulation import Simulation name = "quickstart" workspace = Path(__file__).parent / name time = ModelTime(perlen=[1.0], nstp=[1]) -grid = StructuredGrid(nlay=1, nrow=10, ncol=10) +grid = StructuredGrid( + nlay=1, + nrow=10, + ncol=10, + delr=1.0 * np.ones(10), + delc=1.0 * np.ones(10), + top=1.0 * np.ones((10, 10)), + botm=0.0 * np.ones((1, 10, 10)), +) sim = Simulation(name=name, workspace=workspace, tdis=time) -ims = Ims(parent=sim) gwf_name = "mymodel" +ims = Ims(parent=sim, models=[gwf_name]) # temporary hack gwf = Gwf(parent=sim, name=gwf_name, save_flows=True, dis=grid) npf = Npf(parent=gwf, save_specific_discharge=True) chd = Chd( parent=gwf, head={0: {(0, 0, 0): 1.0, (0, 9, 9): 0.0}}, ) +ic = Ic(parent=gwf, strt=1.0) oc = Oc( parent=gwf, budget_file=f"{gwf.name}.bud", @@ -30,18 +39,15 @@ save_budget={0: "all"}, ) -# sim.write() +sim.write() sim.run(verbose=True) -# check CHD assert chd.data["head"][0, 0] == 1.0 assert chd.data.head.sel(per=0)[99] == 0.0 assert np.allclose(chd.data.head[:, 1:99], np.full(98, 1e30)) -# check DIS assert gwf.dis.data.botm.sel(lay=0, col=0, row=0) == 0.0 -# check OC assert oc.data["save_head"][0] == "all" assert oc.data.save_head.sel(per=0) == "all" diff --git a/flopy4/mf6/__init__.py b/flopy4/mf6/__init__.py index 760e61e5..dd4f15b1 100644 --- a/flopy4/mf6/__init__.py +++ b/flopy4/mf6/__init__.py @@ -1,8 +1,10 @@ -from flopy4.mf6.codec import dump, load +from flopy4.mf6.codec import dump from flopy4.mf6.component import Component +from flopy4.mf6.converter import COMPONENT_CONVERTER from flopy4.uio import DEFAULT_REGISTRY # register io methods -# TODO: call this "mf6" or something? since it might include binary files -DEFAULT_REGISTRY.register_loader(Component, "ascii", lambda c: load(c.path)) -DEFAULT_REGISTRY.register_writer(Component, "ascii", lambda c: dump(c, c.path)) +# TODO: call format "mf6" or something? since it might include binary files +DEFAULT_REGISTRY.register_writer( + Component, "ascii", lambda c: dump(COMPONENT_CONVERTER.unstructure(c), c.path) +) diff --git a/flopy4/mf6/codec/writer/__init__.py b/flopy4/mf6/codec/writer/__init__.py index 616c077d..f6af3ba4 100644 --- a/flopy4/mf6/codec/writer/__init__.py +++ b/flopy4/mf6/codec/writer/__init__.py @@ -4,7 +4,7 @@ import numpy as np from jinja2 import Environment, PackageLoader -from flopy4.mf6 import filters +from flopy4.mf6.codec.writer import filters _JINJA_ENV = Environment( loader=PackageLoader("flopy4.mf6.codec.writer"), diff --git a/flopy4/mf6/filters.py b/flopy4/mf6/codec/writer/filters.py similarity index 100% rename from flopy4/mf6/filters.py rename to flopy4/mf6/codec/writer/filters.py diff --git a/flopy4/mf6/context.py b/flopy4/mf6/context.py index 8320b77c..7e98598d 100644 --- a/flopy4/mf6/context.py +++ b/flopy4/mf6/context.py @@ -1,6 +1,7 @@ from abc import ABC from pathlib import Path +from modflow_devtools.misc import cd from xattree import xattree from flopy4.mf6.component import Component @@ -14,8 +15,20 @@ class Context(Component, ABC): def __attrs_post_init__(self): super().__attrs_post_init__() if self.workspace is None: - self.workspace = Path.cwd() + self.workspace = ( + self.parent.workspace + if self.parent and hasattr(self.parent, "workspace") + else Path.cwd() + ) @property def path(self) -> Path: return self.workspace / self.filename + + def load(self, format="ascii"): + with cd(self.workspace): + super().load(format=format) + + def write(self, format="ascii"): + with cd(self.workspace): + super().write(format=format) diff --git a/flopy4/mf6/converter.py b/flopy4/mf6/converter.py index a1808d53..4f566aee 100644 --- a/flopy4/mf6/converter.py +++ b/flopy4/mf6/converter.py @@ -1,15 +1,64 @@ +from collections.abc import MutableMapping from datetime import datetime from pathlib import Path from typing import Any import xarray as xr import xattree +from attrs import define from cattrs import Converter from flopy4.mf6.component import Component +from flopy4.mf6.context import Context +from flopy4.mf6.exchange import Exchange +from flopy4.mf6.model import Model +from flopy4.mf6.package import Package +from flopy4.mf6.solution import Solution from flopy4.mf6.spec import fields_dict, get_blocks +@define +class _Binding: + """ + An MF6 component binding: a record representation of the + component for writing to a parent component's name file. + """ + + type: str + fname: str + terms: tuple[str, ...] | None = None + + def to_tuple(self): + if self.terms and any(self.terms): + return (self.type, self.fname, *self.terms) + else: + return (self.type, self.fname) + + @classmethod + def from_component(cls, component: Component) -> "_Binding": + def _get_binding_type(component: Component) -> str: + cls_name = component.__class__.__name__ + if isinstance(component, Exchange): + return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6" + else: + return f"{cls_name.upper()}6" + + def _get_binding_terms(component: Component) -> tuple[str, ...] | None: + if isinstance(component, Exchange): + return (component.exgmnamea, component.exgmnameb) # type: ignore + elif isinstance(component, Solution): + return tuple(component.models) + elif isinstance(component, (Model, Package)): + return (component.name,) # type: ignore + return None + + return cls( + type=_get_binding_type(component), + fname=component.filename or component.default_filename(), + terms=_get_binding_terms(component), + ) + + def _attach_field_metadata( dataset: xr.Dataset, component_type: type, field_names: list[str] ) -> None: @@ -29,14 +78,56 @@ def _path_to_record(field_name: str, path_value: Path) -> tuple: def unstructure_component(value: Component) -> dict[str, Any]: - data = xattree.asdict(value) blockspec = get_blocks(value.dfn) blocks: dict[str, dict[str, Any]] = {} + xatspec = xattree.get_xatspec(type(value)) + + # Handle child component bindings before converting to dict + if isinstance(value, Context): + for field_name, child_spec in xatspec.children.items(): + if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore + block_name = child_spec.metadata["block"] # type: ignore + field_value = getattr(value, field_name, None) + + if block_name not in blocks: + blocks[block_name] = {} + + if isinstance(field_value, Component): + components = [_Binding.from_component(field_value).to_tuple()] + elif isinstance(field_value, MutableMapping): + components = [ + _Binding.from_component(comp).to_tuple() + for comp in field_value.values() + if comp is not None + ] + elif isinstance(field_value, (list, tuple)): + components = [ + _Binding.from_component(comp).to_tuple() + for comp in field_value + if comp is not None + ] + else: + continue + + if components: + blocks[block_name][field_name] = components + + data = xattree.asdict(value) + for block_name, block in blockspec.items(): - blocks[block_name] = {} + if block_name not in blocks: + blocks[block_name] = {} period_data = {} period_blocks = {} # type: ignore + for field_name in block.keys(): + # Skip child components that have been processed as bindings + if isinstance(value, Context) and field_name in xatspec.children: + child_spec = xatspec.children[field_name] + if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore + if child_spec.metadata["block"] == block_name: # type: ignore + continue + field_value = data[field_name] # convert: # - paths to records @@ -89,7 +180,18 @@ def unstructure_component(value: Component) -> dict[str, Any]: _attach_field_metadata(dataset, type(value), list(block.keys())) blocks[f"{block_name} {kper + 1}"] = {block_name: dataset} - return {name: block for name, block in blocks.items() if block} + # make sure options block always comes first + if "options" in blocks: + options_block = blocks.pop("options") + blocks = {"options": options_block, **blocks} + + # total temporary hack! manually set solutiongroup 1. still need to support multiple.. + if "solutiongroup" in blocks: + sg = blocks["solutiongroup"] + blocks["solutiongroup 1"] = sg + del blocks["solutiongroup"] + + return {name: block for name, block in blocks.items() if name != "period"} def _make_converter() -> Converter: diff --git a/flopy4/mf6/gwf/__init__.py b/flopy4/mf6/gwf/__init__.py index fa0697c2..70a64443 100644 --- a/flopy4/mf6/gwf/__init__.py +++ b/flopy4/mf6/gwf/__init__.py @@ -31,6 +31,11 @@ def convert_grid(value): @xattree class Gwf(Model): + @define + class NewtonOptions: + newton: bool = field() + under_relaxation: bool = field() + @define class Output: parent: "Gwf" = attrs.field(repr=False) @@ -51,23 +56,7 @@ def budget(self): self.parent.parent.workspace / f"{self.parent.name}.dis.grb", ) - dis: Dis = field(converter=convert_grid) - ic: Ic = field() - oc: Oc = field() - npf: Npf = field() - chd: list[Chd] = field() - wel: list[Wel] = field() - drn: list[Drn] = field() - output: Output = attrs.field( - default=attrs.Factory(lambda self: Gwf.Output(self), takes_self=True) - ) - - @define - class NewtonOptions: - newton: bool = field() - under_relaxation: bool = field() - - list: Optional[str] = field(block="options", default=None) + _list: Optional[str] = field(block="options", default=None) print_input: bool = field(block="options", default=False) print_flows: bool = field(block="options", default=False) save_flows: bool = field(block="options", default=False) @@ -75,6 +64,16 @@ class NewtonOptions: nc_mesh2d_filerecord: Optional[Path] = field(block="options", default=None) nc_structured_filerecord: Optional[Path] = field(block="options", default=None) nc_filerecord: Optional[Path] = field(block="options", default=None) + dis: Dis = field(converter=convert_grid, block="packages") + ic: Ic = field(block="packages") + oc: Oc = field(block="packages") + npf: Npf = field(block="packages") + chd: list[Chd] = field(block="packages") + wel: list[Wel] = field(block="packages") + drn: list[Drn] = field(block="packages") + output: Output = attrs.field( + default=attrs.Factory(lambda self: Gwf.Output(self), takes_self=True) + ) @property def grid(self) -> Grid: diff --git a/flopy4/mf6/gwf/ic.py b/flopy4/mf6/gwf/ic.py index 18cdb7c2..44a0f292 100644 --- a/flopy4/mf6/gwf/ic.py +++ b/flopy4/mf6/gwf/ic.py @@ -8,13 +8,13 @@ from flopy4.mf6.spec import array, field -@xattree +@xattree(kw_only=True) class Ic(Package): + export_array_ascii: bool = field(block="options", default=False) + export_array_netcdf: bool = field(block="options", default=False) strt: NDArray[np.float64] = array( - block="packagedata", + block="griddata", dims=("nnodes",), default=1.0, converter=Converter(dict_to_array, takes_self=True, takes_field=True), ) - export_array_ascii: bool = field(block="options", default=False) - export_array_netcdf: bool = field(block="options", default=False) diff --git a/flopy4/mf6/ims.py b/flopy4/mf6/ims.py index 0066d321..98392f19 100644 --- a/flopy4/mf6/ims.py +++ b/flopy4/mf6/ims.py @@ -12,7 +12,7 @@ class Ims(Solution): solution_package: ClassVar[Sln] = Sln(abbr="ims", pattern="*") - print_option: bool = field(block="options", default=False) + print_option: Optional[str] = field(block="options", default=None) complexity: str = field(block="options", default="simple") csv_outer_output_file: Optional[Path] = field(default=None, block="options") csv_inner_output_file: Optional[Path] = field(block="options", default=None) diff --git a/flopy4/mf6/model.py b/flopy4/mf6/model.py index ca0367c7..a3ebd97e 100644 --- a/flopy4/mf6/model.py +++ b/flopy4/mf6/model.py @@ -2,10 +2,10 @@ from xattree import xattree -from flopy4.mf6.component import Component +from flopy4.mf6.context import Context @xattree -class Model(Component, ABC): +class Model(Context, ABC): def default_filename(self) -> str: return f"{self.name}.nam" # type: ignore diff --git a/flopy4/mf6/simulation.py b/flopy4/mf6/simulation.py index 46484f01..8ee73d6e 100644 --- a/flopy4/mf6/simulation.py +++ b/flopy4/mf6/simulation.py @@ -23,10 +23,10 @@ def convert_time(value): @xattree class Simulation(Context): - models: dict[str, Model] = field() - exchanges: dict[str, Exchange] = field() - solutions: dict[str, Solution] = field() - tdis: Tdis = field(converter=convert_time) + tdis: Tdis = field(converter=convert_time, block="timing") + models: dict[str, Model] = field(block="models") + exchanges: dict[str, Exchange] = field(block="exchanges") + solutions: dict[str, Solution] = field(block="solutiongroup") filename: str = field(default="mfsim.nam", init=False) def __attrs_post_init__(self): @@ -52,13 +52,3 @@ def run(self, exe: str | PathLike = "mf6", verbose: bool = False) -> None: f"Simulation {self.name}: {exe} failed with " # type: ignore f"return code {ret}, output:\n\n{out + err} " ) - - def load(self, format="ascii"): - """Load the simulation.""" - with cd(self.workspace): - super().load(format=format) - - def write(self, format="ascii"): - """Write the simulation.""" - with cd(self.workspace): - super().write(format=format) diff --git a/flopy4/mf6/solution.py b/flopy4/mf6/solution.py index dfcafd73..d3b63b17 100644 --- a/flopy4/mf6/solution.py +++ b/flopy4/mf6/solution.py @@ -1,5 +1,6 @@ from abc import ABC +import attrs from xattree import xattree from flopy4.mf6.package import Package @@ -7,4 +8,4 @@ @xattree class Solution(Package, ABC): - pass + models: list[str] = attrs.field(default=attrs.Factory(list)) diff --git a/flopy4/mf6/spec.py b/flopy4/mf6/spec.py index 4153b9e2..8f0d566a 100644 --- a/flopy4/mf6/spec.py +++ b/flopy4/mf6/spec.py @@ -233,7 +233,7 @@ def get_dfn_field_type(attribute: Attribute) -> FieldType: kind = xatmeta["kind"] match kind: case "child": - raise ValueError(f"Top-level field should not be a child: {attribute.name}") + return "recarray" # Child components become tabular bindings case "array": return "recarray" case "coord": diff --git a/test/test_codec.py b/test/test_codec.py index 8d24a6e3..142c58e2 100644 --- a/test/test_codec.py +++ b/test/test_codec.py @@ -264,3 +264,83 @@ def test_dumps_wel_with_auxiliary(): assert "45 -25.0 2.0" in result # (1,3,4) -> node 45, q=-25.0, aux=2.0 assert "1e+30" not in result assert "1.0e+30" not in result + + +def test_dumps_gwf(): + from flopy4.mf6.gwf import Chd, Dis, Gwf, Ic, Npf, Oc + + dis = Dis(nlay=1, nrow=10, ncol=10, delr=100.0, delc=100.0) + gwf = Gwf(name="test_model", dis=dis) + ic = Ic(parent=gwf, strt=1.0) + npf = Npf(parent=gwf, k=1.0) + oc = Oc(parent=gwf, head_file="test.hds", budget_file="test.bud", dims={"nper": 1}) + chd = Chd(parent=gwf, head={0: {(0, 0, 0): 10.0}}, dims={"nper": 1}) + + gwf = Gwf( + name="test_model", + dis=dis, + ic=ic, + npf=npf, + oc=oc, + chd=[chd], + ) + + result = dumps(COMPONENT_CONVERTER.unstructure(gwf)) + print("GWF model result:") + print(result) + + # Check that child component bindings are included + assert "DIS6" in result + assert "IC6" in result + assert "NPF6" in result + assert "OC6" in result + assert "test_model.dis" in result + assert "test_model.ic" in result + assert "test_model.npf" in result + assert "test_model.oc" in result + + +def test_dumps_simulation(): + from flopy.discretization.modeltime import ModelTime + + from flopy4.mf6.gwf import Dis, Gwf, Ic, Npf, Oc + from flopy4.mf6.simulation import Simulation + from flopy4.mf6.tdis import Tdis + + # Create model components + dis = Dis(nlay=1, nrow=5, ncol=5, delr=100.0, delc=100.0) + gwf = Gwf(name="model1", dis=dis) + ic = Ic(parent=gwf, strt=1.0) + npf = Npf(parent=gwf, k=1.0) + oc = Oc(parent=gwf, head_file="model1.hds", budget_file="model1.bud", dims={"nper": 1}) + + # Create model + gwf = Gwf( + name="model1", + dis=dis, + ic=ic, + npf=npf, + oc=oc, + ) + + # Create time discretization + time = ModelTime(perlen=[1.0], nstp=[1]) + tdis = Tdis.from_time(time) + + # Create simulation + sim = Simulation( + name="test_sim", + models={"model1": gwf}, + exchanges={}, + solutions={}, + tdis=tdis, + ) + + result = dumps(COMPONENT_CONVERTER.unstructure(sim)) + print("Simulation result:") + print(result) + + # Check that model bindings are included + assert "GWF6" in result + assert "model1" in result + assert "TDIS6" in result