Skip to content

Commit fcc0928

Browse files
authored
consolidate converter modules, fix dim name nnodes -> nodes (#185)
and clean up some crufty comments
1 parent 15e61c8 commit fcc0928

File tree

17 files changed

+127
-134
lines changed

17 files changed

+127
-134
lines changed

flopy4/mf6/adapters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,15 +319,15 @@ def data_type(self):
319319
case "ndarray":
320320
if "nper" in self._data.dims:
321321
if self._data.ndim == 2:
322-
if "nnodes" in self._data.dims:
322+
if "nodes" in self._data.dims:
323323
return DataType.transient2d # nodes?
324324
if self._data.ndim == 3:
325325
return DataType.transient3d # ncpl?
326326
if self._data.ndim == 4:
327327
return DataType.transient2d # nodes?
328328
else:
329329
if self._data.ndim == 1:
330-
if "nnodes" in self._data.dims:
330+
if "nodes" in self._data.dims:
331331
return DataType.array3d
332332
if self._data.ndim == 2:
333333
return DataType.array2d
@@ -351,7 +351,7 @@ def dtype(self):
351351
@property
352352
def array(self):
353353
if self._spec.type.__name__ == "ndarray":
354-
if "nnodes" in self._data.dims:
354+
if "nodes" in self._data.dims:
355355
if "nper" in self._data.dims:
356356
shape = (
357357
self._time.nper,

flopy4/mf6/codec/writer/filters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
176176
yield (value.item(),)
177177
return
178178

179-
spatial_dims = [d for d in value.dims if d in ("nlay", "nrow", "ncol", "nnodes")]
179+
spatial_dims = [d for d in value.dims if d in ("nlay", "nrow", "ncol", "nodes")]
180180
has_spatial_dims = len(spatial_dims) > 0
181181
mask = nonempty(value)
182182
indices = np.where(mask)
@@ -223,7 +223,7 @@ def dataset2list(value: xr.Dataset):
223223
if combined_mask is None or not np.any(combined_mask):
224224
return
225225

226-
spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nnodes")]
226+
spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nodes")]
227227
has_spatial_dims = len(spatial_dims) > 0
228228
indices = np.where(combined_mask)
229229
for i in range(len(indices[0])):

flopy4/mf6/converter.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@
33
from pathlib import Path
44
from typing import Any
55

6+
import numpy as np
7+
import sparse
68
import xarray as xr
79
import xattree
810
from attrs import define
911
from cattrs import Converter
12+
from numpy.typing import NDArray
13+
from xattree import get_xatspec
1014

15+
from flopy4.adapters import get_nn
1116
from flopy4.mf6.component import Component
17+
from flopy4.mf6.config import SPARSE_THRESHOLD
18+
from flopy4.mf6.constants import FILL_DNODATA
1219
from flopy4.mf6.context import Context
1320
from flopy4.mf6.exchange import Exchange
1421
from flopy4.mf6.model import Model
@@ -148,15 +155,15 @@ def unstructure_component(value: Component) -> dict[str, Any]:
148155
blocks[block_name][field_name] = tuple(field_value.values.tolist())
149156
elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims:
150157
has_spatial_dims = any(
151-
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nnodes"]
158+
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
152159
)
153160
if has_spatial_dims:
154161
# terrible hack to convert flat nodes dimension to 3d structured dims.
155162
# long term solution for this is to use a custom xarray index. filters
156163
# should then have access to all dimensions needed.
157164
dims_ = set(field_value.dims).copy()
158165
dims_.remove("nper")
159-
if dims_ == {"nnodes"}:
166+
if dims_ == {"nodes"}:
160167
parent = value.parent # type: ignore
161168
field_value = xr.DataArray(
162169
field_value.data.reshape(
@@ -228,3 +235,73 @@ def _make_converter() -> Converter:
228235

229236

230237
COMPONENT_CONVERTER = _make_converter()
238+
239+
240+
def dict_to_array(value, self_, field) -> NDArray:
241+
"""
242+
Convert a sparse dictionary representation of an array to a
243+
dense numpy array or a sparse COO array.
244+
"""
245+
246+
if not isinstance(value, dict):
247+
# if not a dict, assume it's a numpy array
248+
# and let xarray deal with it if it isn't
249+
return value
250+
251+
spec = get_xatspec(type(self_)).flat
252+
field = spec[field.name]
253+
if not field.dims:
254+
raise ValueError(f"Field {field} missing dims")
255+
256+
# resolve dims
257+
explicit_dims = self_.__dict__.get("dims", {})
258+
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
259+
dims = inherited_dims | explicit_dims
260+
shape = [dims.get(d, d) for d in field.dims]
261+
unresolved = [d for d in shape if isinstance(d, str)]
262+
if any(unresolved):
263+
raise ValueError(f"Couldn't resolve dims: {unresolved}")
264+
265+
if np.prod(shape) > SPARSE_THRESHOLD:
266+
a: dict[tuple[Any, ...], Any] = dict()
267+
268+
def set_(arr, val, *ind):
269+
arr[tuple(ind)] = val
270+
271+
def final(arr):
272+
coords = np.array(list(map(list, zip(*arr.keys()))))
273+
return sparse.COO(
274+
coords,
275+
list(arr.values()),
276+
shape=shape,
277+
fill_value=field.default or FILL_DNODATA,
278+
)
279+
else:
280+
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore
281+
282+
def set_(arr, val, *ind):
283+
arr[ind] = val
284+
285+
def final(arr):
286+
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
287+
return arr
288+
289+
if "nper" in dims:
290+
for kper, period in value.items():
291+
if kper == "*":
292+
kper = 0
293+
match len(shape):
294+
case 1:
295+
set_(a, period, kper)
296+
case _:
297+
for cellid, v in period.items():
298+
nn = get_nn(cellid, **dims)
299+
set_(a, v, kper, nn)
300+
if kper == "*":
301+
break
302+
else:
303+
for cellid, v in value.items():
304+
nn = get_nn(cellid, **dims)
305+
set_(a, v, nn)
306+
307+
return final(a)

flopy4/mf6/converters.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

flopy4/mf6/gwf/chd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from xattree import xattree
88

99
from flopy4.mf6.component import update_maxbound
10-
from flopy4.mf6.converters import dict_to_array
10+
from flopy4.mf6.converter import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
1313

@@ -29,7 +29,7 @@ class Chd(Package):
2929
block="period",
3030
dims=(
3131
"nper",
32-
"nnodes",
32+
"nodes",
3333
),
3434
default=None,
3535
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
@@ -40,7 +40,7 @@ class Chd(Package):
4040
block="period",
4141
dims=(
4242
"nper",
43-
"nnodes",
43+
"nodes",
4444
),
4545
default=None,
4646
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
@@ -51,7 +51,7 @@ class Chd(Package):
5151
block="period",
5252
dims=(
5353
"nper",
54-
"nnodes",
54+
"nodes",
5555
),
5656
default=None,
5757
converter=Converter(dict_to_array, takes_self=True, takes_field=True),

flopy4/mf6/gwf/dis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.converters import dict_to_array
9+
from flopy4.mf6.converter import dict_to_array
1010
from flopy4.mf6.package import Package
1111
from flopy4.mf6.spec import array, dim, field
1212

@@ -70,14 +70,14 @@ class Dis(Package):
7070
dims=("nlay", "nrow", "ncol"),
7171
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
7272
)
73-
nnodes: int = dim(
73+
nodes: int = dim(
7474
coord="node",
7575
scope="gwf",
7676
init=False,
7777
)
7878

7979
def __attrs_post_init__(self):
80-
self.nnodes = self.ncol * self.nrow * self.nlay
80+
self.nodes = self.ncol * self.nrow * self.nlay
8181
super().__attrs_post_init__()
8282

8383
def to_grid(self) -> StructuredGrid:

flopy4/mf6/gwf/drn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from xattree import xattree
88

99
from flopy4.mf6.component import update_maxbound
10-
from flopy4.mf6.converters import dict_to_array
10+
from flopy4.mf6.converter import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
1313

@@ -29,15 +29,15 @@ class Drn(Package):
2929
maxbound: Optional[int] = field(block="dimensions", default=None, init=False)
3030
elev: Optional[NDArray[np.float64]] = array(
3131
block="period",
32-
dims=("nper", "nnodes"),
32+
dims=("nper", "nodes"),
3333
default=None,
3434
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
3535
reader="urword",
3636
on_setattr=update_maxbound,
3737
)
3838
cond: Optional[NDArray[np.float64]] = array(
3939
block="period",
40-
dims=("nper", "nnodes"),
40+
dims=("nper", "nodes"),
4141
default=None,
4242
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
4343
reader="urword",
@@ -47,7 +47,7 @@ class Drn(Package):
4747
block="period",
4848
dims=(
4949
"nper",
50-
"nnodes",
50+
"nodes",
5151
),
5252
default=None,
5353
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
@@ -58,7 +58,7 @@ class Drn(Package):
5858
block="period",
5959
dims=(
6060
"nper",
61-
"nnodes",
61+
"nodes",
6262
),
6363
default=None,
6464
converter=Converter(dict_to_array, takes_self=True, takes_field=True),

flopy4/mf6/gwf/ic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numpy.typing import NDArray
44
from xattree import xattree
55

6-
from flopy4.mf6.converters import dict_to_array
6+
from flopy4.mf6.converter import dict_to_array
77
from flopy4.mf6.package import Package
88
from flopy4.mf6.spec import array, field
99

@@ -14,7 +14,7 @@ class Ic(Package):
1414
export_array_netcdf: bool = field(block="options", default=False)
1515
strt: NDArray[np.float64] = array(
1616
block="griddata",
17-
dims=("nnodes",),
17+
dims=("nodes",),
1818
default=1.0,
1919
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
2020
)

0 commit comments

Comments
 (0)