Skip to content

Commit 62aa491

Browse files
authored
Xattree update (#164)
- use new xattree builtin dict to array converter, remove custom impl (todo: support tabular data too... xattree should provide a single converter that supports either dict or table input) - get rid of star syntax for sparse dict representations of arrays. this was more trouble than it was worth to generalize, and the caller can always use a comprehension to easily generate a still-pretty-sparse dict - accommodate get_xatspec refactor (it returns XatSpec now, not the flattened chainmap) - miscelleanous other cleanup and notes - todo: xattree needs to distinguish array fill value from default. can't always interpret default unambiguously for optional arrays. e.g., when an optional array should be None (i.e. not present) by default, but should have a fill value when populated from a sparse dict or table representation
1 parent 6588b09 commit 62aa491

File tree

18 files changed

+768
-908
lines changed

18 files changed

+768
-908
lines changed

docs/dev/sdd.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ import numpy as np
8585
@define
8686
class Ic(Package):
8787
"""Initial conditions package"""
88-
strt: NDArray[np.floating] = field(...)
88+
strt: NDArray[np.float64] = field(...)
8989
export_array_ascii: bool = field(...)
9090
export_array_netcdf: bool = field(...)
9191
```

docs/examples/quickstart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
npf = Npf(parent=gwf, save_specific_discharge=True)
2121
chd = Chd(
2222
parent=gwf,
23-
head={"*": {(0, 0, 0): 1.0, (0, 9, 9): 0.0}},
23+
head={0: {(0, 0, 0): 1.0, (0, 9, 9): 0.0}},
2424
)
2525
oc = Oc(
2626
parent=gwf,
2727
budget_file=f"{gwf.name}.bud",
2828
head_file=f"{gwf.name}.hds",
29-
save_head={"*": "all"},
30-
save_budget={"*": "all"},
29+
save_head={0: "all"},
30+
save_budget={0: "all"},
3131
)
3232

3333
# sim.write()

flopy4/mf6/adapters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(
185185
self._data = package.data
186186
else:
187187
raise Exception("Input package has no data")
188-
self._spec = get_xatspec(type(package))
188+
self._spec = get_xatspec(type(package)).flat
189189
if modelgrid:
190190
self._grid = modelgrid
191191
elif model:

flopy4/mf6/codec/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from flopy4.mf6 import filters
1111
from flopy4.mf6.codec.converter import (
12-
structure_array,
1312
unstructure_array,
1413
unstructure_chd,
1514
unstructure_component,
@@ -85,7 +84,6 @@ def dump(data, path: str | PathLike) -> None:
8584

8685

8786
__all__ = [
88-
"structure_array",
8987
"unstructure_array",
9088
"loads",
9189
"load",

flopy4/mf6/codec/converter.py

Lines changed: 2 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,16 @@
1-
from typing import Any, Tuple
1+
from typing import Any
22

33
import numpy as np
44
import sparse
55
import xattree
6-
from numpy.typing import NDArray
76
from xarray import DataArray
8-
from xattree import get_xatspec
97

10-
from flopy4.adapters import get_cellid, get_nn
8+
from flopy4.adapters import get_cellid
119
from flopy4.mf6.component import Component
12-
from flopy4.mf6.config import SPARSE_THRESHOLD
1310
from flopy4.mf6.constants import FILL_DNODATA
1411
from flopy4.mf6.spec import get_blocks, is_list_field
1512

1613

17-
# TODO: convert to a cattrs structuring hook so we don't have to
18-
# apply separately to all array fields?
19-
def structure_array(value, self_, field) -> NDArray:
20-
"""
21-
Convert a sparse dictionary representation of an array to a
22-
dense numpy array or a sparse COO array.
23-
"""
24-
25-
if not isinstance(value, dict):
26-
# if not a dict, assume it's a numpy array
27-
# and let xarray deal with it if it isn't
28-
return value
29-
30-
# get spec
31-
spec = get_xatspec(type(self_))
32-
field = spec[field.name]
33-
if not field.dims:
34-
raise ValueError(f"Field {field} missing dims")
35-
36-
# resolve dims
37-
explicit_dims = self_.__dict__.get("dims", {})
38-
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
39-
dims = inherited_dims | explicit_dims
40-
shape = [dims.get(d, d) for d in field.dims]
41-
unresolved = [d for d in shape if isinstance(d, str)]
42-
if any(unresolved):
43-
raise ValueError(f"Couldn't resolve dims: {unresolved}")
44-
45-
if np.prod(shape) > SPARSE_THRESHOLD:
46-
a: dict[Tuple[Any, ...], Any] = dict()
47-
48-
def set_(arr, val, *ind):
49-
arr[tuple(ind)] = val
50-
51-
def final(arr):
52-
coords = np.array(list(map(list, zip(*arr.keys()))))
53-
return sparse.COO(
54-
coords,
55-
list(arr.values()),
56-
shape=shape,
57-
fill_value=field.default or FILL_DNODATA,
58-
)
59-
else:
60-
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore
61-
62-
def set_(arr, val, *ind):
63-
arr[ind] = val
64-
65-
def final(arr):
66-
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
67-
return arr
68-
69-
# populate array. TODO: is there a way to do this
70-
# without hardcoding awareness of kper and cellid?
71-
if "nper" in dims:
72-
for kper, period in value.items():
73-
if kper == "*":
74-
kper = 0
75-
match len(shape):
76-
case 1:
77-
set_(a, period, kper)
78-
case _:
79-
for cellid, v in period.items():
80-
nn = get_nn(cellid, **dims)
81-
set_(a, v, kper, nn)
82-
if kper == "*":
83-
break
84-
else:
85-
for cellid, v in value.items():
86-
nn = get_nn(cellid, **dims)
87-
set_(a, v, nn)
88-
return final(a)
89-
90-
9114
def unstructure_array(value: DataArray) -> dict:
9215
"""
9316
Convert a dense numpy array or a sparse COO array to a sparse

flopy4/mf6/gwf/chd.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from typing import ClassVar, Optional
33

44
import numpy as np
5-
from attrs import Converter, define
5+
from attrs import define
66
from numpy.typing import NDArray
7-
from xattree import xattree
7+
from xattree import dict_to_array_converter, xattree
88

9-
from flopy4.mf6.codec import structure_array
109
from flopy4.mf6.constants import FILL_DNODATA
1110
from flopy4.mf6.package import Package
1211
from flopy4.mf6.spec import array, field
@@ -34,24 +33,24 @@ class Steps:
3433
obs_filerecord: Optional[Path] = field(block="options", default=None)
3534
dev_no_newton: bool = field(default=False, metadata={"block": "options"})
3635
maxbound: Optional[int] = field(block="dimensions", default=None)
37-
head: Optional[NDArray[np.floating]] = array(
36+
head: Optional[NDArray[np.float64]] = array(
3837
block="period",
3938
dims=(
4039
"nper",
4140
"nnodes",
4241
),
4342
default=None,
44-
converter=Converter(structure_array, takes_self=True, takes_field=True),
43+
converter=dict_to_array_converter,
4544
reader="urword",
4645
)
47-
aux: Optional[NDArray[np.floating]] = array(
46+
aux: Optional[NDArray[np.float64]] = array(
4847
block="period",
4948
dims=(
5049
"nper",
5150
"nnodes",
5251
),
5352
default=None,
54-
converter=Converter(structure_array, takes_self=True, takes_field=True),
53+
converter=dict_to_array_converter,
5554
reader="urword",
5655
)
5756
boundname: Optional[NDArray[np.str_]] = array(
@@ -61,15 +60,15 @@ class Steps:
6160
"nnodes",
6261
),
6362
default=None,
64-
converter=Converter(structure_array, takes_self=True, takes_field=True),
63+
converter=dict_to_array_converter,
6564
reader="urword",
6665
)
6766
steps: Optional[NDArray[np.object_]] = array(
6867
Steps,
6968
block="period",
7069
dims=("nper", "nnodes"),
7170
default=None,
72-
converter=Converter(structure_array, takes_self=True, takes_field=True),
71+
converter=dict_to_array_converter,
7372
reader="urword",
7473
)
7574

@@ -79,8 +78,25 @@ def __attrs_post_init__(self):
7978
# in post init. but this only works when values
8079
# are set in the initializer, not when they are
8180
# set later.
82-
maxhead = len(np.where(self.head != FILL_DNODATA)) if self.head is not None else 0
83-
maxaux = len(np.where(self.aux != FILL_DNODATA)) if self.aux is not None else 0
84-
maxboundname = len(np.where(self.boundname != "")) if self.boundname is not None else 0
81+
if self.head is None:
82+
maxhead = 0
83+
else:
84+
head = self.head if self.head.data.shape == self.head.shape else self.head.todense()
85+
maxhead = len(np.where(head != FILL_DNODATA))
86+
if self.aux is None:
87+
maxaux = 0
88+
else:
89+
aux = self.aux if self.aux.data.shape == self.aux.shape else self.aux.todense()
90+
maxaux = len(np.where(aux != FILL_DNODATA))
91+
if self.boundname is None:
92+
maxboundname = 0
93+
else:
94+
boundname = (
95+
self.boundname
96+
if self.boundname.data.shape == self.boundname.shape
97+
else self.boundname.todense()
98+
)
99+
maxboundname = len(np.where(boundname != ""))
100+
85101
# maxsteps = len(np.where(self.steps != None)) if self.steps is not None else 0
86102
self.maxbound = max(maxhead, maxaux, maxboundname)

flopy4/mf6/gwf/dis.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import numpy as np
2-
from attrs import Converter
32
from flopy.discretization.structuredgrid import StructuredGrid
43
from numpy.typing import NDArray
5-
from xattree import xattree
4+
from xattree import dict_to_array_converter, xattree
65

7-
from flopy4.mf6.codec import structure_array
86
from flopy4.mf6.package import Package
97
from flopy4.mf6.spec import array, dim, field
108

@@ -25,48 +23,51 @@ class Dis(Package):
2523
coord="lay",
2624
scope="gwf",
2725
default=1,
26+
group="grid",
2827
)
2928
ncol: int = dim(
3029
block="dimensions",
3130
coord="col",
3231
scope="gwf",
3332
default=2,
33+
group="grid",
3434
)
3535
nrow: int = dim(
3636
block="dimensions",
3737
coord="row",
3838
scope="gwf",
3939
default=2,
40+
group="grid",
4041
)
41-
delr: NDArray[np.floating] = array(
42+
delr: NDArray[np.float64] = array(
4243
block="griddata",
4344
default=1.0,
4445
dims=("ncol",),
45-
converter=Converter(structure_array, takes_self=True, takes_field=True),
46+
converter=dict_to_array_converter,
4647
)
47-
delc: NDArray[np.floating] = array(
48+
delc: NDArray[np.float64] = array(
4849
block="griddata",
4950
default=1.0,
5051
dims=("nrow",),
51-
converter=Converter(structure_array, takes_self=True, takes_field=True),
52+
converter=dict_to_array_converter,
5253
)
53-
top: NDArray[np.floating] = array(
54+
top: NDArray[np.float64] = array(
5455
block="griddata",
5556
default=1.0,
5657
dims=("nrow", "ncol"),
57-
converter=Converter(structure_array, takes_self=True, takes_field=True),
58+
converter=dict_to_array_converter,
5859
)
59-
botm: NDArray[np.floating] = array(
60+
botm: NDArray[np.float64] = array(
6061
block="griddata",
6162
default=0.0,
6263
dims=("nlay", "nrow", "ncol"),
63-
converter=Converter(structure_array, takes_self=True, takes_field=True),
64+
converter=dict_to_array_converter,
6465
)
65-
idomain: NDArray[np.integer] = array(
66+
idomain: NDArray[np.int32] = array(
6667
block="griddata",
6768
default=1,
6869
dims=("nlay", "nrow", "ncol"),
69-
converter=Converter(structure_array, takes_self=True, takes_field=True),
70+
converter=dict_to_array_converter,
7071
)
7172
nnodes: int = dim(
7273
coord="node",

flopy4/mf6/gwf/ic.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
import numpy as np
2-
from attrs import Converter
32
from numpy.typing import NDArray
4-
from xattree import xattree
3+
from xattree import dict_to_array_converter, xattree
54

6-
from flopy4.mf6.codec import structure_array
75
from flopy4.mf6.package import Package
86
from flopy4.mf6.spec import array, field
97

108

119
@xattree
1210
class Ic(Package):
13-
strt: NDArray[np.floating] = array(
11+
strt: NDArray[np.float64] = array(
1412
block="packagedata",
1513
dims=("nnodes",),
1614
default=1.0,
17-
converter=Converter(structure_array, takes_self=True, takes_field=True),
15+
converter=dict_to_array_converter,
1816
)
1917
export_array_ascii: bool = field(block="options", default=False)
2018
export_array_netcdf: bool = field(block="options", default=False)

0 commit comments

Comments
 (0)