Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions docs/examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,32 @@
from flopy.discretization.modeltime import ModelTime
from flopy.discretization.structuredgrid import StructuredGrid

from flopy4.mf6.gwf import Chd, Gwf, Npf, Oc
from flopy4.mf6.gwf import Chd, Gwf, Ic, Npf, Oc
from flopy4.mf6.ims import Ims
from flopy4.mf6.simulation import Simulation

name = "quickstart"
workspace = Path(__file__).parent / name
time = ModelTime(perlen=[1.0], nstp=[1])
grid = StructuredGrid(nlay=1, nrow=10, ncol=10)
grid = StructuredGrid(
nlay=1,
nrow=10,
ncol=10,
delr=1.0 * np.ones(10),
delc=1.0 * np.ones(10),
top=1.0 * np.ones((10, 10)),
botm=0.0 * np.ones((1, 10, 10)),
)
sim = Simulation(name=name, workspace=workspace, tdis=time)
ims = Ims(parent=sim)
gwf_name = "mymodel"
ims = Ims(parent=sim, models=[gwf_name]) # temporary hack
gwf = Gwf(parent=sim, name=gwf_name, save_flows=True, dis=grid)
npf = Npf(parent=gwf, save_specific_discharge=True)
chd = Chd(
parent=gwf,
head={0: {(0, 0, 0): 1.0, (0, 9, 9): 0.0}},
)
ic = Ic(parent=gwf, strt=1.0)
oc = Oc(
parent=gwf,
budget_file=f"{gwf.name}.bud",
Expand All @@ -30,18 +39,15 @@
save_budget={0: "all"},
)

# sim.write()
sim.write()
sim.run(verbose=True)

# check CHD
assert chd.data["head"][0, 0] == 1.0
assert chd.data.head.sel(per=0)[99] == 0.0
assert np.allclose(chd.data.head[:, 1:99], np.full(98, 1e30))

# check DIS
assert gwf.dis.data.botm.sel(lay=0, col=0, row=0) == 0.0

# check OC
assert oc.data["save_head"][0] == "all"
assert oc.data.save_head.sel(per=0) == "all"

Expand Down
10 changes: 6 additions & 4 deletions flopy4/mf6/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from flopy4.mf6.codec import dump, load
from flopy4.mf6.codec import dump
from flopy4.mf6.component import Component
from flopy4.mf6.converter import COMPONENT_CONVERTER
from flopy4.uio import DEFAULT_REGISTRY

# register io methods
# TODO: call this "mf6" or something? since it might include binary files
DEFAULT_REGISTRY.register_loader(Component, "ascii", lambda c: load(c.path))
DEFAULT_REGISTRY.register_writer(Component, "ascii", lambda c: dump(c, c.path))
# TODO: call format "mf6" or something? since it might include binary files
DEFAULT_REGISTRY.register_writer(
Component, "ascii", lambda c: dump(COMPONENT_CONVERTER.unstructure(c), c.path)
)
2 changes: 1 addition & 1 deletion flopy4/mf6/codec/writer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from jinja2 import Environment, PackageLoader

from flopy4.mf6 import filters
from flopy4.mf6.codec.writer import filters

_JINJA_ENV = Environment(
loader=PackageLoader("flopy4.mf6.codec.writer"),
Expand Down
File renamed without changes.
15 changes: 14 additions & 1 deletion flopy4/mf6/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC
from pathlib import Path

from modflow_devtools.misc import cd
from xattree import xattree

from flopy4.mf6.component import Component
Expand All @@ -14,8 +15,20 @@ class Context(Component, ABC):
def __attrs_post_init__(self):
super().__attrs_post_init__()
if self.workspace is None:
self.workspace = Path.cwd()
self.workspace = (
self.parent.workspace
if self.parent and hasattr(self.parent, "workspace")
else Path.cwd()
)

@property
def path(self) -> Path:
return self.workspace / self.filename

def load(self, format="ascii"):
with cd(self.workspace):
super().load(format=format)

def write(self, format="ascii"):
with cd(self.workspace):
super().write(format=format)
108 changes: 105 additions & 3 deletions flopy4/mf6/converter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,64 @@
from collections.abc import MutableMapping
from datetime import datetime
from pathlib import Path
from typing import Any

import xarray as xr
import xattree
from attrs import define
from cattrs import Converter

from flopy4.mf6.component import Component
from flopy4.mf6.context import Context
from flopy4.mf6.exchange import Exchange
from flopy4.mf6.model import Model
from flopy4.mf6.package import Package
from flopy4.mf6.solution import Solution
from flopy4.mf6.spec import fields_dict, get_blocks


@define
class _Binding:
"""
An MF6 component binding: a record representation of the
component for writing to a parent component's name file.
"""

type: str
fname: str
terms: tuple[str, ...] | None = None

def to_tuple(self):
if self.terms and any(self.terms):
return (self.type, self.fname, *self.terms)
else:
return (self.type, self.fname)

@classmethod
def from_component(cls, component: Component) -> "_Binding":
def _get_binding_type(component: Component) -> str:
cls_name = component.__class__.__name__
if isinstance(component, Exchange):
return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6"
else:
return f"{cls_name.upper()}6"

def _get_binding_terms(component: Component) -> tuple[str, ...] | None:
if isinstance(component, Exchange):
return (component.exgmnamea, component.exgmnameb) # type: ignore
elif isinstance(component, Solution):
return tuple(component.models)
elif isinstance(component, (Model, Package)):
return (component.name,) # type: ignore
return None

return cls(
type=_get_binding_type(component),
fname=component.filename or component.default_filename(),
terms=_get_binding_terms(component),
)


def _attach_field_metadata(
dataset: xr.Dataset, component_type: type, field_names: list[str]
) -> None:
Expand All @@ -29,14 +78,56 @@ def _path_to_record(field_name: str, path_value: Path) -> tuple:


def unstructure_component(value: Component) -> dict[str, Any]:
data = xattree.asdict(value)
blockspec = get_blocks(value.dfn)
blocks: dict[str, dict[str, Any]] = {}
xatspec = xattree.get_xatspec(type(value))

# Handle child component bindings before converting to dict
if isinstance(value, Context):
for field_name, child_spec in xatspec.children.items():
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
block_name = child_spec.metadata["block"] # type: ignore
field_value = getattr(value, field_name, None)

if block_name not in blocks:
blocks[block_name] = {}

if isinstance(field_value, Component):
components = [_Binding.from_component(field_value).to_tuple()]
elif isinstance(field_value, MutableMapping):
components = [
_Binding.from_component(comp).to_tuple()
for comp in field_value.values()
if comp is not None
]
elif isinstance(field_value, (list, tuple)):
components = [
_Binding.from_component(comp).to_tuple()
for comp in field_value
if comp is not None
]
else:
continue

if components:
blocks[block_name][field_name] = components

data = xattree.asdict(value)

for block_name, block in blockspec.items():
blocks[block_name] = {}
if block_name not in blocks:
blocks[block_name] = {}
period_data = {}
period_blocks = {} # type: ignore

for field_name in block.keys():
# Skip child components that have been processed as bindings
if isinstance(value, Context) and field_name in xatspec.children:
child_spec = xatspec.children[field_name]
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
if child_spec.metadata["block"] == block_name: # type: ignore
continue

field_value = data[field_name]
# convert:
# - paths to records
Expand Down Expand Up @@ -89,7 +180,18 @@ def unstructure_component(value: Component) -> dict[str, Any]:
_attach_field_metadata(dataset, type(value), list(block.keys()))
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}

return {name: block for name, block in blocks.items() if block}
# make sure options block always comes first
if "options" in blocks:
options_block = blocks.pop("options")
blocks = {"options": options_block, **blocks}

# total temporary hack! manually set solutiongroup 1. still need to support multiple..
if "solutiongroup" in blocks:
sg = blocks["solutiongroup"]
blocks["solutiongroup 1"] = sg
del blocks["solutiongroup"]

return {name: block for name, block in blocks.items() if name != "period"}


def _make_converter() -> Converter:
Expand Down
33 changes: 16 additions & 17 deletions flopy4/mf6/gwf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def convert_grid(value):

@xattree
class Gwf(Model):
@define
class NewtonOptions:
newton: bool = field()
under_relaxation: bool = field()

@define
class Output:
parent: "Gwf" = attrs.field(repr=False)
Expand All @@ -51,30 +56,24 @@ def budget(self):
self.parent.parent.workspace / f"{self.parent.name}.dis.grb",
)

dis: Dis = field(converter=convert_grid)
ic: Ic = field()
oc: Oc = field()
npf: Npf = field()
chd: list[Chd] = field()
wel: list[Wel] = field()
drn: list[Drn] = field()
output: Output = attrs.field(
default=attrs.Factory(lambda self: Gwf.Output(self), takes_self=True)
)

@define
class NewtonOptions:
newton: bool = field()
under_relaxation: bool = field()

list: Optional[str] = field(block="options", default=None)
_list: Optional[str] = field(block="options", default=None)
print_input: bool = field(block="options", default=False)
print_flows: bool = field(block="options", default=False)
save_flows: bool = field(block="options", default=False)
newtonoptions: Optional[NewtonOptions] = field(block="options", default=None)
nc_mesh2d_filerecord: Optional[Path] = field(block="options", default=None)
nc_structured_filerecord: Optional[Path] = field(block="options", default=None)
nc_filerecord: Optional[Path] = field(block="options", default=None)
dis: Dis = field(converter=convert_grid, block="packages")
ic: Ic = field(block="packages")
oc: Oc = field(block="packages")
npf: Npf = field(block="packages")
chd: list[Chd] = field(block="packages")
wel: list[Wel] = field(block="packages")
drn: list[Drn] = field(block="packages")
output: Output = attrs.field(
default=attrs.Factory(lambda self: Gwf.Output(self), takes_self=True)
)

@property
def grid(self) -> Grid:
Expand Down
8 changes: 4 additions & 4 deletions flopy4/mf6/gwf/ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from flopy4.mf6.spec import array, field


@xattree
@xattree(kw_only=True)
class Ic(Package):
export_array_ascii: bool = field(block="options", default=False)
export_array_netcdf: bool = field(block="options", default=False)
strt: NDArray[np.float64] = array(
block="packagedata",
block="griddata",
dims=("nnodes",),
default=1.0,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
)
export_array_ascii: bool = field(block="options", default=False)
export_array_netcdf: bool = field(block="options", default=False)
2 changes: 1 addition & 1 deletion flopy4/mf6/ims.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class Ims(Solution):
solution_package: ClassVar[Sln] = Sln(abbr="ims", pattern="*")

print_option: bool = field(block="options", default=False)
print_option: Optional[str] = field(block="options", default=None)
complexity: str = field(block="options", default="simple")
csv_outer_output_file: Optional[Path] = field(default=None, block="options")
csv_inner_output_file: Optional[Path] = field(block="options", default=None)
Expand Down
4 changes: 2 additions & 2 deletions flopy4/mf6/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from xattree import xattree

from flopy4.mf6.component import Component
from flopy4.mf6.context import Context


@xattree
class Model(Component, ABC):
class Model(Context, ABC):
def default_filename(self) -> str:
return f"{self.name}.nam" # type: ignore
18 changes: 4 additions & 14 deletions flopy4/mf6/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def convert_time(value):

@xattree
class Simulation(Context):
models: dict[str, Model] = field()
exchanges: dict[str, Exchange] = field()
solutions: dict[str, Solution] = field()
tdis: Tdis = field(converter=convert_time)
tdis: Tdis = field(converter=convert_time, block="timing")
models: dict[str, Model] = field(block="models")
exchanges: dict[str, Exchange] = field(block="exchanges")
solutions: dict[str, Solution] = field(block="solutiongroup")
filename: str = field(default="mfsim.nam", init=False)

def __attrs_post_init__(self):
Expand All @@ -52,13 +52,3 @@ def run(self, exe: str | PathLike = "mf6", verbose: bool = False) -> None:
f"Simulation {self.name}: {exe} failed with " # type: ignore
f"return code {ret}, output:\n\n{out + err} "
)

def load(self, format="ascii"):
"""Load the simulation."""
with cd(self.workspace):
super().load(format=format)

def write(self, format="ascii"):
"""Write the simulation."""
with cd(self.workspace):
super().write(format=format)
3 changes: 2 additions & 1 deletion flopy4/mf6/solution.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC

import attrs
from xattree import xattree

from flopy4.mf6.package import Package


@xattree
class Solution(Package, ABC):
pass
models: list[str] = attrs.field(default=attrs.Factory(list))
2 changes: 1 addition & 1 deletion flopy4/mf6/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def get_dfn_field_type(attribute: Attribute) -> FieldType:
kind = xatmeta["kind"]
match kind:
case "child":
raise ValueError(f"Top-level field should not be a child: {attribute.name}")
return "recarray" # Child components become tabular bindings
case "array":
return "recarray"
case "coord":
Expand Down
Loading
Loading