Skip to content

Commit 37670e4

Browse files
author
wpbonelli
committed
writer child binding
1 parent b2bdd95 commit 37670e4

File tree

8 files changed

+204
-30
lines changed

8 files changed

+204
-30
lines changed

flopy4/mf6/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from flopy4.mf6.codec import dump, load
22
from flopy4.mf6.component import Component
3+
from flopy4.mf6.converter import COMPONENT_CONVERTER
34
from flopy4.uio import DEFAULT_REGISTRY
45

56
# register io methods
6-
# TODO: call this "mf6" or something? since it might include binary files
7-
DEFAULT_REGISTRY.register_loader(Component, "ascii", lambda c: load(c.path))
8-
DEFAULT_REGISTRY.register_writer(Component, "ascii", lambda c: dump(c, c.path))
7+
# TODO: call format "mf6" or something? since it might include binary files
8+
DEFAULT_REGISTRY.register_writer(
9+
Component, "ascii", lambda c: dump(COMPONENT_CONVERTER.unstructure(c), c.path)
10+
)

flopy4/mf6/context.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC
22
from pathlib import Path
33

4+
from modflow_devtools.misc import cd
45
from xattree import xattree
56

67
from flopy4.mf6.component import Component
@@ -14,8 +15,20 @@ class Context(Component, ABC):
1415
def __attrs_post_init__(self):
1516
super().__attrs_post_init__()
1617
if self.workspace is None:
17-
self.workspace = Path.cwd()
18+
self.workspace = (
19+
self.parent.workspace
20+
if self.parent and hasattr(self.parent, "workspace")
21+
else Path.cwd()
22+
)
1823

1924
@property
2025
def path(self) -> Path:
2126
return self.workspace / self.filename
27+
28+
def load(self, format="ascii"):
29+
with cd(self.workspace):
30+
super().load(format=format)
31+
32+
def write(self, format="ascii"):
33+
with cd(self.workspace):
34+
super().write(format=format)

flopy4/mf6/converter.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,62 @@
1+
from collections.abc import MutableMapping
12
from datetime import datetime
23
from pathlib import Path
34
from typing import Any
45

56
import xarray as xr
67
import xattree
8+
from attrs import define
79
from cattrs import Converter
810

911
from flopy4.mf6.component import Component
12+
from flopy4.mf6.context import Context
13+
from flopy4.mf6.exchange import Exchange
14+
from flopy4.mf6.model import Model
15+
from flopy4.mf6.package import Package
1016
from flopy4.mf6.spec import fields_dict, get_blocks
1117

1218

19+
@define
20+
class _Binding:
21+
"""
22+
An MF6 component binding: a record representation of the
23+
component for writing to a parent component's name file.
24+
"""
25+
26+
type: str
27+
fname: str
28+
terms: tuple[str, ...] | None = None
29+
30+
def to_tuple(self):
31+
if self.terms and any(self.terms):
32+
return (self.type, self.fname, *self.terms)
33+
else:
34+
return (self.type, self.fname)
35+
36+
@classmethod
37+
def from_component(cls, component: Component) -> "_Binding":
38+
def _get_binding_type(component: Component) -> str:
39+
cls_name = component.__class__.__name__
40+
if isinstance(component, Exchange):
41+
return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6"
42+
else:
43+
return f"{cls_name.upper()}6"
44+
45+
def _get_binding_terms(component: Component) -> tuple[str, ...] | None:
46+
if isinstance(component, Exchange):
47+
return (component.exgmnamea, component.exgmnameb) # type: ignore
48+
elif isinstance(component, (Model, Package)):
49+
return (component.name,) # type: ignore
50+
# TODO solutions
51+
return None
52+
53+
return cls(
54+
type=_get_binding_type(component),
55+
fname=component.filename or component.default_filename(),
56+
terms=_get_binding_terms(component),
57+
)
58+
59+
1360
def _attach_field_metadata(
1461
dataset: xr.Dataset, component_type: type, field_names: list[str]
1562
) -> None:
@@ -29,14 +76,56 @@ def _path_to_record(field_name: str, path_value: Path) -> tuple:
2976

3077

3178
def unstructure_component(value: Component) -> dict[str, Any]:
32-
data = xattree.asdict(value)
3379
blockspec = get_blocks(value.dfn)
3480
blocks: dict[str, dict[str, Any]] = {}
81+
xatspec = xattree.get_xatspec(type(value))
82+
83+
# Handle child component bindings before converting to dict
84+
if isinstance(value, Context):
85+
for field_name, child_spec in xatspec.children.items():
86+
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
87+
block_name = child_spec.metadata["block"] # type: ignore
88+
field_value = getattr(value, field_name, None)
89+
90+
if block_name not in blocks:
91+
blocks[block_name] = {}
92+
93+
if isinstance(field_value, Component):
94+
components = [_Binding.from_component(field_value).to_tuple()]
95+
elif isinstance(field_value, MutableMapping):
96+
components = [
97+
_Binding.from_component(comp).to_tuple()
98+
for comp in field_value.values()
99+
if comp is not None
100+
]
101+
elif isinstance(field_value, (list, tuple)):
102+
components = [
103+
_Binding.from_component(comp).to_tuple()
104+
for comp in field_value
105+
if comp is not None
106+
]
107+
else:
108+
continue
109+
110+
if components:
111+
blocks[block_name][field_name] = components
112+
113+
data = xattree.asdict(value)
114+
35115
for block_name, block in blockspec.items():
36-
blocks[block_name] = {}
116+
if block_name not in blocks:
117+
blocks[block_name] = {}
37118
period_data = {}
38119
period_blocks = {} # type: ignore
120+
39121
for field_name in block.keys():
122+
# Skip child components that have been processed as bindings
123+
if isinstance(value, Context) and field_name in xatspec.children:
124+
child_spec = xatspec.children[field_name]
125+
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
126+
if child_spec.metadata["block"] == block_name: # type: ignore
127+
continue
128+
40129
field_value = data[field_name]
41130
# convert:
42131
# - paths to records

flopy4/mf6/gwf/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ def budget(self):
5151
self.parent.parent.workspace / f"{self.parent.name}.dis.grb",
5252
)
5353

54-
dis: Dis = field(converter=convert_grid)
55-
ic: Ic = field()
56-
oc: Oc = field()
57-
npf: Npf = field()
58-
chd: list[Chd] = field()
59-
wel: list[Wel] = field()
60-
drn: list[Drn] = field()
54+
dis: Dis = field(converter=convert_grid, block="packages")
55+
ic: Ic = field(block="packages")
56+
oc: Oc = field(block="packages")
57+
npf: Npf = field(block="packages")
58+
chd: list[Chd] = field(block="packages")
59+
wel: list[Wel] = field(block="packages")
60+
drn: list[Drn] = field(block="packages")
6161
output: Output = attrs.field(
6262
default=attrs.Factory(lambda self: Gwf.Output(self), takes_self=True)
6363
)

flopy4/mf6/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from xattree import xattree
44

5-
from flopy4.mf6.component import Component
5+
from flopy4.mf6.context import Context
66

77

88
@xattree
9-
class Model(Component, ABC):
9+
class Model(Context, ABC):
1010
def default_filename(self) -> str:
1111
return f"{self.name}.nam" # type: ignore

flopy4/mf6/simulation.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def convert_time(value):
2323

2424
@xattree
2525
class Simulation(Context):
26-
models: dict[str, Model] = field()
27-
exchanges: dict[str, Exchange] = field()
28-
solutions: dict[str, Solution] = field()
29-
tdis: Tdis = field(converter=convert_time)
26+
models: dict[str, Model] = field(block="models")
27+
exchanges: dict[str, Exchange] = field(block="exchanges")
28+
solutions: dict[str, Solution] = field(block="solutiongroup")
29+
tdis: Tdis = field(converter=convert_time, block="timing")
3030
filename: str = field(default="mfsim.nam", init=False)
3131

3232
def __attrs_post_init__(self):
@@ -52,13 +52,3 @@ def run(self, exe: str | PathLike = "mf6", verbose: bool = False) -> None:
5252
f"Simulation {self.name}: {exe} failed with " # type: ignore
5353
f"return code {ret}, output:\n\n{out + err} "
5454
)
55-
56-
def load(self, format="ascii"):
57-
"""Load the simulation."""
58-
with cd(self.workspace):
59-
super().load(format=format)
60-
61-
def write(self, format="ascii"):
62-
"""Write the simulation."""
63-
with cd(self.workspace):
64-
super().write(format=format)

flopy4/mf6/spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def get_dfn_field_type(attribute: Attribute) -> FieldType:
233233
kind = xatmeta["kind"]
234234
match kind:
235235
case "child":
236-
raise ValueError(f"Top-level field should not be a child: {attribute.name}")
236+
return "recarray" # Child components become tabular bindings
237237
case "array":
238238
return "recarray"
239239
case "coord":

test/test_codec.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,83 @@ def test_dumps_wel_with_auxiliary():
264264
assert "45 -25.0 2.0" in result # (1,3,4) -> node 45, q=-25.0, aux=2.0
265265
assert "1e+30" not in result
266266
assert "1.0e+30" not in result
267+
268+
269+
def test_dumps_gwf():
270+
from flopy4.mf6.gwf import Chd, Dis, Gwf, Ic, Npf, Oc
271+
272+
dis = Dis(nlay=1, nrow=10, ncol=10, delr=100.0, delc=100.0)
273+
gwf = Gwf(name="test_model", dis=dis)
274+
ic = Ic(parent=gwf, strt=1.0)
275+
npf = Npf(parent=gwf, k=1.0)
276+
oc = Oc(parent=gwf, head_file="test.hds", budget_file="test.bud", dims={"nper": 1})
277+
chd = Chd(parent=gwf, head={0: {(0, 0, 0): 10.0}}, dims={"nper": 1})
278+
279+
gwf = Gwf(
280+
name="test_model",
281+
dis=dis,
282+
ic=ic,
283+
npf=npf,
284+
oc=oc,
285+
chd=[chd],
286+
)
287+
288+
result = dumps(COMPONENT_CONVERTER.unstructure(gwf))
289+
print("GWF model result:")
290+
print(result)
291+
292+
# Check that child component bindings are included
293+
assert "DIS6" in result
294+
assert "IC6" in result
295+
assert "NPF6" in result
296+
assert "OC6" in result
297+
assert "test_model.dis" in result
298+
assert "test_model.ic" in result
299+
assert "test_model.npf" in result
300+
assert "test_model.oc" in result
301+
302+
303+
def test_dumps_simulation():
304+
from flopy.discretization.modeltime import ModelTime
305+
306+
from flopy4.mf6.gwf import Dis, Gwf, Ic, Npf, Oc
307+
from flopy4.mf6.simulation import Simulation
308+
from flopy4.mf6.tdis import Tdis
309+
310+
# Create model components
311+
dis = Dis(nlay=1, nrow=5, ncol=5, delr=100.0, delc=100.0)
312+
gwf = Gwf(name="model1", dis=dis)
313+
ic = Ic(parent=gwf, strt=1.0)
314+
npf = Npf(parent=gwf, k=1.0)
315+
oc = Oc(parent=gwf, head_file="model1.hds", budget_file="model1.bud", dims={"nper": 1})
316+
317+
# Create model
318+
gwf = Gwf(
319+
name="model1",
320+
dis=dis,
321+
ic=ic,
322+
npf=npf,
323+
oc=oc,
324+
)
325+
326+
# Create time discretization
327+
time = ModelTime(perlen=[1.0], nstp=[1])
328+
tdis = Tdis.from_time(time)
329+
330+
# Create simulation
331+
sim = Simulation(
332+
name="test_sim",
333+
models={"model1": gwf},
334+
exchanges={},
335+
solutions={},
336+
tdis=tdis,
337+
)
338+
339+
result = dumps(COMPONENT_CONVERTER.unstructure(sim))
340+
print("Simulation result:")
341+
print(result)
342+
343+
# Check that model bindings are included
344+
assert "GWF6" in result
345+
assert "model1" in result
346+
assert "TDIS6" in result

0 commit comments

Comments
 (0)