Skip to content

Commit e695f51

Browse files
authored
writer child binding (#174)
maybe not the last piece of #109 but we can now successfully write/run the simulation! it just fails on postprocessing/plotting because we got some inputs wrong.
1 parent b2bdd95 commit e695f51

File tree

14 files changed

+249
-56
lines changed

14 files changed

+249
-56
lines changed

docs/examples/quickstart.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,32 @@
55
from flopy.discretization.modeltime import ModelTime
66
from flopy.discretization.structuredgrid import StructuredGrid
77

8-
from flopy4.mf6.gwf import Chd, Gwf, Npf, Oc
8+
from flopy4.mf6.gwf import Chd, Gwf, Ic, Npf, Oc
99
from flopy4.mf6.ims import Ims
1010
from flopy4.mf6.simulation import Simulation
1111

1212
name = "quickstart"
1313
workspace = Path(__file__).parent / name
1414
time = ModelTime(perlen=[1.0], nstp=[1])
15-
grid = StructuredGrid(nlay=1, nrow=10, ncol=10)
15+
grid = StructuredGrid(
16+
nlay=1,
17+
nrow=10,
18+
ncol=10,
19+
delr=1.0 * np.ones(10),
20+
delc=1.0 * np.ones(10),
21+
top=1.0 * np.ones((10, 10)),
22+
botm=0.0 * np.ones((1, 10, 10)),
23+
)
1624
sim = Simulation(name=name, workspace=workspace, tdis=time)
17-
ims = Ims(parent=sim)
1825
gwf_name = "mymodel"
26+
ims = Ims(parent=sim, models=[gwf_name]) # temporary hack
1927
gwf = Gwf(parent=sim, name=gwf_name, save_flows=True, dis=grid)
2028
npf = Npf(parent=gwf, save_specific_discharge=True)
2129
chd = Chd(
2230
parent=gwf,
2331
head={0: {(0, 0, 0): 1.0, (0, 9, 9): 0.0}},
2432
)
33+
ic = Ic(parent=gwf, strt=1.0)
2534
oc = Oc(
2635
parent=gwf,
2736
budget_file=f"{gwf.name}.bud",
@@ -30,18 +39,15 @@
3039
save_budget={0: "all"},
3140
)
3241

33-
# sim.write()
42+
sim.write()
3443
sim.run(verbose=True)
3544

36-
# check CHD
3745
assert chd.data["head"][0, 0] == 1.0
3846
assert chd.data.head.sel(per=0)[99] == 0.0
3947
assert np.allclose(chd.data.head[:, 1:99], np.full(98, 1e30))
4048

41-
# check DIS
4249
assert gwf.dis.data.botm.sel(lay=0, col=0, row=0) == 0.0
4350

44-
# check OC
4551
assert oc.data["save_head"][0] == "all"
4652
assert oc.data.save_head.sel(per=0) == "all"
4753

flopy4/mf6/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from flopy4.mf6.codec import dump, load
1+
from flopy4.mf6.codec import dump
22
from flopy4.mf6.component import Component
3+
from flopy4.mf6.converter import COMPONENT_CONVERTER
34
from flopy4.uio import DEFAULT_REGISTRY
45

56
# register io methods
6-
# TODO: call this "mf6" or something? since it might include binary files
7-
DEFAULT_REGISTRY.register_loader(Component, "ascii", lambda c: load(c.path))
8-
DEFAULT_REGISTRY.register_writer(Component, "ascii", lambda c: dump(c, c.path))
7+
# TODO: call format "mf6" or something? since it might include binary files
8+
DEFAULT_REGISTRY.register_writer(
9+
Component, "ascii", lambda c: dump(COMPONENT_CONVERTER.unstructure(c), c.path)
10+
)

flopy4/mf6/codec/writer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from jinja2 import Environment, PackageLoader
66

7-
from flopy4.mf6 import filters
7+
from flopy4.mf6.codec.writer import filters
88

99
_JINJA_ENV = Environment(
1010
loader=PackageLoader("flopy4.mf6.codec.writer"),
File renamed without changes.

flopy4/mf6/context.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC
22
from pathlib import Path
33

4+
from modflow_devtools.misc import cd
45
from xattree import xattree
56

67
from flopy4.mf6.component import Component
@@ -14,8 +15,20 @@ class Context(Component, ABC):
1415
def __attrs_post_init__(self):
1516
super().__attrs_post_init__()
1617
if self.workspace is None:
17-
self.workspace = Path.cwd()
18+
self.workspace = (
19+
self.parent.workspace
20+
if self.parent and hasattr(self.parent, "workspace")
21+
else Path.cwd()
22+
)
1823

1924
@property
2025
def path(self) -> Path:
2126
return self.workspace / self.filename
27+
28+
def load(self, format="ascii"):
29+
with cd(self.workspace):
30+
super().load(format=format)
31+
32+
def write(self, format="ascii"):
33+
with cd(self.workspace):
34+
super().write(format=format)

flopy4/mf6/converter.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,64 @@
1+
from collections.abc import MutableMapping
12
from datetime import datetime
23
from pathlib import Path
34
from typing import Any
45

56
import xarray as xr
67
import xattree
8+
from attrs import define
79
from cattrs import Converter
810

911
from flopy4.mf6.component import Component
12+
from flopy4.mf6.context import Context
13+
from flopy4.mf6.exchange import Exchange
14+
from flopy4.mf6.model import Model
15+
from flopy4.mf6.package import Package
16+
from flopy4.mf6.solution import Solution
1017
from flopy4.mf6.spec import fields_dict, get_blocks
1118

1219

20+
@define
21+
class _Binding:
22+
"""
23+
An MF6 component binding: a record representation of the
24+
component for writing to a parent component's name file.
25+
"""
26+
27+
type: str
28+
fname: str
29+
terms: tuple[str, ...] | None = None
30+
31+
def to_tuple(self):
32+
if self.terms and any(self.terms):
33+
return (self.type, self.fname, *self.terms)
34+
else:
35+
return (self.type, self.fname)
36+
37+
@classmethod
38+
def from_component(cls, component: Component) -> "_Binding":
39+
def _get_binding_type(component: Component) -> str:
40+
cls_name = component.__class__.__name__
41+
if isinstance(component, Exchange):
42+
return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6"
43+
else:
44+
return f"{cls_name.upper()}6"
45+
46+
def _get_binding_terms(component: Component) -> tuple[str, ...] | None:
47+
if isinstance(component, Exchange):
48+
return (component.exgmnamea, component.exgmnameb) # type: ignore
49+
elif isinstance(component, Solution):
50+
return tuple(component.models)
51+
elif isinstance(component, (Model, Package)):
52+
return (component.name,) # type: ignore
53+
return None
54+
55+
return cls(
56+
type=_get_binding_type(component),
57+
fname=component.filename or component.default_filename(),
58+
terms=_get_binding_terms(component),
59+
)
60+
61+
1362
def _attach_field_metadata(
1463
dataset: xr.Dataset, component_type: type, field_names: list[str]
1564
) -> None:
@@ -29,14 +78,56 @@ def _path_to_record(field_name: str, path_value: Path) -> tuple:
2978

3079

3180
def unstructure_component(value: Component) -> dict[str, Any]:
32-
data = xattree.asdict(value)
3381
blockspec = get_blocks(value.dfn)
3482
blocks: dict[str, dict[str, Any]] = {}
83+
xatspec = xattree.get_xatspec(type(value))
84+
85+
# Handle child component bindings before converting to dict
86+
if isinstance(value, Context):
87+
for field_name, child_spec in xatspec.children.items():
88+
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
89+
block_name = child_spec.metadata["block"] # type: ignore
90+
field_value = getattr(value, field_name, None)
91+
92+
if block_name not in blocks:
93+
blocks[block_name] = {}
94+
95+
if isinstance(field_value, Component):
96+
components = [_Binding.from_component(field_value).to_tuple()]
97+
elif isinstance(field_value, MutableMapping):
98+
components = [
99+
_Binding.from_component(comp).to_tuple()
100+
for comp in field_value.values()
101+
if comp is not None
102+
]
103+
elif isinstance(field_value, (list, tuple)):
104+
components = [
105+
_Binding.from_component(comp).to_tuple()
106+
for comp in field_value
107+
if comp is not None
108+
]
109+
else:
110+
continue
111+
112+
if components:
113+
blocks[block_name][field_name] = components
114+
115+
data = xattree.asdict(value)
116+
35117
for block_name, block in blockspec.items():
36-
blocks[block_name] = {}
118+
if block_name not in blocks:
119+
blocks[block_name] = {}
37120
period_data = {}
38121
period_blocks = {} # type: ignore
122+
39123
for field_name in block.keys():
124+
# Skip child components that have been processed as bindings
125+
if isinstance(value, Context) and field_name in xatspec.children:
126+
child_spec = xatspec.children[field_name]
127+
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
128+
if child_spec.metadata["block"] == block_name: # type: ignore
129+
continue
130+
40131
field_value = data[field_name]
41132
# convert:
42133
# - paths to records
@@ -89,7 +180,18 @@ def unstructure_component(value: Component) -> dict[str, Any]:
89180
_attach_field_metadata(dataset, type(value), list(block.keys()))
90181
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}
91182

92-
return {name: block for name, block in blocks.items() if block}
183+
# make sure options block always comes first
184+
if "options" in blocks:
185+
options_block = blocks.pop("options")
186+
blocks = {"options": options_block, **blocks}
187+
188+
# total temporary hack! manually set solutiongroup 1. still need to support multiple..
189+
if "solutiongroup" in blocks:
190+
sg = blocks["solutiongroup"]
191+
blocks["solutiongroup 1"] = sg
192+
del blocks["solutiongroup"]
193+
194+
return {name: block for name, block in blocks.items() if name != "period"}
93195

94196

95197
def _make_converter() -> Converter:

flopy4/mf6/gwf/__init__.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def convert_grid(value):
3131

3232
@xattree
3333
class Gwf(Model):
34+
@define
35+
class NewtonOptions:
36+
newton: bool = field()
37+
under_relaxation: bool = field()
38+
3439
@define
3540
class Output:
3641
parent: "Gwf" = attrs.field(repr=False)
@@ -51,30 +56,24 @@ def budget(self):
5156
self.parent.parent.workspace / f"{self.parent.name}.dis.grb",
5257
)
5358

54-
dis: Dis = field(converter=convert_grid)
55-
ic: Ic = field()
56-
oc: Oc = field()
57-
npf: Npf = field()
58-
chd: list[Chd] = field()
59-
wel: list[Wel] = field()
60-
drn: list[Drn] = field()
61-
output: Output = attrs.field(
62-
default=attrs.Factory(lambda self: Gwf.Output(self), takes_self=True)
63-
)
64-
65-
@define
66-
class NewtonOptions:
67-
newton: bool = field()
68-
under_relaxation: bool = field()
69-
70-
list: Optional[str] = field(block="options", default=None)
59+
_list: Optional[str] = field(block="options", default=None)
7160
print_input: bool = field(block="options", default=False)
7261
print_flows: bool = field(block="options", default=False)
7362
save_flows: bool = field(block="options", default=False)
7463
newtonoptions: Optional[NewtonOptions] = field(block="options", default=None)
7564
nc_mesh2d_filerecord: Optional[Path] = field(block="options", default=None)
7665
nc_structured_filerecord: Optional[Path] = field(block="options", default=None)
7766
nc_filerecord: Optional[Path] = field(block="options", default=None)
67+
dis: Dis = field(converter=convert_grid, block="packages")
68+
ic: Ic = field(block="packages")
69+
oc: Oc = field(block="packages")
70+
npf: Npf = field(block="packages")
71+
chd: list[Chd] = field(block="packages")
72+
wel: list[Wel] = field(block="packages")
73+
drn: list[Drn] = field(block="packages")
74+
output: Output = attrs.field(
75+
default=attrs.Factory(lambda self: Gwf.Output(self), takes_self=True)
76+
)
7877

7978
@property
8079
def grid(self) -> Grid:

flopy4/mf6/gwf/ic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from flopy4.mf6.spec import array, field
99

1010

11-
@xattree
11+
@xattree(kw_only=True)
1212
class Ic(Package):
13+
export_array_ascii: bool = field(block="options", default=False)
14+
export_array_netcdf: bool = field(block="options", default=False)
1315
strt: NDArray[np.float64] = array(
14-
block="packagedata",
16+
block="griddata",
1517
dims=("nnodes",),
1618
default=1.0,
1719
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
1820
)
19-
export_array_ascii: bool = field(block="options", default=False)
20-
export_array_netcdf: bool = field(block="options", default=False)

flopy4/mf6/ims.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class Ims(Solution):
1313
solution_package: ClassVar[Sln] = Sln(abbr="ims", pattern="*")
1414

15-
print_option: bool = field(block="options", default=False)
15+
print_option: Optional[str] = field(block="options", default=None)
1616
complexity: str = field(block="options", default="simple")
1717
csv_outer_output_file: Optional[Path] = field(default=None, block="options")
1818
csv_inner_output_file: Optional[Path] = field(block="options", default=None)

flopy4/mf6/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from xattree import xattree
44

5-
from flopy4.mf6.component import Component
5+
from flopy4.mf6.context import Context
66

77

88
@xattree
9-
class Model(Component, ABC):
9+
class Model(Context, ABC):
1010
def default_filename(self) -> str:
1111
return f"{self.name}.nam" # type: ignore

0 commit comments

Comments
 (0)