Skip to content

Commit 1183e71

Browse files
committed
bottom up working
1 parent 420453b commit 1183e71

File tree

8 files changed

+47
-34
lines changed

8 files changed

+47
-34
lines changed

flopy4/__init__.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from collections.abc import Iterable, Mapping
1+
from collections.abc import Mapping
2+
from itertools import chain
23
from pathlib import Path
34
from typing import Annotated, Any, Optional, get_origin
45

@@ -215,6 +216,7 @@ def init_tree(
215216
name: Optional[str] = None,
216217
parent: Optional[_HasTree] = None,
217218
children: Optional[Mapping[str, _HasTree]] = None,
219+
**kwargs,
218220
):
219221
"""
220222
Initialize a data tree for a component class instance.
@@ -271,7 +273,7 @@ def _yield_arrays(spec, vals):
271273
var,
272274
value=vals.pop(var.name, var.default),
273275
tree=parent.data.root if parent else None,
274-
**scalar_vals,
276+
**{**scalar_vals, **kwargs},
275277
)
276278
if val is not None:
277279
yield (var.name, (dims, val))
@@ -347,11 +349,7 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
347349
self.data.update({attr.name: value})
348350

349351

350-
def component(
351-
maybe_cls: Optional[type[_IsAttrs]] = None,
352-
*,
353-
align: Optional[Iterable[str]] = None,
354-
) -> type[_Component]:
352+
def component(maybe_cls: Optional[type[_IsAttrs]] = None) -> type[_Component]:
355353
"""
356354
Attach a data tree to an `attrs` class instance, and use
357355
the data tree for attribute storage: intercept gets/sets
@@ -365,26 +363,37 @@ def component(
365363

366364
def wrap(cls):
367365
init_self = cls.__init__
366+
spec = fields_dict(cls)
368367

369368
def init(self, *args, **kwargs):
370369
name = kwargs.pop("name", None)
371370
children = kwargs.pop("children", None)
372371
parent = args[0] if args and any(args) else None
373372

374373
# resolve dims from grid and time discretizations
374+
# get dims from spec
375+
dim_kwargs = {}
376+
dims_used = set(
377+
chain(*[var.metadata.get("dims", []) for var in spec.values()])
378+
)
375379
grid: Grid = kwargs.pop("grid", None)
376380
time: ModelTime = kwargs.pop("time", None)
377-
diss = [dis for dis in [grid, time] if dis]
378-
if align:
379-
for dim in align:
380-
for dis in diss:
381-
attr = getattr(dis, dim, None)
382-
if attr is not None:
383-
kwargs[dim] = attr
381+
if grid:
382+
grid_dims = ["nlay", "nrow", "ncol", "nnodes"]
383+
for dim in grid_dims:
384+
if dim in dims_used:
385+
dim_kwargs[dim] = getattr(grid, dim)
386+
if time:
387+
time_dims = ["nper", "ntstp"]
388+
for dim in time_dims:
389+
if dim in dims_used:
390+
dim_kwargs[dim] = getattr(time, dim)
384391

385392
# run the original __init__, then set up the tree
386393
init_self(self, **kwargs)
387-
init_tree(self, name=name, parent=parent, children=children)
394+
init_tree(
395+
self, name=name, parent=parent, children=children, **dim_kwargs
396+
)
388397

389398
# override attribute access
390399
cls.__getattr__ = getattribute

flopy4/mf6/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class Model(Component):
3434
pass
3535

3636

37-
@component(align=["nper"])
37+
@component
3838
@define(slots=False, on_setattr=setattribute)
3939
class Tdis(Package):
4040
@define(slots=False)

flopy4/mf6/gwf/chd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from flopy4.mf6 import Package
88

99

10-
@component(align="nper")
10+
@component
1111
@define(slots=False, on_setattr=setattribute)
1212
class Chd(Package):
1313
multi = True

flopy4/mf6/gwf/dis.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from flopy4.mf6 import Package
99

1010

11-
@component(align=["nlay", "ncol", "nrow"])
11+
@component
1212
@define(slots=False, on_setattr=setattribute)
1313
class Dis(Package):
1414
length_units: str = field(
@@ -45,7 +45,10 @@ class Dis(Package):
4545
default=1,
4646
metadata={"block": "griddata", "dims": ("ncol", "nrow", "nlay")},
4747
)
48-
nodes: Optional[int] = field(default=None)
48+
nnodes: Optional[int] = field(default=None)
4949

5050
def __attrs_post_init__(self):
51-
self.nodes = self.ncol * self.nrow * self.nlay
51+
try:
52+
self.nnodes = self.ncol * self.nrow * self.nlay
53+
except:
54+
pass

flopy4/mf6/gwf/ic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from flopy4.mf6 import Package
77

88

9-
@component(align=["nodes"])
9+
@component
1010
@define(slots=False, on_setattr=setattribute)
1111
class Ic(Package):
1212
strt: NDArray[np.floating] = field(
1313
default=1.0,
14-
metadata={"block": "packagedata", "dims": ("nodes",)},
14+
metadata={"block": "packagedata", "dims": ("nnodes",)},
1515
)
1616
export_array_ascii: bool = field(
1717
default=False, metadata={"block": "options"}

flopy4/mf6/gwf/npf.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from flopy4.mf6 import Package
1010

1111

12-
@component(align=["nodes"])
12+
@component
1313
@define(slots=False, on_setattr=setattribute)
1414
class Npf(Package):
1515
@define(slots=False)
@@ -66,33 +66,33 @@ class Xt3dOptions:
6666
)
6767
icelltype: NDArray[np.integer] = field(
6868
default=0,
69-
metadata={"block": "griddata", "dims": ("nodes",)},
69+
metadata={"block": "griddata", "dims": ("nnodes",)},
7070
)
7171
k: NDArray[np.floating] = field(
7272
default=1.0,
73-
metadata={"block": "griddata", "dims": ("nodes",)},
73+
metadata={"block": "griddata", "dims": ("nnodes",)},
7474
)
7575
k22: Optional[NDArray[np.floating]] = field(
7676
default=None,
77-
metadata={"block": "griddata", "dims": ("nodes",)},
77+
metadata={"block": "griddata", "dims": ("nnodes",)},
7878
)
7979
k33: Optional[NDArray[np.floating]] = field(
8080
default=None,
81-
metadata={"block": "griddata", "dims": ("nodes",)},
81+
metadata={"block": "griddata", "dims": ("nnodes",)},
8282
)
8383
angle1: Optional[NDArray[np.floating]] = field(
8484
default=None,
85-
metadata={"block": "griddata", "dims": ("nodes",)},
85+
metadata={"block": "griddata", "dims": ("nnodes",)},
8686
)
8787
angle2: Optional[NDArray[np.floating]] = field(
8888
default=None,
89-
metadata={"block": "griddata", "dims": ("nodes",)},
89+
metadata={"block": "griddata", "dims": ("nnodes",)},
9090
)
9191
angle3: Optional[NDArray[np.floating]] = field(
9292
default=None,
93-
metadata={"block": "griddata", "dims": ("nodes",)},
93+
metadata={"block": "griddata", "dims": ("nnodes",)},
9494
)
9595
wetdry: Optional[NDArray[np.floating]] = field(
9696
default=None,
97-
metadata={"block": "griddata", "dims": ("nodes",)},
97+
metadata={"block": "griddata", "dims": ("nnodes",)},
9898
)

flopy4/mf6/gwf/oc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414

15-
@component(align="nper")
15+
@component
1616
@define(slots=False, on_setattr=setattribute)
1717
class Oc(Package):
1818
@define(slots=False)

test/test_component.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,19 @@ def test_init_top_down():
5555

5656
def test_init_bottom_up():
5757
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
58-
grid = StructuredGrid()
58+
grid = StructuredGrid(nlay=1, nrow=2, ncol=2)
5959
dis = Dis(grid=grid)
6060
ic = Ic(grid=grid)
6161
oc = Oc(grid=grid)
6262
npf = Npf(grid=grid)
6363
gwf = Gwf(
64+
grid=grid,
6465
children={
6566
"dis": dis,
6667
"ic": ic,
6768
"oc": oc,
6869
"npf": npf,
69-
}
70+
},
7071
)
7172
tdis = Tdis(time=time)
7273
sim = Sim(children={"tdis": tdis, "gwf": gwf})

0 commit comments

Comments
 (0)