Skip to content

Commit d493924

Browse files
authored
chd serialization (#156)
1 parent 6bc8fe7 commit d493924

File tree

10 files changed

+266
-110
lines changed

10 files changed

+266
-110
lines changed

docs/examples/quickstart.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
save_budget={"*": "all"},
3131
)
3232

33+
# sim.write()
3334
sim.run(verbose=True)
3435

3536
# check CHD

flopy4/mf6/codec/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from flopy4.mf6.codec.converter import (
1212
structure_array,
1313
unstructure_array,
14+
unstructure_chd,
1415
unstructure_component,
1516
unstructure_oc,
1617
unstructure_tdis,
@@ -40,24 +41,22 @@
4041

4142
def _make_converter() -> Converter:
4243
from flopy4.mf6.component import Component
44+
from flopy4.mf6.gwf.chd import Chd
4345
from flopy4.mf6.gwf.oc import Oc
4446
from flopy4.mf6.tdis import Tdis
4547

4648
converter = Converter()
4749
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
4850
converter.register_unstructure_hook(Component, unstructure_component)
4951
converter.register_unstructure_hook(Tdis, unstructure_tdis)
52+
converter.register_unstructure_hook(Chd, unstructure_chd)
5053
converter.register_unstructure_hook(Oc, unstructure_oc)
5154
return converter
5255

5356

5457
_CONVERTER = _make_converter()
5558

5659

57-
# TODO unstructure arrays into sparse dicts
58-
# TODO combine OC fields into list input as defined in the MF6 dfn
59-
60-
6160
def loads(data: str) -> Any:
6261
# TODO
6362
pass

flopy4/mf6/codec/converter.py

Lines changed: 119 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
import numpy as np
44
import sparse
55
import xattree
6+
from flopy.discretization.grid import Grid
7+
from flopy.discretization.structuredgrid import StructuredGrid
8+
from flopy.discretization.unstructuredgrid import UnstructuredGrid
9+
from flopy.discretization.vertexgrid import VertexGrid
610
from numpy.typing import NDArray
711
from xarray import DataArray
812
from xattree import get_xatspec
913

1014
from flopy4.mf6.component import Component
1115
from flopy4.mf6.config import SPARSE_THRESHOLD
1216
from flopy4.mf6.constants import FILL_DNODATA
13-
from flopy4.mf6.spec import get_blocks
17+
from flopy4.mf6.spec import get_blocks, is_list_field
1418

1519

1620
# TODO: convert to a cattrs structuring hook so we don't have to
@@ -87,20 +91,16 @@ def _get_nn(cellid):
8791
match len(shape):
8892
case 1:
8993
set_(a, period, kper)
90-
# a[(kper,)] = period
9194
case _:
9295
for cellid, v in period.items():
9396
nn = _get_nn(cellid)
9497
set_(a, v, kper, nn)
95-
# a[(kper, nn)] = v
9698
if kper == "*":
9799
break
98100
else:
99101
for cellid, v in value.items():
100102
nn = _get_nn(cellid)
101103
set_(a, v, nn)
102-
# a[(nn,)] = v
103-
104104
return final(a)
105105

106106

@@ -109,36 +109,47 @@ def unstructure_array(value: DataArray) -> dict:
109109
Convert a dense numpy array or a sparse COO array to a sparse
110110
dictionary representation suitable for serialization into the
111111
MF6 list-based input format.
112+
113+
The input array must have a time dimension named 'nper', i.e.
114+
it must be stress period data for some MODFLOW 6 component.
115+
116+
Returns:
117+
dict: {kper: {spatial indices: value, ...}, ...}
112118
"""
113-
# make sure dim 'kper' is present
114-
time_dim = "nper"
115-
if time_dim not in value.dims:
119+
if (time_dim := "nper") not in value.dims:
116120
raise ValueError(f"Array must have dimension '{time_dim}'")
117-
118121
if isinstance(value.data, sparse.COO):
119122
coords = value.coords
120123
data = value.data
121124
else:
122-
coords = np.array(np.nonzero(value.data)).T # type: ignore
125+
coords = np.array(np.where(value.data != FILL_DNODATA)).T # type: ignore
123126
data = value.data[tuple(coords.T)] # type: ignore
124127
if not coords.size: # type: ignore
125128
return {}
129+
result = {}
126130
match value.ndim:
127131
case 1:
128-
return {int(k): v for k, v in zip(coords[:, 0], data)} # type: ignore
129-
case 2:
130-
return {(int(k), int(j)): v for (k, j), v in zip(coords, data)} # type: ignore
131-
case 3:
132-
return {(int(k), int(i), int(j)): v for (k, i, j), v in zip(coords, data)} # type: ignore
133-
return {}
132+
# Only kper, no spatial dims
133+
for kper, v in zip(coords[:, 0], data):
134+
result[int(kper)] = v
135+
case _:
136+
# kper + spatial dims
137+
for row, v in zip(coords, data):
138+
kper = int(row[0]) # type: ignore
139+
spatial = tuple(int(x) for x in row[1:]) # type: ignore
140+
if kper not in result:
141+
result[kper] = {}
142+
# flatten spatial index if only one spatial dim
143+
key = spatial[0] if len(spatial) == 1 else spatial
144+
result[kper][key] = v
145+
return result
134146

135147

136148
def unstructure_component(value: Component) -> dict[str, Any]:
137149
data = xattree.asdict(value)
138150
for block in get_blocks(value.dfn).values():
139151
for field_name, field in block.items():
140-
# unstructure arrays destined for list-based input
141-
if field["type"] == "recarray" and field["reader"] != "readarray":
152+
if is_list_field(field):
142153
data[field_name] = unstructure_array(data[field_name])
143154
return data
144155

@@ -148,63 +159,119 @@ def unstructure_tdis(value: Any) -> dict[str, Any]:
148159
blocks = get_blocks(value.dfn)
149160
for block_name, block in blocks.items():
150161
if block_name == "perioddata":
151-
array_fields = list(block.keys())
152-
153-
# Unstructure all arrays and collect all unique periods
154-
arrays = {}
162+
arrs_d = {}
155163
periods = set() # type: ignore
156-
for field_name in array_fields:
157-
arr = unstructure_array(data.get(field_name, {}))
158-
arrays[field_name] = arr
159-
periods.update(arr.keys())
164+
for field_name in block.keys():
165+
arr = data.get(field_name, None)
166+
arr_d = {} if arr is None else unstructure_array(arr)
167+
arrs_d[field_name] = arr_d
168+
periods.update(arr_d.keys())
160169
periods = sorted(periods) # type: ignore
161-
162170
perioddata = {} # type: ignore
163171
for kper in periods:
164172
line = []
165-
for arr in arrays.values():
166-
if kper not in perioddata:
167-
perioddata[kper] = [] # type: ignore
168-
line.append(arr[kper])
173+
if kper not in perioddata:
174+
perioddata[kper] = [] # type: ignore
175+
for arr_d in arrs_d.values():
176+
if val := arr_d.get(kper, None):
177+
line.append(val)
169178
perioddata[kper] = tuple(line)
170-
171179
data["perioddata"] = perioddata
172180
return data
173181

174182

183+
def get_kij(nn: int, nlay: int, nrow: int, ncol: int) -> tuple[int, int, int]:
184+
nodes = nlay * nrow * ncol
185+
if nn < 0 or nn >= nodes:
186+
raise ValueError(f"Node number {nn} is out of bounds (1 to {nodes})")
187+
k = (nn - 1) / (ncol * nrow) + 1
188+
ij = nn - (k - 1) * ncol * nrow
189+
i = (ij - 1) / ncol + 1
190+
j = ij - (i - 1) * ncol
191+
return int(k), int(i), int(j)
192+
193+
194+
def get_jk(nn: int, ncpl: int) -> tuple[int, int]:
195+
if nn < 0 or nn >= ncpl:
196+
raise ValueError(f"Node number {nn} is out of bounds (1 to {ncpl})")
197+
k = (nn - 1) / ncpl + 1
198+
j = nn - (k - 1) * ncpl
199+
return int(j), int(k)
200+
201+
202+
def get_cellid(nn: int, grid: Grid) -> tuple[int, ...]:
203+
match grid:
204+
case StructuredGrid():
205+
return get_kij(nn, *grid.shape)
206+
case VertexGrid():
207+
return get_jk(nn, grid.ncpl)
208+
case UnstructuredGrid():
209+
return (nn,)
210+
case _:
211+
raise TypeError(f"Unsupported grid type: {type(grid)}")
212+
213+
214+
def unstructure_chd(value: Any) -> dict[str, Any]:
215+
if (parent := value.parent) is None:
216+
raise ValueError(
217+
"CHD cannot be unstructured without a parent "
218+
"model and corresponding grid discretization."
219+
)
220+
grid = parent.grid
221+
data = xattree.asdict(value)
222+
blocks = get_blocks(value.dfn)
223+
for block_name, block in blocks.items():
224+
if block_name == "period":
225+
arrs_d = {}
226+
periods = set() # type: ignore
227+
for field_name in block.keys():
228+
arr = data.get(field_name, None)
229+
arr_d = {} if arr is None else unstructure_array(arr)
230+
arrs_d[field_name] = arr_d
231+
periods.update(arr_d.keys())
232+
periods = sorted(periods) # type: ignore
233+
perioddata = {} # type: ignore
234+
for kper in periods:
235+
line = []
236+
if kper not in perioddata:
237+
perioddata[kper] = [] # type: ignore
238+
for arr_d in arrs_d.values():
239+
if val := arr_d.get(kper, None):
240+
for nn, v in val.items():
241+
cellid = get_cellid(nn, grid)
242+
line.append((*cellid, v))
243+
perioddata[kper] = tuple(line)
244+
data["period"] = perioddata
245+
return data
246+
247+
175248
def unstructure_oc(value: Any) -> dict[str, Any]:
176249
data = xattree.asdict(value)
177250
blocks = get_blocks(value.dfn)
178251
for block_name, block in blocks.items():
179252
if block_name == "period":
180-
# Dynamically collect all recarray fields in perioddata block
181-
array_fields = []
253+
fields = []
182254
for field_name, field in block.items():
183-
# Try to split field_name into action and kind, e.g. save_head -> ("save", "head")
184255
action, rtype = field_name.split("_")
185-
array_fields.append((action, rtype, field_name))
186-
187-
# Unstructure all arrays and collect all unique periods
188-
arrays = {}
256+
fields.append((action, rtype, field_name))
257+
arrs_d = {}
189258
periods = set() # type: ignore
190-
for action, rtype, field_name in array_fields:
191-
arr = unstructure_array(data.get(field_name, {}))
192-
arrays[(action, rtype)] = arr
193-
periods.update(arr.keys())
259+
for action, rtype, field_name in fields:
260+
arr = data.get(field_name, None)
261+
arr_d = {} if arr is None else unstructure_array(arr)
262+
arrs_d[(action, rtype)] = arr_d
263+
periods.update(arr_d.keys())
194264
periods = sorted(periods) # type: ignore
195-
196265
perioddata = {} # type: ignore
197266
for kper in periods:
198-
for (action, rtype), arr in arrays.items():
199-
if kper in arr:
200-
if kper not in perioddata:
201-
perioddata[kper] = []
202-
perioddata[kper].append((action, rtype, arr[kper]))
203-
267+
if kper not in perioddata:
268+
perioddata[kper] = []
269+
for (action, rtype), arr_d in arrs_d.items():
270+
if arr := arr_d.get(kper, None):
271+
perioddata[kper].append((action, rtype, arr))
204272
data["period"] = perioddata
205273
else:
206274
for field_name, field in block.items():
207-
# unstructure arrays destined for list-based input
208-
if field["type"] == "recarray" and field["reader"] != "readarray":
275+
if is_list_field(field):
209276
data[field_name] = unstructure_array(data[field_name])
210277
return data

flopy4/mf6/filters.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,7 @@
77
from modflow_devtools.dfn import Dfn, Field
88
from numpy.typing import NDArray
99

10-
from flopy4.mf6.spec import get_blocks
11-
12-
13-
def _is_list_block(block: dict) -> bool:
14-
return (
15-
len(block) == 1
16-
and (field := next(iter(block.values())))["type"] == "recarray"
17-
and field["reader"] != "readarray"
18-
) or (all(f["type"] == "recarray" and f["reader"] != "readarray" for f in block.values()))
10+
from flopy4.mf6.spec import get_blocks, is_list_block
1911

2012

2113
def dict_blocks(dfn: Dfn) -> dict:
@@ -28,13 +20,13 @@ def dict_blocks(dfn: Dfn) -> dict:
2820
return {
2921
block_name: block
3022
for block_name, block in get_blocks(dfn).items()
31-
if not _is_list_block(block)
23+
if not is_list_block(block)
3224
}
3325

3426

3527
def list_blocks(dfn: Dfn) -> dict:
3628
return {
37-
block_name: block for block_name, block in get_blocks(dfn).items() if _is_list_block(block)
29+
block_name: block for block_name, block in get_blocks(dfn).items() if is_list_block(block)
3830
}
3931

4032

flopy4/mf6/gwf/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@ class Output:
3737
def head(self) -> xr.DataArray:
3838
# TODO support other extensions than .hds (e.g. .hed)
3939
return open_hds(
40-
self.parent.parent.path / f"{self.parent.name}.hds", # type: ignore
41-
self.parent.parent.path / f"{self.parent.name}.dis.grb", # type: ignore
40+
self.parent.parent.workspace / f"{self.parent.name}.hds", # type: ignore
41+
self.parent.parent.workspace / f"{self.parent.name}.dis.grb", # type: ignore
4242
)
4343

4444
@property
4545
def budget(self):
4646
# TODO support other extensions than .bud (e.g. .cbc)
4747
return open_cbc(
48-
self.parent.parent.path / f"{self.parent.name}.bud",
49-
self.parent.parent.path / f"{self.parent.name}.dis.grb",
48+
self.parent.parent.workspace / f"{self.parent.name}.bud",
49+
self.parent.parent.workspace / f"{self.parent.name}.dis.grb",
5050
)
5151

5252
dis: Dis = field(converter=convert_grid)

0 commit comments

Comments
 (0)