Skip to content

Commit 295befc

Browse files
committed
very basic decorator implementation
1 parent dbb5ff3 commit 295befc

File tree

1 file changed

+192
-87
lines changed

1 file changed

+192
-87
lines changed

docs/examples/attrs_demo.py

Lines changed: 192 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,175 @@
22

33
# This example demonstrates a tentative `attrs`-based object model.
44

5+
from os import PathLike
56
from pathlib import Path
6-
from typing import List, Literal, Optional
7+
from typing import Any, Literal, Optional, get_origin
8+
from warnings import warn
79

10+
import attrs
811
import numpy as np
9-
from attr import asdict, define, field
12+
from attr import define, field, fields_dict
1013
from cattr import Converter
11-
from flopy.discretization import StructuredGrid
12-
from numpy.typing import NDArray
14+
from numpy.typing import ArrayLike, NDArray
1315
from xarray import Dataset, DataTree
1416

1517

16-
@define
18+
def _parse_dim_names(s: str) -> tuple[str]:
19+
return tuple(
20+
[
21+
ss.strip()
22+
for ss in s.strip().replace("(", "").replace(")", "").split(",")
23+
if any(ss)
24+
]
25+
)
26+
27+
28+
def _try_resolve_dim(self, name) -> int | str:
29+
name = name.strip()
30+
value = getattr(self, name, None)
31+
if value:
32+
return value
33+
if hasattr(self, "model") and hasattr(self.model, "dis"):
34+
return getattr(self.model.dis, name, name)
35+
return name
36+
37+
38+
def _to_array(value: ArrayLike) -> Optional[NDArray]:
39+
return None if value is None else np.array(value)
40+
41+
42+
def _to_shaped_array(
43+
value: ArrayLike | str | PathLike, self_, field
44+
) -> Optional[NDArray]:
45+
if isinstance(value, (str, PathLike)):
46+
# TODO
47+
pass
48+
49+
value = _to_array(value)
50+
if value is None:
51+
return None
52+
dim_names = _parse_dim_names(field.metadata["shape"])
53+
shape = tuple([_try_resolve_dim(self_, n) for n in dim_names])
54+
unresolved = [d for d in shape if not isinstance(d, int)]
55+
if any(unresolved):
56+
warn(f"Failed to resolve dimension names: {', '.join(unresolved)}")
57+
return value
58+
elif value.shape == ():
59+
return np.ones(shape) ** value.item()
60+
elif value.shape != shape:
61+
raise ValueError(
62+
f"Shape mismatch, got {value.shape}, expected {shape}"
63+
)
64+
return value
65+
66+
67+
def _to_path(value) -> Optional[Path]:
68+
return Path(value) if value else None
69+
70+
71+
def datatree(cls):
72+
# TODO
73+
# - determine whether data array, data set, or data tree DONE
74+
# - shape check arrays (dynamic validator?)
75+
# check for parent and update dimensions
76+
# then try to realign existing packages?
77+
78+
old_post_init = getattr(cls, "__attrs_post_init__", None)
79+
80+
def __attrs_post_init__(self):
81+
print(f"Running datatree on {cls.__name__}")
82+
83+
if old_post_init:
84+
old_post_init(self)
85+
86+
fields = fields_dict(cls)
87+
arrays = {}
88+
for n, f in fields.items():
89+
if get_origin(f.type) is not np.ndarray:
90+
continue
91+
value = getattr(self, n)
92+
if value is None:
93+
continue
94+
arrays[n] = (_parse_dim_names(f.metadata["shape"]), value)
95+
dataset = Dataset(arrays)
96+
children = getattr(self, "children", None)
97+
if children:
98+
self.data = DataTree(
99+
dataset, name=cls.__name__, children=[c.data for c in children]
100+
)
101+
else:
102+
self.data = dataset
103+
104+
cls.__attrs_post_init__ = __attrs_post_init__
105+
106+
return cls
107+
108+
109+
@datatree
110+
@define(slots=False)
111+
class GwfDis:
112+
nlay: int = field(default=1, metadata={"block": "dimensions"})
113+
ncol: int = field(default=2, metadata={"block": "dimensions"})
114+
nrow: int = field(default=2, metadata={"block": "dimensions"})
115+
delr: NDArray[np.floating] = field(
116+
converter=attrs.Converter(
117+
_to_shaped_array, takes_self=True, takes_field=True
118+
),
119+
default=1.0,
120+
metadata={"block": "griddata", "shape": "(ncol,)"},
121+
)
122+
delc: NDArray[np.floating] = field(
123+
converter=attrs.Converter(
124+
_to_shaped_array, takes_self=True, takes_field=True
125+
),
126+
default=1.0,
127+
metadata={"block": "griddata", "shape": "(nrow,)"},
128+
)
129+
top: NDArray[np.floating] = field(
130+
converter=attrs.Converter(
131+
_to_shaped_array, takes_self=True, takes_field=True
132+
),
133+
default=1.0,
134+
metadata={"block": "griddata", "shape": "(ncol, nrow)"},
135+
)
136+
botm: NDArray[np.floating] = field(
137+
converter=attrs.Converter(
138+
_to_shaped_array, takes_self=True, takes_field=True
139+
),
140+
default=0.0,
141+
metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"},
142+
)
143+
idomain: Optional[NDArray[np.integer]] = field(
144+
converter=attrs.Converter(
145+
_to_shaped_array, takes_self=True, takes_field=True
146+
),
147+
default=1,
148+
metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"},
149+
)
150+
length_units: str = field(default=None, metadata={"block": "options"})
151+
nogrb: bool = field(default=False, metadata={"block": "options"})
152+
xorigin: float = field(default=None, metadata={"block": "options"})
153+
yorigin: float = field(default=None, metadata={"block": "options"})
154+
angrot: float = field(default=None, metadata={"block": "options"})
155+
export_array_netcdf: bool = field(
156+
default=False, metadata={"block": "options"}
157+
)
158+
nodes: int = field(init=False)
159+
data: Dataset = field(init=False)
160+
model: Optional[Any] = field(default=None)
161+
162+
def __attrs_post_init__(self):
163+
self.nodes = self.nlay * self.ncol * self.nrow
164+
165+
166+
@datatree
167+
@define(slots=False)
17168
class GwfIc:
18-
strt: NDArray[np.float64] = field(
19-
metadata={"block": "packagedata", "shape": "(nodes)"}
169+
strt: NDArray[np.floating] = field(
170+
converter=attrs.Converter(
171+
_to_shaped_array, takes_self=True, takes_field=True
172+
),
173+
metadata={"block": "packagedata", "shape": "(nodes)"},
20174
)
21175
export_array_ascii: bool = field(
22176
default=False, metadata={"block": "options"}
@@ -25,13 +179,12 @@ class GwfIc:
25179
default=False,
26180
metadata={"block": "options"},
27181
)
182+
data: Dataset = field(init=False)
183+
model: Optional[Any] = field(default=None)
28184

29-
def __attrs_post_init__(self):
30-
# TODO: setup attributes for blocks?
31-
self.data = DataTree(Dataset({"strt": self.strt}), name="ic")
32185

33-
34-
@define
186+
@datatree
187+
@define(slots=False)
35188
class GwfOc:
36189
@define
37190
class Format:
@@ -40,96 +193,42 @@ class Format:
40193
digits: int
41194
format: Literal["exponential", "fixed", "general", "scientific"]
42195

43-
periods: List[List[tuple]] = field(metadata={"block": "perioddata"})
44196
budget_file: Optional[Path] = field(
45-
default=None, metadata={"block": "options"}
197+
converter=_to_path, default=None, metadata={"block": "options"}
46198
)
47199
budget_csv_file: Optional[Path] = field(
48-
default=None, metadata={"block": "options"}
200+
converter=_to_path, default=None, metadata={"block": "options"}
49201
)
50202
head_file: Optional[Path] = field(
51-
default=None, metadata={"block": "options"}
203+
converter=_to_path, default=None, metadata={"block": "options"}
52204
)
53205
printhead: Optional[Format] = field(
54206
default=None, metadata={"block": "options"}
55207
)
56-
57-
58-
@define
59-
class GwfDis:
60-
nlay: int = field(metadata={"block": "dimensions"})
61-
ncol: int = field(metadata={"block": "dimensions"})
62-
nrow: int = field(metadata={"block": "dimensions"})
63-
delr: NDArray[np.float64] = field(
64-
metadata={"block": "griddata", "shape": "(ncol,)"}
65-
)
66-
delc: NDArray[np.float64] = field(
67-
metadata={"block": "griddata", "shape": "(nrow,)"}
68-
)
69-
top: NDArray[np.float64] = field(
70-
metadata={"block": "griddata", "shape": "(ncol, nrow)"}
71-
)
72-
botm: NDArray[np.float64] = field(
73-
metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"}
74-
)
75-
idomain: NDArray[np.float64] = field(
76-
metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"}
77-
)
78-
length_units: str = field(default=None, metadata={"block": "options"})
79-
nogrb: bool = field(default=False, metadata={"block": "options"})
80-
xorigin: float = field(default=None, metadata={"block": "options"})
81-
yorigin: float = field(default=None, metadata={"block": "options"})
82-
angrot: float = field(default=None, metadata={"block": "options"})
83-
export_array_netcdf: bool = field(
84-
default=False, metadata={"block": "options"}
208+
perioddata: Optional[list[list[tuple]]] = field(
209+
default=None, metadata={"block": "perioddata"}
85210
)
86-
87-
def __attrs_post_init__(self):
88-
self.data = DataTree(
89-
Dataset(
90-
{
91-
"nlay": self.nlay,
92-
"ncol": self.ncol,
93-
"nrow": self.nrow,
94-
"delr": self.delr,
95-
"delc": self.delc,
96-
"top": self.top,
97-
"botm": self.botm,
98-
"idomain": self.idomain,
99-
}
100-
),
101-
name="dis",
102-
)
103-
# TODO: check for parent and update dimensions
104-
# then try to realign any existing packages?
211+
data: Dataset = field(init=False)
212+
model: Optional[Any] = field(default=None)
105213

106214

107-
@define
215+
@datatree
216+
@define(slots=False)
108217
class Gwf:
109-
dis: GwfDis = field()
110-
ic: GwfIc = field()
111-
112-
def __attrs_post_init__(self):
113-
self.data = DataTree.from_dict(
114-
{"/dis": self.dis, "/ic": self.ic}, name="gwf"
115-
)
116-
self.grid = StructuredGrid(**asdict(self.dis))
117-
118-
@ic.validator
119-
def _check_dims(self, attribute, value):
120-
assert value.strt.shape == (
121-
self.dis.nlay * self.dis.nrow * self.dis.ncol
122-
)
218+
dis: Optional[GwfDis] = field(default=None)
219+
ic: Optional[GwfIc] = field(default=None)
220+
oc: Optional[GwfOc] = field(default=None)
221+
data: DataTree = field(init=False)
123222

124223

125224
# We can define a package with some data.
126225

127226

128227
oc = GwfOc(
129228
budget_file="some/file/path.cbc",
130-
periods=[[("print", "budget", "steps", 1, 3, 5)]],
229+
perioddata=[[("print", "budget", "steps", 1, 3, 5)]],
131230
)
132-
assert isinstance(oc.budget_file, str) # TODO path
231+
assert isinstance(oc.budget_file, Path)
133232

134233

135234
# We now set up a `cattrs` converter to convert an unstructured
@@ -142,7 +241,7 @@ def _check_dims(self, attribute, value):
142241
# as would be returned by a separate IO layer in the future.
143242
# (Either hand-written or using e.g. lark.)
144243

145-
gwfoc = converter.structure(
244+
oc = converter.structure(
146245
{
147246
"budget_file": "some/file/path.cbc",
148247
"head_file": "some/file/path.hds",
@@ -152,7 +251,7 @@ def _check_dims(self, attribute, value):
152251
"digits": 8,
153252
"format": "scientific",
154253
},
155-
"periods": [
254+
"perioddata": [
156255
[
157256
("print", "budget", "steps", 1, 3, 5),
158257
("save", "head", "frequency", 2),
@@ -161,9 +260,15 @@ def _check_dims(self, attribute, value):
161260
},
162261
GwfOc,
163262
)
164-
assert gwfoc.budget_file == Path("some/file/path.cbc")
165-
assert gwfoc.printhead.width == 10
166-
assert gwfoc.printhead.format == "scientific"
167-
period = gwfoc.periods[0]
263+
assert oc.budget_file == Path("some/file/path.cbc")
264+
assert oc.printhead.width == 10
265+
assert oc.printhead.format == "scientific"
266+
period = oc.perioddata[0]
168267
assert len(period) == 2
169268
assert period[0] == ("print", "budget", "steps", 1, 3, 5)
269+
270+
271+
# Creating a model by constructor.
272+
273+
274+
gwf = Gwf(dis=GwfDis(), ic=GwfIc(strt=1), oc=oc)

0 commit comments

Comments
 (0)