Skip to content

Commit 66e79a9

Browse files
committed
minimal working example
1 parent 9719f97 commit 66e79a9

File tree

1 file changed

+75
-57
lines changed

1 file changed

+75
-57
lines changed

docs/examples/attrs_demo.py

Lines changed: 75 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
from os import PathLike
66
from pathlib import Path
7-
from typing import Any, Literal, Optional, get_origin
7+
from typing import Literal, Optional, get_origin
88
from warnings import warn
99

10-
import attrs
1110
import numpy as np
1211
from attr import define, field, fields_dict
1312
from cattr import Converter
@@ -35,23 +34,28 @@ def _try_resolve_dim(self, name) -> int | str:
3534
return name
3635

3736

38-
def _to_array(value: ArrayLike) -> Optional[NDArray]:
37+
def _try_resolve_shape(self, field) -> tuple[int | str]:
38+
dim_names = _parse_dim_names(field.metadata["shape"])
39+
return tuple([_try_resolve_dim(self, n) for n in dim_names])
40+
41+
42+
def _to_array(value: Optional[ArrayLike]) -> Optional[NDArray]:
3943
return None if value is None else np.array(value)
4044

4145

4246
def _to_shaped_array(
43-
value: ArrayLike | str | PathLike, self_, field
47+
value: Optional[ArrayLike | str | PathLike], self_, field
4448
) -> Optional[NDArray]:
4549
if isinstance(value, (str, PathLike)):
46-
# TODO
50+
# TODO handle external arrays
4751
pass
4852

4953
value = _to_array(value)
5054
if value is None:
5155
return None
52-
dim_names = _parse_dim_names(field.metadata["shape"])
53-
shape = tuple([_try_resolve_dim(self_, n) for n in dim_names])
54-
unresolved = [d for d in shape if not isinstance(d, int)]
56+
57+
shape = _try_resolve_shape(self_, field)
58+
unresolved = [dim for dim in shape if not isinstance(dim, int)]
5559
if any(unresolved):
5660
warn(f"Failed to resolve dimension names: {', '.join(unresolved)}")
5761
return value
@@ -69,20 +73,10 @@ def _to_path(value) -> Optional[Path]:
6973

7074

7175
def datatree(cls):
72-
# TODO
73-
# - determine whether data array, data set, or data tree DONE
74-
# - shape check arrays (dynamic validator?)
75-
# check for parent and update dimensions
76-
# then try to realign existing packages?
77-
78-
old_post_init = getattr(cls, "__attrs_post_init__", None)
79-
80-
def __attrs_post_init__(self):
81-
print(f"Running datatree on {cls.__name__}")
82-
83-
if old_post_init:
84-
old_post_init(self)
76+
post_init_name = "__attrs_post_init__"
77+
post_init_prev = getattr(cls, post_init_name, None)
8578

79+
def _set_data_on_self(self, cls):
8680
fields = fields_dict(cls)
8781
arrays = {}
8882
for n, f in fields.items():
@@ -91,59 +85,71 @@ def __attrs_post_init__(self):
9185
value = getattr(self, n)
9286
if value is None:
9387
continue
94-
arrays[n] = (_parse_dim_names(f.metadata["shape"]), value)
88+
arrays[n] = (
89+
_parse_dim_names(f.metadata["shape"]),
90+
_to_shaped_array(value, self, f),
91+
)
9592
dataset = Dataset(arrays)
96-
children = getattr(self, "children", None)
97-
if children:
98-
self.data = DataTree(
99-
dataset, name=cls.__name__, children=[c.data for c in children]
93+
self.data = (
94+
DataTree(dataset, name=cls.__name__.lower()[3:])
95+
if issubclass(cls, Model)
96+
else dataset
97+
)
98+
99+
def _set_self_on_model(self, cls):
100+
model = getattr(self, "model", None)
101+
if model:
102+
self_name = cls.__name__.lower()[3:]
103+
setattr(model, self_name, self)
104+
model.data = model.data.assign(
105+
{self_name: DataTree(self.data, name=self_name)}
100106
)
101-
else:
102-
self.data = dataset
103107

104-
cls.__attrs_post_init__ = __attrs_post_init__
108+
def __attrs_post_init__(self):
109+
if post_init_prev:
110+
post_init_prev(self)
111+
112+
_set_data_on_self(self, cls)
113+
_set_self_on_model(self, cls)
105114

115+
# TODO: figure out why classes need to have a
116+
# __attrs_post_init__ method for this to work
117+
setattr(cls, post_init_name, __attrs_post_init__)
106118
return cls
107119

108120

121+
class Model:
122+
pass
123+
124+
109125
@datatree
110126
@define(slots=False)
111127
class GwfDis:
112128
nlay: int = field(default=1, metadata={"block": "dimensions"})
113129
ncol: int = field(default=2, metadata={"block": "dimensions"})
114130
nrow: int = field(default=2, metadata={"block": "dimensions"})
115131
delr: NDArray[np.floating] = field(
116-
converter=attrs.Converter(
117-
_to_shaped_array, takes_self=True, takes_field=True
118-
),
132+
converter=_to_array,
119133
default=1.0,
120134
metadata={"block": "griddata", "shape": "(ncol,)"},
121135
)
122136
delc: NDArray[np.floating] = field(
123-
converter=attrs.Converter(
124-
_to_shaped_array, takes_self=True, takes_field=True
125-
),
137+
converter=_to_array,
126138
default=1.0,
127139
metadata={"block": "griddata", "shape": "(nrow,)"},
128140
)
129141
top: NDArray[np.floating] = field(
130-
converter=attrs.Converter(
131-
_to_shaped_array, takes_self=True, takes_field=True
132-
),
142+
converter=_to_array,
133143
default=1.0,
134144
metadata={"block": "griddata", "shape": "(ncol, nrow)"},
135145
)
136146
botm: NDArray[np.floating] = field(
137-
converter=attrs.Converter(
138-
_to_shaped_array, takes_self=True, takes_field=True
139-
),
147+
converter=_to_array,
140148
default=0.0,
141149
metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"},
142150
)
143151
idomain: Optional[NDArray[np.integer]] = field(
144-
converter=attrs.Converter(
145-
_to_shaped_array, takes_self=True, takes_field=True
146-
),
152+
converter=_to_array,
147153
default=1,
148154
metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"},
149155
)
@@ -156,8 +162,7 @@ class GwfDis:
156162
default=False, metadata={"block": "options"}
157163
)
158164
nodes: int = field(init=False)
159-
data: Dataset = field(init=False)
160-
model: Optional[Any] = field(default=None)
165+
model: Optional[Model] = field(default=None)
161166

162167
def __attrs_post_init__(self):
163168
self.nodes = self.nlay * self.ncol * self.nrow
@@ -167,9 +172,8 @@ def __attrs_post_init__(self):
167172
@define(slots=False)
168173
class GwfIc:
169174
strt: NDArray[np.floating] = field(
170-
converter=attrs.Converter(
171-
_to_shaped_array, takes_self=True, takes_field=True
172-
),
175+
converter=_to_array,
176+
default=1.0,
173177
metadata={"block": "packagedata", "shape": "(nodes)"},
174178
)
175179
export_array_ascii: bool = field(
@@ -179,8 +183,11 @@ class GwfIc:
179183
default=False,
180184
metadata={"block": "options"},
181185
)
182-
data: Dataset = field(init=False)
183-
model: Optional[Any] = field(default=None)
186+
model: Optional[Model] = field(default=None)
187+
188+
def __attrs_post_init__(self):
189+
# for some reason this is necessary..
190+
pass
184191

185192

186193
@datatree
@@ -208,17 +215,23 @@ class Format:
208215
perioddata: Optional[list[list[tuple]]] = field(
209216
default=None, metadata={"block": "perioddata"}
210217
)
211-
data: Dataset = field(init=False)
212-
model: Optional[Any] = field(default=None)
218+
model: Optional[Model] = field(default=None)
219+
220+
def __attrs_post_init__(self):
221+
# for some reason this is necessary..
222+
pass
213223

214224

215225
@datatree
216226
@define(slots=False)
217-
class Gwf:
227+
class Gwf(Model):
218228
dis: Optional[GwfDis] = field(default=None)
219229
ic: Optional[GwfIc] = field(default=None)
220230
oc: Optional[GwfOc] = field(default=None)
221-
data: DataTree = field(init=False)
231+
232+
def __attrs_post_init__(self):
233+
# for some reason this is necessary..
234+
pass
222235

223236

224237
# We can define a package with some data.
@@ -268,7 +281,12 @@ class Gwf:
268281
assert period[0] == ("print", "budget", "steps", 1, 3, 5)
269282

270283

271-
# Creating a model by constructor.
284+
# Create a model.
285+
272286

287+
gwf = Gwf()
288+
dis = GwfDis(model=gwf)
289+
ic = GwfIc(model=gwf, strt=1)
290+
oc.model = gwf
273291

274-
gwf = Gwf(dis=GwfDis(), ic=GwfIc(strt=1), oc=oc)
292+
# View the data tree.

0 commit comments

Comments
 (0)