Skip to content

Commit 2a7af0d

Browse files
authored
array converter (#103)
More complete draft of an array conversion function to expand sparse dictionary representations to a full array. Also some housekeeping, move things into their own modules. Conversion function still needs testing, only the barest minimum here.
1 parent 56a0267 commit 2a7af0d

File tree

21 files changed

+552
-567
lines changed

21 files changed

+552
-567
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ jobs:
3030
- name: Run ruff
3131
run: pixi run lint
3232

33+
- name: Run mypy
34+
run: pixi run mypy flopy4
35+
3336
build:
3437
name: Build
3538
runs-on: ubuntu-latest

flopy4/mf6/__init__.py

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +0,0 @@
1-
from abc import ABC
2-
from datetime import datetime
3-
from pathlib import Path
4-
from typing import Optional
5-
6-
import numpy as np
7-
from attrs import define
8-
from numpy.typing import NDArray
9-
from xattree import ROOT, array, dim, field, xattree
10-
11-
__all__ = [
12-
"Component",
13-
"Package",
14-
"Model",
15-
"Simulation",
16-
"Solution",
17-
"Exchange",
18-
"COMPONENTS",
19-
]
20-
21-
COMPONENTS = {}
22-
"""MF6 component registry."""
23-
24-
25-
class Component(ABC):
26-
@classmethod
27-
def __attrs_init_subclass__(cls):
28-
COMPONENTS[cls.__name__.lower()] = cls
29-
30-
31-
@define
32-
class Package(Component):
33-
pass
34-
35-
36-
@define
37-
class Model(Component):
38-
pass
39-
40-
41-
@define
42-
class Solution(Package):
43-
pass
44-
45-
46-
@define
47-
class Exchange(Package):
48-
exgtype: type = field()
49-
exgfile: Path = field()
50-
exgmnamea: Optional[str] = field(default=None)
51-
exgmnameb: Optional[str] = field(default=None)
52-
53-
54-
@xattree
55-
class Tdis(Package):
56-
@define
57-
class PeriodData:
58-
perlen: float
59-
nstp: int
60-
tsmult: float
61-
62-
nper: int = dim(
63-
name="per",
64-
default=1,
65-
scope=ROOT,
66-
metadata={"block": "dimensions"},
67-
)
68-
time_units: Optional[str] = field(
69-
default=None, metadata={"block": "options"}
70-
)
71-
start_date_time: Optional[datetime] = field(
72-
default=None, metadata={"block": "options"}
73-
)
74-
# perioddata: NDArray[np.object_] = array(
75-
# PeriodData,
76-
# dims=("per",),
77-
# metadata={"block": "perioddata"},
78-
# )
79-
perlen: NDArray[np.floating] = array(
80-
default=1.0,
81-
dims=("per",),
82-
metadata={"block": "perioddata"},
83-
)
84-
nstp: NDArray[np.integer] = array(
85-
default=1,
86-
dims=("per",),
87-
metadata={"block": "perioddata"},
88-
)
89-
tsmult: NDArray[np.floating] = array(
90-
default=1.0,
91-
dims=("per",),
92-
metadata={"block": "perioddata"},
93-
)
94-
95-
96-
@xattree
97-
class Simulation(Component):
98-
models: dict[str, Model] = field()
99-
exchanges: dict[str, Exchange] = field()
100-
solutions: dict[str, Solution] = field()
101-
tdis: Tdis = field()

flopy4/mf6/component.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from abc import ABC
2+
3+
COMPONENTS = {}
4+
"""MF6 component registry."""
5+
6+
7+
class Component(ABC):
8+
@classmethod
9+
def __attrs_init_subclass__(cls):
10+
COMPONENTS[cls.__name__.lower()] = cls

flopy4/mf6/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import numpy as np
2+
3+
FILL_DEFAULT = np.nan
4+
FILL_DNODATA = 1e30

flopy4/mf6/converters.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
2+
import numpy as np
3+
from numpy.typing import NDArray
4+
from xattree import _get_xatspec
5+
6+
from flopy4.mf6.constants import FILL_DNODATA
7+
8+
9+
def convert_array(value, self_, field) -> NDArray:
10+
if not isinstance(value, dict):
11+
# if not a dict, assume it's a numpy array
12+
# and let xarray deal with it if it isn't
13+
return value
14+
15+
# get spec
16+
spec = _get_xatspec(type(self_))
17+
field = spec.arrays[field.name]
18+
if not field.dims:
19+
raise ValueError(f"Field {field} missing dims")
20+
21+
# resolve dims
22+
explicit_dims = self_.__dict__.get("dims", {})
23+
inherited_dims = self_.parent.data.dims if self_.parent else {}
24+
dims = inherited_dims | explicit_dims
25+
shape = [dims.get(d, d) for d in field.dims]
26+
unresolved = [d for d in shape if isinstance(d, str)]
27+
if any(unresolved):
28+
raise ValueError(f"Couldn't resolve dims: {unresolved}")
29+
30+
# create array
31+
a = np.full(shape, fill_value=FILL_DNODATA, dtype=field.dtype)
32+
33+
def _get_nn(cellid):
34+
match len(cellid):
35+
case 1:
36+
return cellid[0]
37+
case 2:
38+
k, j = cellid
39+
return k * dims["ncpl"] + j
40+
case 3:
41+
k, i, j = cellid
42+
return k * dims["row"] * dims["col"] + i * dims["col"] + j
43+
case _:
44+
raise ValueError(f"Invalid cellid: {cellid}")
45+
46+
# populate array. TODO: is there a way to do this
47+
# without hardcoding awareness of kper and cellid?
48+
if "per" in dims:
49+
for kper, period in value.items():
50+
if kper == "*":
51+
kper = 0
52+
match len(shape):
53+
case 1:
54+
a[kper] = value
55+
case _:
56+
for cellid, v in period.items():
57+
nn = _get_nn(cellid)
58+
a[kper, nn] = v
59+
if kper == "*":
60+
break
61+
else:
62+
for cellid, v in value.items():
63+
nn = _get_nn(cellid)
64+
a[nn] = v
65+
66+
return a

flopy4/mf6/exchange.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from pathlib import Path
2+
from typing import Optional
3+
4+
from attrs import define
5+
from xattree import field
6+
7+
from flopy4.mf6.package import Package
8+
9+
10+
@define
11+
class Exchange(Package):
12+
exgtype: type = field()
13+
exgfile: Path = field()
14+
exgmnamea: Optional[str] = field(default=None)
15+
exgmnameb: Optional[str] = field(default=None)

flopy4/mf6/gwf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from attrs import define
55
from xattree import field, xattree
66

7-
from flopy4.mf6 import Model
87
from flopy4.mf6.gwf.chd import Chd
98
from flopy4.mf6.gwf.dis import Dis
109
from flopy4.mf6.gwf.ic import Ic
1110
from flopy4.mf6.gwf.npf import Npf
1211
from flopy4.mf6.gwf.oc import Oc
12+
from flopy4.mf6.model import Model
1313

1414
__all__ = ["Gwf", "Chd", "Dis", "Ic", "Npf", "Oc"]
1515

flopy4/mf6/gwf/chd.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,16 @@
11
from pathlib import Path
22
from typing import Optional
33

4-
import attrs
54
import numpy as np
6-
from attrs import define
5+
from attrs import Converter, define
76
from numpy.typing import NDArray
8-
from xattree import _get_xatspec, array, field, xattree
7+
from xattree import array, field, xattree
98

10-
from flopy4.mf6 import Package
9+
from flopy4.mf6.converters import convert_array
10+
from flopy4.mf6.package import Package
1111

12-
dnodata = 1e30
1312

14-
15-
def _get_nn(ncol, nrow, k, i, j):
16-
return k * nrow * ncol + i * ncol + j
17-
18-
19-
def _convert_array(value, self_, field):
20-
if not isinstance(value, dict):
21-
return value
22-
23-
inherited_dims = self_.__dict__.get("dims", {})
24-
spec = _get_xatspec(type(self_))
25-
field = spec.arrays["head"]
26-
shape = field.dims
27-
if not shape:
28-
raise ValueError()
29-
dims = [inherited_dims.get(d, d) for d in shape]
30-
# TODO pull out dtype from annotation
31-
a = np.full(dims, fill_value=dnodata, dtype=np.float64)
32-
for kper, period in value.items():
33-
if kper == "*":
34-
kper = 0
35-
for cellid, v in period.items():
36-
nn = _get_nn(inherited_dims["col"], inherited_dims["row"], *cellid)
37-
a[kper, nn] = v
38-
return a
39-
40-
41-
@xattree(multi="list")
13+
@xattree
4214
class Chd(Package):
4315
@define(slots=False)
4416
class Steps:
@@ -75,9 +47,7 @@ class Steps:
7547
),
7648
default=None,
7749
metadata={"block": "period"},
78-
converter=attrs.Converter(
79-
_convert_array, takes_self=True, takes_field=True
80-
),
50+
converter=Converter(convert_array, takes_self=True, takes_field=True),
8151
)
8252
aux: Optional[NDArray[np.floating]] = array(
8353
dims=(
@@ -86,6 +56,7 @@ class Steps:
8656
),
8757
default=None,
8858
metadata={"block": "period"},
59+
converter=Converter(convert_array, takes_self=True, takes_field=True),
8960
)
9061
boundname: Optional[NDArray[np.str_]] = array(
9162
dims=(
@@ -94,7 +65,12 @@ class Steps:
9465
),
9566
default=None,
9667
metadata={"block": "period"},
68+
converter=Converter(convert_array, takes_self=True, takes_field=True),
9769
)
9870
steps: Optional[NDArray[np.object_]] = array(
99-
Steps, dims=("per", "node"), default=None, metadata={"block": "period"}
71+
Steps,
72+
dims=("per", "node"),
73+
default=None,
74+
metadata={"block": "period"},
75+
converter=Converter(convert_array, takes_self=True, takes_field=True),
10076
)

flopy4/mf6/gwf/dis.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
from attrs import Converter
23
from numpy.typing import NDArray
34
from xattree import array, dim, field, xattree
45

5-
from flopy4.mf6 import Package
6+
from flopy4.mf6.converters import convert_array
7+
from flopy4.mf6.package import Package
68

79

810
@xattree
@@ -46,26 +48,31 @@ class Dis(Package):
4648
dims=("col",),
4749
default=1.0,
4850
metadata={"block": "griddata"},
51+
converter=Converter(convert_array, takes_self=True, takes_field=True),
4952
)
5053
delc: NDArray[np.floating] = array(
5154
dims=("row",),
5255
default=1.0,
5356
metadata={"block": "griddata"},
57+
converter=Converter(convert_array, takes_self=True, takes_field=True),
5458
)
5559
top: NDArray[np.floating] = array(
5660
dims=("col", "row"),
5761
default=1.0,
5862
metadata={"block": "griddata"},
63+
converter=Converter(convert_array, takes_self=True, takes_field=True),
5964
)
6065
botm: NDArray[np.floating] = array(
6166
dims=("col", "row", "lay"),
6267
default=0.0,
6368
metadata={"block": "griddata"},
69+
converter=Converter(convert_array, takes_self=True, takes_field=True),
6470
)
6571
idomain: NDArray[np.integer] = array(
6672
dims=("col", "row", "lay"),
6773
default=1,
6874
metadata={"block": "griddata"},
75+
converter=Converter(convert_array, takes_self=True, takes_field=True),
6976
)
7077
nnodes: int = dim(name="node", scope="gwf", init=False)
7178

flopy4/mf6/gwf/ic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
from attrs import Converter
23
from numpy.typing import NDArray
34
from xattree import array, field, xattree
45

5-
from flopy4.mf6 import Package
6+
from flopy4.mf6.converters import convert_array
7+
from flopy4.mf6.package import Package
68

79

810
@xattree
@@ -11,6 +13,7 @@ class Ic(Package):
1113
dims=("node",),
1214
default=1.0,
1315
metadata={"block": "packagedata"},
16+
converter=Converter(convert_array, takes_self=True, takes_field=True),
1417
)
1518
export_array_ascii: bool = field(
1619
default=False, metadata={"block": "options"}

0 commit comments

Comments
 (0)