Skip to content

Commit a7b4097

Browse files
committed
work on custom unstructuring oc perioddata
1 parent 164bcea commit a7b4097

File tree

10 files changed

+238
-75
lines changed

10 files changed

+238
-75
lines changed

flopy4/mf6/codec/__init__.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@
88
from jinja2 import Environment, PackageLoader
99

1010
from flopy4.mf6 import filters
11-
from flopy4.mf6.codec.converter import structure_array, unstructure_array
11+
from flopy4.mf6.codec.converter import (
12+
structure_array,
13+
unstructure_array,
14+
unstructure_component,
15+
unstructure_oc,
16+
)
17+
from flopy4.mf6.spec import get_blocks
1218

1319
_JINJA_ENV = Environment(
1420
loader=PackageLoader("flopy4.mf6"),
1521
trim_blocks=True,
1622
lstrip_blocks=True,
1723
)
18-
_JINJA_ENV.filters["blocks"] = filters.blocks
24+
_JINJA_ENV.filters["blocks"] = get_blocks
1925
_JINJA_ENV.filters["field_type"] = filters.field_type
2026
_JINJA_ENV.filters["field_value"] = filters.field_value
21-
_JINJA_ENV.filters["is_list"] = filters.is_list
2227
_JINJA_ENV.filters["array_how"] = filters.array_how
2328
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
2429
_JINJA_ENV.filters["array2string"] = filters.array2string
@@ -31,10 +36,20 @@
3136
"threshold": sys.maxsize,
3237
}
3338

34-
_CONVERTER = Converter()
35-
_CONVERTER.register_unstructure_hook_factory(
36-
lambda cls: xattree.has(cls), lambda cls: xattree.asdict
37-
)
39+
40+
def _make_converter() -> Converter:
41+
from flopy4.mf6.component import Component
42+
from flopy4.mf6.gwf.oc import Oc
43+
44+
converter = Converter()
45+
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
46+
converter.register_unstructure_hook(Component, unstructure_component)
47+
converter.register_unstructure_hook(Oc, unstructure_oc)
48+
return converter
49+
50+
51+
_CONVERTER = _make_converter()
52+
3853

3954
# TODO unstructure arrays into sparse dicts
4055
# TODO combine OC fields into list input as defined in the MF6 dfn

flopy4/mf6/codec/converter.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
import numpy as np
44
import sparse
5+
import xattree
56
from numpy.typing import NDArray
67
from xarray import DataArray
78
from xattree import get_xatspec
89

10+
from flopy4.mf6.component import Component
911
from flopy4.mf6.config import SPARSE_THRESHOLD
1012
from flopy4.mf6.constants import FILL_DNODATA
13+
from flopy4.mf6.spec import get_blocks
1114

1215

1316
# TODO: convert to a cattrs structuring hook so we don't have to
@@ -108,22 +111,90 @@ def unstructure_array(value: DataArray) -> dict:
108111
MF6 list-based input format.
109112
"""
110113
# make sure dim 'kper' is present
111-
if "kper" not in value.dims:
112-
raise ValueError("array must have 'kper' dimension")
114+
time_dim = "nper"
115+
if time_dim not in value.dims:
116+
raise ValueError(f"Array must have dimension '{time_dim}'")
113117

114118
if isinstance(value.data, sparse.COO):
115119
coords = value.coords
116120
data = value.data
117121
else:
118-
coords = np.array(np.nonzero(value)).T # type: ignore
119-
data = value[tuple(coords.T)] # type: ignore
122+
coords = np.array(np.nonzero(value.data)).T # type: ignore
123+
data = value.data[tuple(coords.T)] # type: ignore
120124
if not coords.size: # type: ignore
121125
return {}
122126
match value.ndim:
123127
case 1:
124-
return {k: v for k, v in zip(coords[:, 0], data)} # type: ignore
128+
return {int(k): v for k, v in zip(coords[:, 0], data)} # type: ignore
125129
case 2:
126-
return {(k, j): v for (k, j), v in zip(coords, data)} # type: ignore
130+
return {(int(k), int(j)): v for (k, j), v in zip(coords, data)} # type: ignore
127131
case 3:
128-
return {(k, i, j): v for (k, i, j), v in zip(coords, data)} # type: ignore
132+
return {(int(k), int(i), int(j)): v for (k, i, j), v in zip(coords, data)} # type: ignore
129133
return {}
134+
135+
136+
def unstructure_component(value: Component) -> dict[str, Any]:
137+
data = xattree.asdict(value)
138+
for block in get_blocks(value.dfn).values():
139+
for field_name, field in block.items():
140+
# unstructure arrays destined for list-based input
141+
if field["type"] == "recarray" and field["reader"] != "readarray":
142+
data[field_name] = unstructure_array(data[field_name])
143+
return data
144+
145+
146+
def unstructure_oc(value: Any) -> dict[str, Any]:
147+
data = xattree.asdict(value)
148+
for block_name, block in get_blocks(value.dfn).items():
149+
if block_name == "perioddata":
150+
# Unstructure all four arrays
151+
save_head = unstructure_array(data.get("save_head", {}))
152+
save_budget = unstructure_array(data.get("save_budget", {}))
153+
print_head = unstructure_array(data.get("print_head", {}))
154+
print_budget = unstructure_array(data.get("print_budget", {}))
155+
156+
# Collect all unique periods
157+
all_periods = set() # type: ignore
158+
for d in (save_head, save_budget, print_head, print_budget):
159+
if isinstance(d, dict):
160+
all_periods.update(d.keys())
161+
all_periods = sorted(all_periods) # type: ignore
162+
163+
saverecord = {} # type: ignore
164+
printrecord = {} # type: ignore
165+
for kper in all_periods:
166+
# Save head
167+
if kper in save_head:
168+
v = save_head[kper]
169+
if kper not in saverecord:
170+
saverecord[kper] = []
171+
saverecord[kper].append({"action": "save", "type": "head", "ocsetting": v})
172+
# Save budget
173+
if kper in save_budget:
174+
v = save_budget[kper]
175+
if kper not in saverecord:
176+
saverecord[kper] = []
177+
saverecord[kper].append({"action": "save", "type": "budget", "ocsetting": v})
178+
# Print head
179+
if kper in print_head:
180+
v = print_head[kper]
181+
if kper not in printrecord:
182+
printrecord[kper] = []
183+
printrecord[kper].append({"action": "print", "type": "head", "ocsetting": v})
184+
# Print budget
185+
if kper in print_budget:
186+
v = print_budget[kper]
187+
if kper not in printrecord:
188+
printrecord[kper] = []
189+
printrecord[kper].append({"action": "print", "type": "budget", "ocsetting": v})
190+
191+
data["saverecord"] = saverecord
192+
data["printrecord"] = printrecord
193+
data["save"] = "save"
194+
data["print"] = "print"
195+
else:
196+
for field_name, field in block.items():
197+
# unstructure arrays destined for list-based input
198+
if field["type"] == "recarray" and field["reader"] != "readarray":
199+
data[field_name] = unstructure_array(data[field_name])
200+
return data

flopy4/mf6/component.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC
22
from collections.abc import MutableMapping
33
from pathlib import Path
4+
from typing import ClassVar
45

56
from modflow_devtools.dfn import Dfn, Field
67
from xattree import xattree
@@ -33,6 +34,8 @@ class Component(ABC, MutableMapping):
3334

3435
filename: str = field(default=None)
3536

37+
dfn: ClassVar[Dfn]
38+
3639
@property
3740
def path(self) -> Path:
3841
return Path.cwd() / self.filename

flopy4/mf6/filters.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,12 @@
11
from collections.abc import Hashable, Mapping
22
from io import StringIO
3-
from typing import Any
43

54
import numpy as np
65
import xarray as xr
76
from jinja2 import pass_context
8-
from modflow_devtools.dfn import Dfn, Field
7+
from modflow_devtools.dfn import Field
98
from numpy.typing import NDArray
109

11-
from flopy4.mf6.spec import block_sort_key
12-
13-
14-
def blocks(dfn: Dfn) -> dict:
15-
"""
16-
Get blocks from an MF6 input definition. Anything not an
17-
explicitly defined key in the `Dfn` typed dict is a block.
18-
"""
19-
return dict(
20-
sorted(
21-
{k: v for k, v in dfn.items() if k not in Dfn.__annotations__}.items(),
22-
key=block_sort_key,
23-
)
24-
)
25-
2610

2711
def field_type(field: Field) -> str:
2812
"""
@@ -105,7 +89,3 @@ def array2string(value: NDArray) -> str:
10589
)
10690
np.savetxt(buffer, value, fmt=format, delimiter=" ")
10791
return buffer.getvalue().strip()
108-
109-
110-
def is_list(value: Any) -> bool:
111-
return isinstance(value, list)

flopy4/mf6/gwf/oc.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from attrs import Converter, define
6+
from modflow_devtools.dfn import Dfn, Field
67
from numpy.typing import NDArray
78
from xattree import xattree
89

@@ -11,6 +12,64 @@
1112
from flopy4.mf6.spec import array, field
1213
from flopy4.utils import to_path
1314

15+
_OCSETTING = Field(
16+
name="ocsetting",
17+
type="keystring",
18+
reader="urword",
19+
children={
20+
"all": Field(
21+
name="all",
22+
type="keyword",
23+
reader="urword",
24+
),
25+
"first": Field(
26+
name="first",
27+
type="keyword",
28+
reader="urword",
29+
),
30+
"last": Field(
31+
name="last",
32+
type="keyword",
33+
reader="urword",
34+
),
35+
"steps": Field(
36+
name="steps",
37+
type="integer",
38+
reader="urword",
39+
),
40+
"frequency": Field(
41+
name="frequency",
42+
type="integer",
43+
reader="urword",
44+
),
45+
},
46+
)
47+
48+
_RTYPE = Field(
49+
name="rtype",
50+
type="string",
51+
reader="urword",
52+
)
53+
54+
55+
def _oc_action_field(action: str) -> Field:
56+
return Field(
57+
name=f"{action}record",
58+
type="recarray",
59+
dims=("nper",),
60+
block="perioddata",
61+
reader="urword",
62+
children={
63+
action: Field(
64+
name=action,
65+
type="keyword",
66+
reader="urword",
67+
),
68+
"rtype": _RTYPE,
69+
"ocsetting": _OCSETTING,
70+
},
71+
)
72+
1473

1574
@xattree
1675
class Oc(Package):
@@ -56,25 +115,39 @@ class Period:
56115
default="all",
57116
dims=("nper",),
58117
converter=Converter(structure_array, takes_self=True, takes_field=True),
118+
reader="urword",
59119
)
60120
save_budget: Optional[NDArray[np.object_]] = array(
61121
Steps,
62122
block="perioddata",
63123
default="all",
64124
dims=("nper",),
65125
converter=Converter(structure_array, takes_self=True, takes_field=True),
126+
reader="urword",
66127
)
67128
print_head: Optional[NDArray[np.object_]] = array(
68129
Steps,
69130
block="perioddata",
70131
default="all",
71132
dims=("nper",),
72133
converter=Converter(structure_array, takes_self=True, takes_field=True),
134+
reader="urword",
73135
)
74136
print_budget: Optional[NDArray[np.object_]] = array(
75137
Steps,
76138
block="perioddata",
77139
default="all",
78140
dims=("nper",),
79141
converter=Converter(structure_array, takes_self=True, takes_field=True),
142+
reader="urword",
80143
)
144+
145+
@classmethod
146+
def get_dfn(cls) -> Dfn:
147+
"""Generate the component's MODFLOW 6 definition."""
148+
dfn = super().get_dfn()
149+
for field_name in list(dfn["perioddata"].keys()):
150+
dfn["perioddata"].pop(field_name)
151+
dfn["perioddata"]["saverecord"] = _oc_action_field("save")
152+
dfn["perioddata"]["printrecord"] = _oc_action_field("print")
153+
return dfn

flopy4/mf6/spec.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111
from attrs import NOTHING, Attribute
12-
from modflow_devtools.dfn import Field, FieldType
12+
from modflow_devtools.dfn import Dfn, Field, FieldType, Reader
1313

1414
from flopy4.spec import array as flopy_array
1515
from flopy4.spec import coord as flopy_coord
@@ -32,6 +32,7 @@ def field(
3232
if block:
3333
metadata = metadata or {}
3434
metadata["block"] = block
35+
metadata["reader"] = "urword"
3536
return flopy_field(
3637
default=default,
3738
validator=validator,
@@ -57,6 +58,7 @@ def dim(
5758
if block:
5859
metadata = metadata or {}
5960
metadata["block"] = block
61+
metadata["reader"] = "urword"
6062
return flopy_dim(
6163
scope=scope,
6264
coord=coord,
@@ -80,6 +82,7 @@ def coord(
8082
if block:
8183
metadata = metadata or {}
8284
metadata["block"] = block
85+
metadata["reader"] = "readarray"
8386
return flopy_coord(
8487
scope=scope,
8588
default=default,
@@ -99,11 +102,13 @@ def array(
99102
eq=None,
100103
metadata=None,
101104
block: str | None = None,
105+
reader: Reader = "readarray",
102106
):
103107
"""Define an array field."""
104108
if block:
105109
metadata = metadata or {}
106110
metadata["block"] = block
111+
metadata["reader"] = reader
107112
return flopy_array(
108113
cls=cls,
109114
dims=dims,
@@ -227,4 +232,18 @@ def to_dfn_field(attribute: Attribute) -> Field:
227232
children={k: to_dfn_field(v) for k, v in fields_dict(attribute.type)} # type: ignore
228233
if attribute.metadata.get("kind", None) == "child" # type: ignore
229234
else None, # type: ignore
235+
reader=attribute.metadata.get("reader", "urword"),
236+
)
237+
238+
239+
def get_blocks(dfn: Dfn) -> dict:
240+
"""
241+
Get blocks from an MF6 input definition. Anything not an
242+
explicitly defined key in the `Dfn` typed dict is a block.
243+
"""
244+
return dict(
245+
sorted(
246+
{k: v for k, v in dfn.items() if k not in Dfn.__annotations__}.items(),
247+
key=block_sort_key,
248+
)
230249
)

0 commit comments

Comments
 (0)