Skip to content

Commit c82de32

Browse files
authored
Cleanup and coords prep (#85)
* cleanup, prep metadata for coords * eponymous coordinates.. not what we want, just hacking
1 parent bab9b1e commit c82de32

File tree

5 files changed

+97
-68
lines changed

5 files changed

+97
-68
lines changed

flopy4/__init__.py

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Mapping
2-
from itertools import chain
32
from pathlib import Path
43
from typing import Annotated, Any, Optional, get_origin
54

@@ -97,7 +96,7 @@ def resolve_array(
9796
tree: DataTree = None,
9897
strict: bool = False,
9998
**kwargs,
100-
) -> Optional[NDArray]:
99+
) -> tuple[Optional[NDArray], Optional[dict[str, NDArray]]]:
101100
"""
102101
Resolve an array-like value to the given variable's expected shape.
103102
If the value is a collection, check if the shape matches. If scalar,
@@ -118,15 +117,15 @@ def resolve_array(
118117
f"Component class '{type(self).__name__}' array "
119118
f"variable '{attr.name}' could not be resolved "
120119
)
121-
return None
120+
return None, None
122121
dims = attr.metadata.get("dims", None)
123122
if not dims:
124123
if strict:
125124
raise ValueError(
126125
f"Component class '{type(self).__name__}' array "
127126
f"variable '{attr.name}' needs 'dims' metadata"
128127
)
129-
return None
128+
return None, None
130129
shape = [find(tree or DataTree(), key=dim, default=dim) for dim in dims]
131130
shape = tuple(
132131
[
@@ -142,8 +141,10 @@ def resolve_array(
142141
f"variable '{attr.name}' failed dim resolution: "
143142
f"{', '.join(unresolved)}"
144143
)
145-
return None
146-
return reshape_array(value, shape)
144+
return None, None
145+
array = reshape_array(value, shape)
146+
coords = {dim: np.arange(size) for dim, size in zip(dims, shape)}
147+
return array, coords
147148

148149

149150
def bind_tree(
@@ -160,10 +161,11 @@ def bind_tree(
160161
data tree, as well as to any non-`attrs` attributes whose
161162
name matches a child's name.
162163
163-
TODO: this is massively duplicative, since each component
164-
has a subtree of its own, next to the one its parent owns
165-
and in which its tree appears. need to have a single tree
166-
at the root, then each component's data is a view into it.
164+
TODO: discover dimensions from self, parent and children.
165+
If the parent defines a dimension, it should be used for
166+
self and children. If a dimension is found in self or in
167+
a child which has scope broader than self, send it up to
168+
the parent.
167169
"""
168170

169171
cls = type(self)
@@ -172,30 +174,27 @@ def bind_tree(
172174

173175
# bind parent
174176
if parent:
175-
# try binding first by name
177+
# bind to parent attrs whose name
178+
# matches this component's name
176179
parent_spec = fields_dict(type(parent))
177180
parent_var = parent_spec.get(name, None)
178181
if parent_var:
179182
assert parent_var.metadata.get("bind", False)
180183
setattr(parent, name, self)
181-
# TODO bind multipackages by type
182-
# parent_bindings = {
183-
# n: v
184-
# for n, v in parent_spec.items()
185-
# if v.metadata.get("bind", False)
186-
# }
187-
# print(parent_bindings)
184+
185+
# bind parent data tree
188186
if name in parent.data:
189187
parent.data.update({name: self.data})
190188
else:
191189
parent.data = parent.data.assign({name: self.data})
192190
self.data = parent.data[self.data.name]
193191

194-
# bind grandparent
192+
# bind grandparent recursively
195193
grandparent = getattr(parent, "parent", None)
196194
if grandparent is not None:
197195
bind_tree(parent, parent=grandparent)
198196

197+
# update parent reference
199198
self.parent = parent
200199

201200
# bind children
@@ -234,11 +233,12 @@ def init_tree(
234233
cls = type(self)
235234
spec = fields_dict(cls)
236235
dimensions = set()
236+
coordinates = {}
237+
components = {}
237238
array_vars = {}
238239
scalar_vars = {}
239240
array_vals = {}
240241
scalar_vals = {}
241-
components = {}
242242

243243
for var in spec.values():
244244
bind = var.metadata.get("bind", False)
@@ -264,21 +264,23 @@ def _yield_scalars(spec, vals):
264264
def _yield_arrays(spec, vals):
265265
for var in spec.values():
266266
dims = var.metadata["dims"]
267-
val = resolve_array(
267+
val, coords = resolve_array(
268268
self,
269269
var,
270270
value=vals.pop(var.name, var.default),
271271
tree=parent.data.root if parent else None,
272272
**{**scalar_vals, **kwargs},
273273
)
274274
if val is not None:
275+
coordinates.update(coords)
275276
yield (var.name, (dims, val))
276277

277278
array_vals = dict(list(_yield_arrays(spec=array_vars, vals=self.__dict__)))
278279

279280
self.data = DataTree(
280281
Dataset(
281282
data_vars=array_vals,
283+
coords=coordinates,
282284
attrs={
283285
n: v for n, v in scalar_vals.items() if n not in dimensions
284286
},
@@ -332,7 +334,7 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
332334
return value
333335
if get_origin(attr.type) in [list, np.ndarray]:
334336
shape = attr.metadata["dims"]
335-
value = resolve_array(self, attr, value)
337+
value, _ = resolve_array(self, attr, value)
336338
value = (shape, value)
337339
bind = attr.metadata.get("bind", False)
338340
if bind:
@@ -341,6 +343,30 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
341343
self.data.update({attr.name: value})
342344

343345

346+
def pop_dims(**kwargs):
347+
"""
348+
Use dims from `Grid` and/or `ModelTime` instances
349+
passed to `grid` and `time` keyword arguments, if
350+
available.
351+
"""
352+
dims = {}
353+
grid: Grid = kwargs.pop("grid", None)
354+
time: ModelTime = kwargs.pop("time", None)
355+
grid_dims = ["nlay", "nrow", "ncol", "nnodes"]
356+
time_dims = ["nper", "nstp"]
357+
if grid:
358+
for dim in grid_dims:
359+
dims[dim] = getattr(grid, dim)
360+
if time:
361+
for dim in time_dims:
362+
dims[dim] = getattr(time, dim)
363+
for dim in grid_dims + time_dims:
364+
v = kwargs.pop(dim, None)
365+
if v is not None:
366+
dims[dim] = v
367+
return kwargs, dims
368+
369+
344370
def component(maybe_cls: Optional[type[_IsAttrs]] = None) -> type[_Component]:
345371
"""
346372
Attach a data tree to an `attrs` class instance, and use
@@ -362,28 +388,11 @@ def init(self, *args, **kwargs):
362388
children = kwargs.pop("children", None)
363389
parent = args[0] if args and any(args) else None
364390

365-
# use dims from grid and modeltime, if provided
366-
dim_kwargs = {}
367-
dims_used = set(
368-
chain(*[var.metadata.get("dims", []) for var in spec.values()])
369-
)
370-
grid: Grid = kwargs.pop("grid", None)
371-
time: ModelTime = kwargs.pop("time", None)
372-
if grid:
373-
grid_dims = ["nlay", "nrow", "ncol", "nnodes"]
374-
for dim in grid_dims:
375-
if dim in dims_used:
376-
dim_kwargs[dim] = getattr(grid, dim)
377-
if time:
378-
time_dims = ["nper", "ntstp"]
379-
for dim in time_dims:
380-
if dim in dims_used:
381-
dim_kwargs[dim] = getattr(time, dim)
382-
383391
# run the original __init__, then set up the tree
392+
kwargs, dimensions = pop_dims(**kwargs)
384393
init_self(self, **kwargs)
385394
init_tree(
386-
self, name=name, parent=parent, children=children, **dim_kwargs
395+
self, name=name, parent=parent, children=children, **dimensions
387396
)
388397
bind_tree(self, parent=parent, children=children)
389398

@@ -397,3 +406,7 @@ def init(self, *args, **kwargs):
397406
return wrap
398407

399408
return wrap(maybe_cls)
409+
410+
411+
# TODO: add separate `component()` decorator like `attrs.field()`?
412+
# for now, "bind" metadata indicates subcomponent, not a variable.

flopy4/mf6/__init__.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
"Package",
1313
"Model",
1414
"Simulation",
15-
"Sim",
15+
"Solution",
16+
"Exchange",
1617
"COMPONENTS",
1718
]
1819

@@ -34,6 +35,17 @@ class Model(Component):
3435
pass
3536

3637

38+
class Solution(Package):
39+
pass
40+
41+
42+
class Exchange(Package):
43+
exgtype: type = field()
44+
exgfile: Path = field()
45+
exgmnamea: Optional[str] = field(default=None)
46+
exgmnameb: Optional[str] = field(default=None)
47+
48+
3749
@component
3850
@define(slots=False, on_setattr=setattribute)
3951
class Tdis(Package):
@@ -44,7 +56,11 @@ class PeriodData:
4456
tsmult: float = field(default=1.0)
4557

4658
nper: int = field(
47-
default=1, metadata={"block": "dimensions", "dim": {"coord": "kper"}}
59+
default=1,
60+
metadata={
61+
"block": "dimensions",
62+
"dim": {"coord": "kper", "scope": "simulation"},
63+
},
4864
)
4965
perioddata: list[PeriodData] = field(
5066
default=Factory(list),
@@ -58,26 +74,9 @@ class PeriodData:
5874
)
5975

6076

61-
class Solution(Package):
62-
pass
63-
64-
65-
class Exchange(Package):
66-
exgtype: type = field()
67-
exgfile: Path = field()
68-
exgmnamea: Optional[str] = field(default=None)
69-
exgmnameb: Optional[str] = field(default=None)
70-
71-
72-
class Simulation(Component):
73-
pass
74-
75-
7677
@component
7778
@define(init=False, slots=False)
78-
class Sim(Simulation):
79-
# "bind" indicates this is a subcomponent, not a variable.
80-
# TODO: add separate `component()` decorator like `field`?
79+
class Simulation(Component):
8180
models: dict[str, Model] = field(metadata={"bind": True})
8281
exchanges: dict[str, Exchange] = field(metadata={"bind": True})
8382
solutions: dict[str, Solution] = field(metadata={"bind": True})

flopy4/mf6/gwf/dis.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,25 @@ class Dis(Package):
2323
default=False, metadata={"block": "options"}
2424
)
2525
nlay: int = field(
26-
default=1, metadata={"block": "dimensions", "dim": {"coord": "k"}}
26+
default=1,
27+
metadata={
28+
"block": "dimensions",
29+
"dim": {"coord": "k", "scope": "simulation"},
30+
},
2731
)
2832
ncol: int = field(
29-
default=2, metadata={"block": "dimensions", "dim": {"coord": "i"}}
33+
default=2,
34+
metadata={
35+
"block": "dimensions",
36+
"dim": {"coord": "i", "scope": "simulation"},
37+
},
3038
)
3139
nrow: int = field(
32-
default=2, metadata={"block": "dimensions", "dim": {"coord": "j"}}
40+
default=2,
41+
metadata={
42+
"block": "dimensions",
43+
"dim": {"coord": "j", "scope": "simulation"},
44+
},
3345
)
3446
delr: NDArray[np.floating] = field(
3547
default=1.0,

flopy4/todo.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ Each component should get a view into the root (as far as it's aware) tree
66
unless it's the root (i.e. simulation) itself, or it's not attached to any
77
parent context, in which case it's the root of its own tree.
88

9+
Currently it's massively duplicative, since each component
10+
has a subtree of its own, next to the one its parent owns
11+
and in which its tree appears. need to have a single tree
12+
at the root, then each component's data is a view into it.
13+
914
- subcomponent accessors
1015

1116
I think for access by name we want dict style e.g. `gwf["chd1"]`,

test/test_component.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from flopy.discretization.modeltime import ModelTime
44
from xarray import DataTree
55

6-
from flopy4.mf6 import COMPONENTS, Sim, Tdis
6+
from flopy4.mf6 import COMPONENTS, Simulation, Tdis
77
from flopy4.mf6.gwf import Dis, Gwf, Ic, Npf, Oc
88

99

1010
def test_registry():
11-
assert COMPONENTS["sim"] is Sim
11+
assert COMPONENTS["simulation"] is Simulation
1212
assert COMPONENTS["tdis"] is Tdis
1313
assert COMPONENTS["gwf"] is Gwf
1414
assert COMPONENTS["npf"] is Npf
@@ -17,7 +17,7 @@ def test_registry():
1717

1818

1919
def test_init_top_down():
20-
sim = Sim()
20+
sim = Simulation()
2121
tdis = Tdis(sim)
2222
gwf = Gwf(sim)
2323
dis = Dis(gwf)
@@ -70,7 +70,7 @@ def test_init_bottom_up():
7070
},
7171
)
7272
tdis = Tdis(time=time)
73-
sim = Sim(children={"tdis": tdis, "gwf": gwf})
73+
sim = Simulation(children={"tdis": tdis, "gwf": gwf})
7474

7575
assert sim.tdis is tdis
7676
# TODO test autoincrement

0 commit comments

Comments
 (0)