Skip to content

Commit c6e65dd

Browse files
committed
update xattree and accommodate recent changes
- use builtin dict to array converter, remove custom impl (todo: support tabular data too) - get rid of star syntax for sparse dict representations of arrays. this was more trouble than it was worth to generalize - miscelleanous other cleanup and notes
1 parent 9ef014c commit c6e65dd

File tree

15 files changed

+129
-236
lines changed

15 files changed

+129
-236
lines changed

docs/dev/sdd.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ import numpy as np
8585
@define
8686
class Ic(Package):
8787
"""Initial conditions package"""
88-
strt: NDArray[np.floating] = field(...)
88+
strt: NDArray[np.float64] = field(...)
8989
export_array_ascii: bool = field(...)
9090
export_array_netcdf: bool = field(...)
9191
```

docs/examples/quickstart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
npf = Npf(parent=gwf, save_specific_discharge=True)
2121
chd = Chd(
2222
parent=gwf,
23-
head={"*": {(0, 0, 0): 1.0, (0, 9, 9): 0.0}},
23+
head={0: {(0, 0, 0): 1.0, (0, 9, 9): 0.0}},
2424
)
2525
oc = Oc(
2626
parent=gwf,
2727
budget_file=f"{gwf.name}.bud",
2828
head_file=f"{gwf.name}.hds",
29-
save_head={"*": "all"},
30-
save_budget={"*": "all"},
29+
save_head={0: "all"},
30+
save_budget={0: "all"},
3131
)
3232

3333
# sim.write()

flopy4/mf6/codec/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from flopy4.mf6 import filters
1111
from flopy4.mf6.codec.converter import (
12-
structure_array,
1312
unstructure_array,
1413
unstructure_chd,
1514
unstructure_component,
@@ -85,7 +84,6 @@ def dump(data, path: str | PathLike) -> None:
8584

8685

8786
__all__ = [
88-
"structure_array",
8987
"unstructure_array",
9088
"loads",
9189
"load",

flopy4/mf6/codec/converter.py

Lines changed: 2 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,16 @@
1-
from typing import Any, Tuple
1+
from typing import Any
22

33
import numpy as np
44
import sparse
55
import xattree
6-
from numpy.typing import NDArray
76
from xarray import DataArray
8-
from xattree import get_xatspec
97

10-
from flopy4.adapters import get_cellid, get_nn
8+
from flopy4.adapters import get_cellid
119
from flopy4.mf6.component import Component
12-
from flopy4.mf6.config import SPARSE_THRESHOLD
1310
from flopy4.mf6.constants import FILL_DNODATA
1411
from flopy4.mf6.spec import get_blocks, is_list_field
1512

1613

17-
# TODO: convert to a cattrs structuring hook so we don't have to
18-
# apply separately to all array fields?
19-
def structure_array(value, self_, field) -> NDArray:
20-
"""
21-
Convert a sparse dictionary representation of an array to a
22-
dense numpy array or a sparse COO array.
23-
"""
24-
25-
if not isinstance(value, dict):
26-
# if not a dict, assume it's a numpy array
27-
# and let xarray deal with it if it isn't
28-
return value
29-
30-
# get spec
31-
spec = get_xatspec(type(self_))
32-
field = spec.flat[field.name]
33-
if not field.dims:
34-
raise ValueError(f"Field {field} missing dims")
35-
36-
# resolve dims
37-
explicit_dims = self_.__dict__.get("dims", {})
38-
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
39-
dims = inherited_dims | explicit_dims
40-
shape = [dims.get(d, d) for d in field.dims]
41-
unresolved = [d for d in shape if isinstance(d, str)]
42-
if any(unresolved):
43-
raise ValueError(f"Couldn't resolve dims: {unresolved}")
44-
45-
if np.prod(shape) > SPARSE_THRESHOLD:
46-
a: dict[Tuple[Any, ...], Any] = dict()
47-
48-
def set_(arr, val, *ind):
49-
arr[tuple(ind)] = val
50-
51-
def final(arr):
52-
coords = np.array(list(map(list, zip(*arr.keys()))))
53-
return sparse.COO(
54-
coords,
55-
list(arr.values()),
56-
shape=shape,
57-
fill_value=field.default or FILL_DNODATA,
58-
)
59-
else:
60-
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore
61-
62-
def set_(arr, val, *ind):
63-
arr[ind] = val
64-
65-
def final(arr):
66-
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
67-
return arr
68-
69-
# populate array. TODO: is there a way to do this
70-
# without hardcoding awareness of kper and cellid?
71-
if "nper" in dims:
72-
for kper, period in value.items():
73-
if kper == "*":
74-
kper = 0
75-
match len(shape):
76-
case 1:
77-
set_(a, period, kper)
78-
case _:
79-
for cellid, v in period.items():
80-
nn = get_nn(cellid, **dims)
81-
set_(a, v, kper, nn)
82-
if kper == "*":
83-
break
84-
else:
85-
for cellid, v in value.items():
86-
nn = get_nn(cellid, **dims)
87-
set_(a, v, nn)
88-
return final(a)
89-
90-
9114
def unstructure_array(value: DataArray) -> dict:
9215
"""
9316
Convert a dense numpy array or a sparse COO array to a sparse

flopy4/mf6/gwf/chd.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from typing import ClassVar, Optional
33

44
import numpy as np
5-
from attrs import Converter, define
5+
from attrs import define
66
from numpy.typing import NDArray
7-
from xattree import xattree
7+
from xattree import dict_to_array_converter, xattree
88

9-
from flopy4.mf6.codec import structure_array
109
from flopy4.mf6.constants import FILL_DNODATA
1110
from flopy4.mf6.package import Package
1211
from flopy4.mf6.spec import array, field
@@ -34,24 +33,24 @@ class Steps:
3433
obs_filerecord: Optional[Path] = field(block="options", default=None)
3534
dev_no_newton: bool = field(default=False, metadata={"block": "options"})
3635
maxbound: Optional[int] = field(block="dimensions", default=None)
37-
head: Optional[NDArray[np.floating]] = array(
36+
head: Optional[NDArray[np.float64]] = array(
3837
block="period",
3938
dims=(
4039
"nper",
4140
"nnodes",
4241
),
4342
default=None,
44-
converter=Converter(structure_array, takes_self=True, takes_field=True),
43+
converter=dict_to_array_converter,
4544
reader="urword",
4645
)
47-
aux: Optional[NDArray[np.floating]] = array(
46+
aux: Optional[NDArray[np.float64]] = array(
4847
block="period",
4948
dims=(
5049
"nper",
5150
"nnodes",
5251
),
5352
default=None,
54-
converter=Converter(structure_array, takes_self=True, takes_field=True),
53+
converter=dict_to_array_converter,
5554
reader="urword",
5655
)
5756
boundname: Optional[NDArray[np.str_]] = array(
@@ -61,15 +60,15 @@ class Steps:
6160
"nnodes",
6261
),
6362
default=None,
64-
converter=Converter(structure_array, takes_self=True, takes_field=True),
63+
converter=dict_to_array_converter,
6564
reader="urword",
6665
)
6766
steps: Optional[NDArray[np.object_]] = array(
6867
Steps,
6968
block="period",
7069
dims=("nper", "nnodes"),
7170
default=None,
72-
converter=Converter(structure_array, takes_self=True, takes_field=True),
71+
converter=dict_to_array_converter,
7372
reader="urword",
7473
)
7574

@@ -79,8 +78,25 @@ def __attrs_post_init__(self):
7978
# in post init. but this only works when values
8079
# are set in the initializer, not when they are
8180
# set later.
82-
maxhead = len(np.where(self.head != FILL_DNODATA)) if self.head is not None else 0
83-
maxaux = len(np.where(self.aux != FILL_DNODATA)) if self.aux is not None else 0
84-
maxboundname = len(np.where(self.boundname != "")) if self.boundname is not None else 0
81+
if self.head is None:
82+
maxhead = 0
83+
else:
84+
head = self.head if self.head.data.shape == self.head.shape else self.head.todense()
85+
maxhead = len(np.where(head != FILL_DNODATA))
86+
if self.aux is None:
87+
maxaux = 0
88+
else:
89+
aux = self.aux if self.aux.data.shape == self.aux.shape else self.aux.todense()
90+
maxaux = len(np.where(aux != FILL_DNODATA))
91+
if self.boundname is None:
92+
maxboundname = 0
93+
else:
94+
boundname = (
95+
self.boundname
96+
if self.boundname.data.shape == self.boundname.shape
97+
else self.boundname.todense()
98+
)
99+
maxboundname = len(np.where(boundname != ""))
100+
85101
# maxsteps = len(np.where(self.steps != None)) if self.steps is not None else 0
86102
self.maxbound = max(maxhead, maxaux, maxboundname)

flopy4/mf6/gwf/dis.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import numpy as np
2-
from attrs import Converter
32
from flopy.discretization.structuredgrid import StructuredGrid
43
from numpy.typing import NDArray
5-
from xattree import xattree
4+
from xattree import dict_to_array_converter, xattree
65

7-
from flopy4.mf6.codec import structure_array
86
from flopy4.mf6.package import Package
97
from flopy4.mf6.spec import array, dim, field
108

@@ -25,48 +23,51 @@ class Dis(Package):
2523
coord="lay",
2624
scope="gwf",
2725
default=1,
26+
group="grid",
2827
)
2928
ncol: int = dim(
3029
block="dimensions",
3130
coord="col",
3231
scope="gwf",
3332
default=2,
33+
group="grid",
3434
)
3535
nrow: int = dim(
3636
block="dimensions",
3737
coord="row",
3838
scope="gwf",
3939
default=2,
40+
group="grid",
4041
)
41-
delr: NDArray[np.floating] = array(
42+
delr: NDArray[np.float64] = array(
4243
block="griddata",
4344
default=1.0,
4445
dims=("ncol",),
45-
converter=Converter(structure_array, takes_self=True, takes_field=True),
46+
converter=dict_to_array_converter,
4647
)
47-
delc: NDArray[np.floating] = array(
48+
delc: NDArray[np.float64] = array(
4849
block="griddata",
4950
default=1.0,
5051
dims=("nrow",),
51-
converter=Converter(structure_array, takes_self=True, takes_field=True),
52+
converter=dict_to_array_converter,
5253
)
53-
top: NDArray[np.floating] = array(
54+
top: NDArray[np.float64] = array(
5455
block="griddata",
5556
default=1.0,
5657
dims=("nrow", "ncol"),
57-
converter=Converter(structure_array, takes_self=True, takes_field=True),
58+
converter=dict_to_array_converter,
5859
)
59-
botm: NDArray[np.floating] = array(
60+
botm: NDArray[np.float64] = array(
6061
block="griddata",
6162
default=0.0,
6263
dims=("nlay", "nrow", "ncol"),
63-
converter=Converter(structure_array, takes_self=True, takes_field=True),
64+
converter=dict_to_array_converter,
6465
)
65-
idomain: NDArray[np.integer] = array(
66+
idomain: NDArray[np.int32] = array(
6667
block="griddata",
6768
default=1,
6869
dims=("nlay", "nrow", "ncol"),
69-
converter=Converter(structure_array, takes_self=True, takes_field=True),
70+
converter=dict_to_array_converter,
7071
)
7172
nnodes: int = dim(
7273
coord="node",

flopy4/mf6/gwf/ic.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
import numpy as np
2-
from attrs import Converter
32
from numpy.typing import NDArray
4-
from xattree import xattree
3+
from xattree import dict_to_array_converter, xattree
54

6-
from flopy4.mf6.codec import structure_array
75
from flopy4.mf6.package import Package
86
from flopy4.mf6.spec import array, field
97

108

119
@xattree
1210
class Ic(Package):
13-
strt: NDArray[np.floating] = array(
11+
strt: NDArray[np.float64] = array(
1412
block="packagedata",
1513
dims=("nnodes",),
1614
default=1.0,
17-
converter=Converter(structure_array, takes_self=True, takes_field=True),
15+
converter=dict_to_array_converter,
1816
)
1917
export_array_ascii: bool = field(block="options", default=False)
2018
export_array_netcdf: bool = field(block="options", default=False)

0 commit comments

Comments
 (0)