Skip to content

Commit a0206b5

Browse files
authored
back to original dict->array converter (#169)
this makes more sense to special case than try to generalize
1 parent 83bd261 commit a0206b5

File tree

13 files changed

+1775
-1530
lines changed

13 files changed

+1775
-1530
lines changed

flopy4/mf6/converters.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
import sparse
5+
from numpy.typing import NDArray
6+
from xattree import get_xatspec
7+
8+
from flopy4.adapters import get_nn
9+
from flopy4.mf6.config import SPARSE_THRESHOLD
10+
from flopy4.mf6.constants import FILL_DNODATA
11+
12+
13+
def dict_to_array(value, self_, field) -> NDArray:
14+
"""
15+
Convert a sparse dictionary representation of an array to a
16+
dense numpy array or a sparse COO array.
17+
"""
18+
19+
if not isinstance(value, dict):
20+
# if not a dict, assume it's a numpy array
21+
# and let xarray deal with it if it isn't
22+
return value
23+
24+
spec = get_xatspec(type(self_)).flat
25+
field = spec[field.name]
26+
if not field.dims:
27+
raise ValueError(f"Field {field} missing dims")
28+
29+
# resolve dims
30+
explicit_dims = self_.__dict__.get("dims", {})
31+
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
32+
dims = inherited_dims | explicit_dims
33+
shape = [dims.get(d, d) for d in field.dims]
34+
unresolved = [d for d in shape if isinstance(d, str)]
35+
if any(unresolved):
36+
raise ValueError(f"Couldn't resolve dims: {unresolved}")
37+
38+
if np.prod(shape) > SPARSE_THRESHOLD:
39+
a: dict[tuple[Any, ...], Any] = dict()
40+
41+
def set_(arr, val, *ind):
42+
arr[tuple(ind)] = val
43+
44+
def final(arr):
45+
coords = np.array(list(map(list, zip(*arr.keys()))))
46+
return sparse.COO(
47+
coords,
48+
list(arr.values()),
49+
shape=shape,
50+
fill_value=field.default or FILL_DNODATA,
51+
)
52+
else:
53+
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore
54+
55+
def set_(arr, val, *ind):
56+
arr[ind] = val
57+
58+
def final(arr):
59+
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
60+
return arr
61+
62+
if "nper" in dims:
63+
for kper, period in value.items():
64+
if kper == "*":
65+
kper = 0
66+
match len(shape):
67+
case 1:
68+
set_(a, period, kper)
69+
case _:
70+
for cellid, v in period.items():
71+
nn = get_nn(cellid, **dims)
72+
set_(a, v, kper, nn)
73+
if kper == "*":
74+
break
75+
else:
76+
for cellid, v in value.items():
77+
nn = get_nn(cellid, **dims)
78+
set_(a, v, nn)
79+
80+
return final(a)

flopy4/mf6/gwf/chd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from typing import ClassVar, Optional
33

44
import numpy as np
5+
from attrs import Converter
56
from numpy.typing import NDArray
6-
from xattree import dict_to_array_converter, xattree
7+
from xattree import xattree
78

89
from flopy4.mf6.constants import FILL_DNODATA
10+
from flopy4.mf6.converters import dict_to_array
911
from flopy4.mf6.package import Package
1012
from flopy4.mf6.spec import array, field
1113

@@ -31,7 +33,7 @@ class Chd(Package):
3133
"nnodes",
3234
),
3335
default=None,
34-
converter=dict_to_array_converter,
36+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
3537
reader="urword",
3638
)
3739
aux: Optional[NDArray[np.float64]] = array(
@@ -41,7 +43,7 @@ class Chd(Package):
4143
"nnodes",
4244
),
4345
default=None,
44-
converter=dict_to_array_converter,
46+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
4547
reader="urword",
4648
)
4749
boundname: Optional[NDArray[np.str_]] = array(
@@ -51,7 +53,7 @@ class Chd(Package):
5153
"nnodes",
5254
),
5355
default=None,
54-
converter=dict_to_array_converter,
56+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
5557
reader="urword",
5658
)
5759

flopy4/mf6/gwf/dis.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
from attrs import Converter
23
from flopy.discretization.structuredgrid import StructuredGrid
34
from numpy.typing import NDArray
4-
from xattree import dict_to_array_converter, xattree
5+
from xattree import xattree
56

7+
from flopy4.mf6.converters import dict_to_array
68
from flopy4.mf6.package import Package
79
from flopy4.mf6.spec import array, dim, field
810

@@ -43,31 +45,31 @@ class Dis(Package):
4345
block="griddata",
4446
default=1.0,
4547
dims=("ncol",),
46-
converter=dict_to_array_converter,
48+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
4749
)
4850
delc: NDArray[np.float64] = array(
4951
block="griddata",
5052
default=1.0,
5153
dims=("nrow",),
52-
converter=dict_to_array_converter,
54+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
5355
)
5456
top: NDArray[np.float64] = array(
5557
block="griddata",
5658
default=1.0,
5759
dims=("nrow", "ncol"),
58-
converter=dict_to_array_converter,
60+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
5961
)
6062
botm: NDArray[np.float64] = array(
6163
block="griddata",
6264
default=0.0,
6365
dims=("nlay", "nrow", "ncol"),
64-
converter=dict_to_array_converter,
66+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
6567
)
6668
idomain: NDArray[np.int32] = array(
6769
block="griddata",
6870
default=1,
6971
dims=("nlay", "nrow", "ncol"),
70-
converter=dict_to_array_converter,
72+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
7173
)
7274
nnodes: int = dim(
7375
coord="node",

flopy4/mf6/gwf/drn.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
from typing import ClassVar, Optional
33

44
import numpy as np
5+
from attrs import Converter
56
from numpy.typing import NDArray
6-
from xattree import dict_to_array_converter
7+
from xattree import xattree
78

89
from flopy4.mf6.constants import FILL_DNODATA
10+
from flopy4.mf6.converters import dict_to_array
911
from flopy4.mf6.package import Package
1012
from flopy4.mf6.spec import array, field
1113

1214

15+
@xattree
1316
class Drn(Package):
1417
multi_package: ClassVar[bool] = True
15-
1618
auxiliary: Optional[list[str]] = array(block="options", default=None)
1719
auxmultname: Optional[str] = field(block="options", default=None)
1820
auxdepthname: Optional[str] = field(block="options", default=None)
@@ -29,14 +31,14 @@ class Drn(Package):
2931
block="period",
3032
dims=("nper", "nnodes"),
3133
default=None,
32-
converter=dict_to_array_converter,
34+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
3335
reader="urword",
3436
)
3537
cond: Optional[NDArray[np.float64]] = array(
3638
block="period",
3739
dims=("nper", "nnodes"),
3840
default=None,
39-
converter=dict_to_array_converter,
41+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
4042
reader="urword",
4143
)
4244
aux: Optional[NDArray[np.float64]] = array(
@@ -46,7 +48,7 @@ class Drn(Package):
4648
"nnodes",
4749
),
4850
default=None,
49-
converter=dict_to_array_converter,
51+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
5052
reader="urword",
5153
)
5254
boundname: Optional[NDArray[np.str_]] = array(
@@ -56,7 +58,7 @@ class Drn(Package):
5658
"nnodes",
5759
),
5860
default=None,
59-
converter=dict_to_array_converter,
61+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
6062
reader="urword",
6163
)
6264

flopy4/mf6/gwf/ic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
2+
from attrs import Converter
23
from numpy.typing import NDArray
3-
from xattree import dict_to_array_converter, xattree
4+
from xattree import xattree
45

6+
from flopy4.mf6.converters import dict_to_array
57
from flopy4.mf6.package import Package
68
from flopy4.mf6.spec import array, field
79

@@ -12,7 +14,7 @@ class Ic(Package):
1214
block="packagedata",
1315
dims=("nnodes",),
1416
default=1.0,
15-
converter=dict_to_array_converter,
17+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
1618
)
1719
export_array_ascii: bool = field(block="options", default=False)
1820
export_array_netcdf: bool = field(block="options", default=False)

flopy4/mf6/gwf/npf.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from typing import Optional
33

44
import numpy as np
5-
from attrs import define
5+
from attrs import Converter, define
66
from numpy.typing import NDArray
7-
from xattree import dict_to_array_converter, xattree
7+
from xattree import xattree
88

9+
from flopy4.mf6.converters import dict_to_array
910
from flopy4.mf6.package import Package
1011
from flopy4.mf6.spec import array, field
1112

@@ -50,47 +51,47 @@ class Xt3dOptions:
5051
block="griddata",
5152
dims=("nnodes",),
5253
default=0,
53-
converter=dict_to_array_converter,
54+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
5455
)
5556
k: NDArray[np.float64] = array(
5657
block="griddata",
5758
dims=("nnodes",),
5859
default=1.0,
59-
converter=dict_to_array_converter,
60+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
6061
)
6162
k22: Optional[NDArray[np.float64]] = array(
6263
block="griddata",
6364
dims=("nnodes",),
6465
default=None,
65-
converter=dict_to_array_converter,
66+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
6667
)
6768
k33: Optional[NDArray[np.float64]] = array(
6869
block="griddata",
6970
dims=("nnodes",),
7071
default=None,
71-
converter=dict_to_array_converter,
72+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
7273
)
7374
angle1: Optional[NDArray[np.float64]] = array(
7475
block="griddata",
7576
dims=("nnodes",),
7677
default=None,
77-
converter=dict_to_array_converter,
78+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
7879
)
7980
angle2: Optional[NDArray[np.float64]] = array(
8081
block="griddata",
8182
dims=("nnodes",),
8283
default=None,
83-
converter=dict_to_array_converter,
84+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
8485
)
8586
angle3: Optional[NDArray[np.float64]] = array(
8687
block="griddata",
8788
dims=("nnodes",),
8889
default=None,
89-
converter=dict_to_array_converter,
90+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
9091
)
9192
wetdry: Optional[NDArray[np.float64]] = array(
9293
block="griddata",
9394
dims=("nnodes",),
9495
default=None,
95-
converter=dict_to_array_converter,
96+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
9697
)

flopy4/mf6/gwf/oc.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from typing import Literal, Optional
33

44
import numpy as np
5-
from attrs import define
5+
from attrs import Converter, define
66
from numpy.typing import NDArray
7-
from xattree import dict_to_array_converter, xattree
7+
from xattree import xattree
88

9+
from flopy4.mf6.converters import dict_to_array
910
from flopy4.mf6.package import Package
1011
from flopy4.mf6.spec import array, field
1112
from flopy4.utils import to_path
@@ -54,41 +55,30 @@ class Period:
5455
block="period",
5556
default="all",
5657
dims=("nper",),
57-
converter=dict_to_array_converter,
58+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
5859
reader="urword",
5960
)
6061
save_budget: Optional[NDArray[np.object_]] = array(
6162
Steps,
6263
block="period",
6364
default="all",
6465
dims=("nper",),
65-
converter=dict_to_array_converter,
66+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
6667
reader="urword",
6768
)
6869
print_head: Optional[NDArray[np.object_]] = array(
6970
Steps,
7071
block="period",
7172
default="all",
7273
dims=("nper",),
73-
converter=dict_to_array_converter,
74+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
7475
reader="urword",
7576
)
7677
print_budget: Optional[NDArray[np.object_]] = array(
7778
Steps,
7879
block="period",
7980
default="all",
8081
dims=("nper",),
81-
converter=dict_to_array_converter,
82+
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
8283
reader="urword",
8384
)
84-
85-
# original DFN
86-
# @classmethod
87-
# def get_dfn(cls) -> Dfn:
88-
# """Generate the component's MODFLOW 6 definition."""
89-
# dfn = super().get_dfn()
90-
# for field_name in list(dfn["perioddata"].keys()):
91-
# dfn["perioddata"].pop(field_name)
92-
# dfn["perioddata"]["saverecord"] = _oc_action_field("save")
93-
# dfn["perioddata"]["printrecord"] = _oc_action_field("print")
94-
# return dfn

0 commit comments

Comments
 (0)