Skip to content

Commit f8956cf

Browse files
committed
draft array converter
1 parent 56a0267 commit f8956cf

File tree

15 files changed

+234
-138
lines changed

15 files changed

+234
-138
lines changed

flopy4/mf6/__init__.py

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +0,0 @@
1-
from abc import ABC
2-
from datetime import datetime
3-
from pathlib import Path
4-
from typing import Optional
5-
6-
import numpy as np
7-
from attrs import define
8-
from numpy.typing import NDArray
9-
from xattree import ROOT, array, dim, field, xattree
10-
11-
__all__ = [
12-
"Component",
13-
"Package",
14-
"Model",
15-
"Simulation",
16-
"Solution",
17-
"Exchange",
18-
"COMPONENTS",
19-
]
20-
21-
COMPONENTS = {}
22-
"""MF6 component registry."""
23-
24-
25-
class Component(ABC):
26-
@classmethod
27-
def __attrs_init_subclass__(cls):
28-
COMPONENTS[cls.__name__.lower()] = cls
29-
30-
31-
@define
32-
class Package(Component):
33-
pass
34-
35-
36-
@define
37-
class Model(Component):
38-
pass
39-
40-
41-
@define
42-
class Solution(Package):
43-
pass
44-
45-
46-
@define
47-
class Exchange(Package):
48-
exgtype: type = field()
49-
exgfile: Path = field()
50-
exgmnamea: Optional[str] = field(default=None)
51-
exgmnameb: Optional[str] = field(default=None)
52-
53-
54-
@xattree
55-
class Tdis(Package):
56-
@define
57-
class PeriodData:
58-
perlen: float
59-
nstp: int
60-
tsmult: float
61-
62-
nper: int = dim(
63-
name="per",
64-
default=1,
65-
scope=ROOT,
66-
metadata={"block": "dimensions"},
67-
)
68-
time_units: Optional[str] = field(
69-
default=None, metadata={"block": "options"}
70-
)
71-
start_date_time: Optional[datetime] = field(
72-
default=None, metadata={"block": "options"}
73-
)
74-
# perioddata: NDArray[np.object_] = array(
75-
# PeriodData,
76-
# dims=("per",),
77-
# metadata={"block": "perioddata"},
78-
# )
79-
perlen: NDArray[np.floating] = array(
80-
default=1.0,
81-
dims=("per",),
82-
metadata={"block": "perioddata"},
83-
)
84-
nstp: NDArray[np.integer] = array(
85-
default=1,
86-
dims=("per",),
87-
metadata={"block": "perioddata"},
88-
)
89-
tsmult: NDArray[np.floating] = array(
90-
default=1.0,
91-
dims=("per",),
92-
metadata={"block": "perioddata"},
93-
)
94-
95-
96-
@xattree
97-
class Simulation(Component):
98-
models: dict[str, Model] = field()
99-
exchanges: dict[str, Exchange] = field()
100-
solutions: dict[str, Solution] = field()
101-
tdis: Tdis = field()

flopy4/mf6/component.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from abc import ABC
2+
3+
COMPONENTS = {}
4+
"""MF6 component registry."""
5+
6+
7+
class Component(ABC):
8+
@classmethod
9+
def __attrs_init_subclass__(cls):
10+
COMPONENTS[cls.__name__.lower()] = cls

flopy4/mf6/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import numpy as np
2+
3+
FILL_DEFAULT = np.nan
4+
FILL_DNODATA = 1e30

flopy4/mf6/converters.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import types
2+
from typing import Union, get_args, get_origin
3+
4+
import numpy as np
5+
from numpy.typing import NDArray
6+
from xattree import _get_xatspec
7+
8+
from flopy4.mf6.constants import FILL_DNODATA
9+
10+
11+
def convert_array(value, self_, field):
12+
if not isinstance(value, dict):
13+
# if not a dict, assume it's a numpy array
14+
# and let xarray deal with it if it isn't
15+
return value
16+
17+
# get spec
18+
spec = _get_xatspec(type(self_))
19+
field = spec.arrays[field.name]
20+
if not field.dims:
21+
raise ValueError(f"Field {field} missing dims")
22+
23+
# resolve dims
24+
explicit_dims = self_.__dict__.get("dims", {})
25+
inherited_dims = self_.parent.data.dims if self_.parent else {}
26+
dims = inherited_dims | explicit_dims
27+
shape = [dims.get(d, d) for d in field.dims]
28+
unresolved = [d for d in shape if isinstance(d, str)]
29+
if any(unresolved):
30+
raise ValueError(f"Couldn't resolve dims: {unresolved}")
31+
32+
# extract dtype
33+
origin = get_origin(field.type)
34+
args = get_args(field.type)
35+
if origin in (Union, types.UnionType) and args[-1] is types.NoneType:
36+
origin = args[0] # Optional
37+
if origin is NDArray:
38+
dtype = args[1]
39+
elif origin is np.ndarray:
40+
dtype = args[0]
41+
else:
42+
raise ValueError(f"Expected NDArray, got {origin}")
43+
44+
# create array
45+
a = np.full(shape, fill_value=FILL_DNODATA, dtype=dtype)
46+
47+
def _get_nn(cellid):
48+
match len(cellid):
49+
case 1:
50+
return cellid[0]
51+
case 2:
52+
k, j = cellid
53+
return k * dims["ncpl"] + j
54+
case 3:
55+
k, i, j = cellid
56+
return k * dims["row"] * dims["col"] + i * dims["col"] + j
57+
case _:
58+
raise ValueError(f"Invalid cellid: {cellid}")
59+
60+
# populate array. TODO: is there a way to do this
61+
# without hardcoding awareness of kper and cellid?
62+
if "kper" in dims:
63+
for kper, period in value.items():
64+
if kper == "*":
65+
kper = 0
66+
match len(shape):
67+
case 1:
68+
a[kper] = v
69+
case _:
70+
for cellid, v in period.items():
71+
nn = _get_nn(cellid)
72+
a[kper, nn] = v
73+
if kper == "*":
74+
break
75+
else:
76+
for cellid, v in value.items():
77+
nn = _get_nn(cellid)
78+
a[kper, nn] = v
79+
80+
return a

flopy4/mf6/exchange.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from pathlib import Path
2+
from typing import Optional
3+
4+
from attrs import define
5+
from xattree import field
6+
7+
from flopy4.mf6.package import Package
8+
9+
10+
@define
11+
class Exchange(Package):
12+
exgtype: type = field()
13+
exgfile: Path = field()
14+
exgmnamea: Optional[str] = field(default=None)
15+
exgmnameb: Optional[str] = field(default=None)

flopy4/mf6/gwf/chd.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,16 @@
11
from pathlib import Path
22
from typing import Optional
33

4-
import attrs
54
import numpy as np
6-
from attrs import define
5+
from attrs import Converter, define
76
from numpy.typing import NDArray
8-
from xattree import _get_xatspec, array, field, xattree
7+
from xattree import array, field, xattree
98

109
from flopy4.mf6 import Package
11-
12-
dnodata = 1e30
13-
14-
15-
def _get_nn(ncol, nrow, k, i, j):
16-
return k * nrow * ncol + i * ncol + j
17-
18-
19-
def _convert_array(value, self_, field):
20-
if not isinstance(value, dict):
21-
return value
22-
23-
inherited_dims = self_.__dict__.get("dims", {})
24-
spec = _get_xatspec(type(self_))
25-
field = spec.arrays["head"]
26-
shape = field.dims
27-
if not shape:
28-
raise ValueError()
29-
dims = [inherited_dims.get(d, d) for d in shape]
30-
# TODO pull out dtype from annotation
31-
a = np.full(dims, fill_value=dnodata, dtype=np.float64)
32-
for kper, period in value.items():
33-
if kper == "*":
34-
kper = 0
35-
for cellid, v in period.items():
36-
nn = _get_nn(inherited_dims["col"], inherited_dims["row"], *cellid)
37-
a[kper, nn] = v
38-
return a
10+
from flopy4.mf6.converters import convert_array
3911

4012

13+
# TODO get rid of multi, just infer from parent?
4114
@xattree(multi="list")
4215
class Chd(Package):
4316
@define(slots=False)
@@ -75,9 +48,7 @@ class Steps:
7548
),
7649
default=None,
7750
metadata={"block": "period"},
78-
converter=attrs.Converter(
79-
_convert_array, takes_self=True, takes_field=True
80-
),
51+
converter=Converter(convert_array, takes_self=True, takes_field=True),
8152
)
8253
aux: Optional[NDArray[np.floating]] = array(
8354
dims=(
@@ -86,6 +57,7 @@ class Steps:
8657
),
8758
default=None,
8859
metadata={"block": "period"},
60+
converter=Converter(convert_array, takes_self=True, takes_field=True),
8961
)
9062
boundname: Optional[NDArray[np.str_]] = array(
9163
dims=(
@@ -94,7 +66,12 @@ class Steps:
9466
),
9567
default=None,
9668
metadata={"block": "period"},
69+
converter=Converter(convert_array, takes_self=True, takes_field=True),
9770
)
9871
steps: Optional[NDArray[np.object_]] = array(
99-
Steps, dims=("per", "node"), default=None, metadata={"block": "period"}
72+
Steps,
73+
dims=("per", "node"),
74+
default=None,
75+
metadata={"block": "period"},
76+
converter=Converter(convert_array, takes_self=True, takes_field=True),
10077
)

flopy4/mf6/gwf/dis.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
from attrs import Converter
23
from numpy.typing import NDArray
34
from xattree import array, dim, field, xattree
45

56
from flopy4.mf6 import Package
7+
from flopy4.mf6.converters import convert_array
68

79

810
@xattree
@@ -46,26 +48,31 @@ class Dis(Package):
4648
dims=("col",),
4749
default=1.0,
4850
metadata={"block": "griddata"},
51+
converter=Converter(convert_array, takes_self=True, takes_field=True),
4952
)
5053
delc: NDArray[np.floating] = array(
5154
dims=("row",),
5255
default=1.0,
5356
metadata={"block": "griddata"},
57+
converter=Converter(convert_array, takes_self=True, takes_field=True),
5458
)
5559
top: NDArray[np.floating] = array(
5660
dims=("col", "row"),
5761
default=1.0,
5862
metadata={"block": "griddata"},
63+
converter=Converter(convert_array, takes_self=True, takes_field=True),
5964
)
6065
botm: NDArray[np.floating] = array(
6166
dims=("col", "row", "lay"),
6267
default=0.0,
6368
metadata={"block": "griddata"},
69+
converter=Converter(convert_array, takes_self=True, takes_field=True),
6470
)
6571
idomain: NDArray[np.integer] = array(
6672
dims=("col", "row", "lay"),
6773
default=1,
6874
metadata={"block": "griddata"},
75+
converter=Converter(convert_array, takes_self=True, takes_field=True),
6976
)
7077
nnodes: int = dim(name="node", scope="gwf", init=False)
7178

flopy4/mf6/gwf/ic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
from attrs import Converter
23
from numpy.typing import NDArray
34
from xattree import array, field, xattree
45

56
from flopy4.mf6 import Package
7+
from flopy4.mf6.converters import convert_array
68

79

810
@xattree
@@ -11,6 +13,7 @@ class Ic(Package):
1113
dims=("node",),
1214
default=1.0,
1315
metadata={"block": "packagedata"},
16+
converter=Converter(convert_array, takes_self=True, takes_field=True),
1417
)
1518
export_array_ascii: bool = field(
1619
default=False, metadata={"block": "options"}

0 commit comments

Comments
 (0)