Skip to content

Commit fbe80d6

Browse files
committed
attr-style access, misc
1 parent 16d7623 commit fbe80d6

File tree

4 files changed

+100
-68
lines changed

4 files changed

+100
-68
lines changed

flopy4/__init__.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
Scalar = bool | int | float | str | Path
1919
"""A scalar value."""
2020

21-
_HasAttrs = Annotated[object, Is[lambda obj: attrs.has(type(obj))]]
21+
_IsAttrs = Annotated[object, Is[lambda obj: attrs.has(type(obj))]]
2222
"""Runtime-applied type hint for `attrs` based class instances."""
2323

24-
_HasData = Annotated[object, IsAttr["data", IsInstance[DataTree]]]
24+
_HasTree = Annotated[object, IsAttr["data", IsInstance[DataTree]]]
2525
"""Runtime-applied type hint for objects with a `DataTree` in `.data`."""
2626

2727
_Component = Annotated[
@@ -92,7 +92,7 @@ def _find_recursive(tree, key):
9292

9393

9494
def resolve_array(
95-
self: _HasAttrs,
95+
self: _IsAttrs,
9696
attr: Attribute,
9797
value: ArrayLike,
9898
tree: DataTree = None,
@@ -120,7 +120,7 @@ def resolve_array(
120120
shape = [find(tree or DataTree(), key=dim, default=dim) for dim in dims]
121121
shape = tuple(
122122
[
123-
(dim if isinstance(dim, int) else kwargs.get(dim, dim))
123+
(dim if isinstance(dim, int) else kwargs.pop(dim, dim))
124124
for dim in shape
125125
]
126126
)
@@ -140,30 +140,58 @@ def resolve_array(
140140
return value
141141

142142

143-
def bind_tree(self: _HasData, parent: _HasData):
143+
def bind_tree(
144+
self: _Component,
145+
parent: _Component = None,
146+
children: Optional[Mapping[str, _Component]] = None,
147+
):
144148
"""
145-
Bind a child component to a parent, linking their trees.
146-
If the parent isn't the root, rebind it to recursively
147-
upwards to the root.
149+
Bind a given component to a parent component, linking the
150+
two components and their data trees. If the parent is not
151+
the tree's root, rebind it to recursively up to the root.
152+
153+
Also attach any child components to the given component's
154+
data tree, as well as to any non-`attrs` attributes whose
155+
name matches a child's name.
148156
149157
TODO: this is massively duplicative, since each component
150158
has a subtree of its own, next to the one its parent owns
151159
and in which its tree appears. need to have a single tree
152160
at the root, then each component's data is a view into it.
153161
"""
154-
parent.data = parent.data.assign({self.data.name: self.data})
155-
self.data = parent.data[self.data.name]
156-
grandparent = getattr(parent, "parent", None)
157-
if grandparent is not None:
158-
bind_tree(parent, grandparent)
159-
self.parent = parent
162+
163+
cls = type(self)
164+
165+
if parent:
166+
parent_spec = fields_dict(type(parent))
167+
if self.data.name in parent_spec:
168+
setattr(parent, self.data.name, self)
169+
170+
# TODO
171+
# parent_bindings = {
172+
# k: v
173+
# for k, v in parent_spec.items()
174+
# if v.metadata.get("bind", False)
175+
# }
176+
177+
parent.data = parent.data.assign({self.data.name: self.data})
178+
self.data = parent.data[self.data.name]
179+
grandparent = getattr(parent, "parent", None)
180+
if grandparent is not None:
181+
bind_tree(parent, grandparent)
182+
self.parent = parent
183+
self.children = children
184+
spec = fields_dict(type(self))
185+
for n, c in (children or {}).items():
186+
if n in spec:
187+
setattr(self, n, c)
160188

161189

162190
def init_tree(
163-
self: _HasAttrs,
191+
self: _IsAttrs,
164192
name: Optional[str] = None,
165-
parent: Optional[_HasData] = None,
166-
children: Optional[Mapping[str, _HasData]] = None,
193+
parent: Optional[_HasTree] = None,
194+
children: Optional[Mapping[str, _HasTree]] = None,
167195
):
168196
"""
169197
Initialize a data tree for a component class instance.
@@ -180,40 +208,43 @@ class cannot use slots for this to work.
180208
spec = fields_dict(cls)
181209
data = Dataset()
182210
dims = set()
211+
arrays = {}
212+
scalars = {}
213+
children = children or {}
183214

184-
# set arrays, then scalars. filter array dims out
185-
# on the first pass thru, while we set up arrays,
215+
# set scalars and arrays. filter array dims out
186216
# so they're not attached as both vars and dims.
217+
# also filter out subcomponents, just want vars.
187218
for attr in spec.values():
219+
bind = attr.metadata.get("bind", False)
220+
if bind:
221+
continue
188222
dims_ = attr.metadata.get("dims", None)
189223
if dims_ is None:
224+
scalars[attr.name] = attr
190225
continue
191226
dims.update(dims_)
227+
arrays[attr.name] = attr
228+
scalars = {k: self.__dict__.pop(k, v.default) for k, v in scalars.items()}
229+
for attr in arrays.values():
230+
dims_ = attr.metadata["dims"]
192231
value = resolve_array(
193232
self,
194233
attr,
195-
value=self.__dict__.pop(attr.name),
234+
value=self.__dict__.pop(attr.name, attr.default),
196235
tree=parent.data.root if parent else None,
197-
**self.__dict__,
236+
**scalars,
198237
)
199238
data[attr.name] = (dims_, value)
200-
for attr in spec.values():
201-
if attr.name in data or attr.name in dims:
202-
continue
203-
data[attr.name] = self.__dict__.pop(attr.name, attr.default)
239+
for k, v in scalars.items():
240+
data.attrs[k] = v
204241

205-
# create tree
206242
self.data = DataTree(
207243
data,
208244
name=name or cls.__name__.lower(),
209-
children={
210-
n: c.data for n, c in (children or {}).items() if c is not None
211-
},
245+
children={n: c.data for n, c in children.items()},
212246
)
213-
214-
# bind tree
215-
if parent is not None:
216-
bind_tree(self, parent)
247+
bind_tree(self, parent=parent, children=children)
217248

218249

219250
def getattribute(self: Any, name: str) -> Any:
@@ -230,8 +261,11 @@ def getattribute(self: Any, name: str) -> Any:
230261
"""
231262
cls = type(self)
232263
spec = fields_dict(cls)
264+
if name == "data":
265+
raise AttributeError
233266
tree = self.data
234-
if name in spec:
267+
var = spec.get(name, None)
268+
if var:
235269
value = get(tree, name, None)
236270
if value is not None:
237271
return value
@@ -263,7 +297,7 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
263297
# TODO run validation?
264298

265299

266-
def component(cls: type[_HasAttrs]) -> type[_Component]:
300+
def component(cls: type[_IsAttrs]) -> type[_Component]:
267301
"""
268302
Attach a data tree to an `attrs` class instance, and use
269303
the data tree for attribute storage: intercept gets/sets
@@ -277,13 +311,13 @@ def component(cls: type[_HasAttrs]) -> type[_Component]:
277311

278312
old_init = cls.__init__
279313

280-
def init(self, *args, **kwargs):
314+
def _init(self, *args, **kwargs):
281315
name = kwargs.pop("name", None)
282316
parent = args[0] if args and any(args) else None
283317
children = kwargs.pop("children", None)
284318
old_init(self, **kwargs)
285319
init_tree(self, name=name, parent=parent, children=children)
286320
cls.__getattr__ = getattribute
287321

288-
cls.__init__ = init
322+
cls.__init__ = _init
289323
return cls

flopy4/mf6/__init__.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,31 +72,9 @@ class Simulation(Component):
7272

7373

7474
@component
75-
@define(slots=False, on_setattr=setattribute)
75+
@define(init=False, slots=False)
7676
class Sim(Simulation):
77-
pass
78-
# tdis: Tdis = field(metadata={"block": "timing"})
79-
# models: dict[str, Model] = field(metadata={"block": "models"})
80-
# exchanges: dict[str, Exchange] = field(metadata={"block": "exchanges"})
81-
# solutions: dict[str, Solution] = field(metadata={"block": "solutions"})
82-
83-
# def __init__(
84-
# self,
85-
# name=None,
86-
# path=None,
87-
# exe=None,
88-
# tdis=None,
89-
# models=None,
90-
# exchanges=None,
91-
# solutions=None,
92-
# ):
93-
# super().__init__(name, path, exe)
94-
# init_tree(
95-
# self,
96-
# children={
97-
# "tdis": tdis,
98-
# "models": models,
99-
# "exchanges": exchanges,
100-
# "solutions": solutions,
101-
# },
102-
# )
77+
tdis: Tdis = field(metadata={"bind": True})
78+
models: dict[str, Model] = field(metadata={"bind": True})
79+
exchanges: dict[str, Exchange] = field(metadata={"bind": True})
80+
solutions: dict[str, Solution] = field(metadata={"bind": True})

flopy4/todo.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ BFS is ok as a general solution but we should use all the info we have.
1919

2020
### components
2121

22-
I think for access by name we want dict style e.g. `gwf["chd1"]`.
22+
I think for access by name we want dict style e.g. `gwf["chd1"]`,
23+
like imod-python does it.
2324

2425
By type, e.g. `gwf.chd`, where it's either a single component,
25-
or a dict for multipackages.
26+
or a dict by name (or auto-increment index) for multipackages.
2627

2728
### variables
2829

test/test_component.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ def test_registry():
1212
assert COMPONENTS["npf"] is Npf
1313

1414

15-
def test_sim(benchmark):
15+
def test_init_top_down():
1616
sim = Sim()
17-
tdis = Tdis(sim, nper=1, perioddata=[Tdis.PeriodData()])
17+
tdis = Tdis(sim)
1818
gwf = Gwf(sim)
1919
dis = Dis(gwf)
2020
ic = Ic(gwf, strt=1.0)
@@ -24,6 +24,7 @@ def test_sim(benchmark):
2424
assert isinstance(sim.data, DataTree)
2525
sim.data # view the tree
2626

27+
assert sim.tdis is tdis
2728
assert "tdis" in sim.data.children
2829
assert "gwf" in sim.data.children
2930
assert "dis" in sim.data.children["gwf"].children
@@ -35,5 +36,23 @@ def test_sim(benchmark):
3536
sim.data.children["gwf"].children["npf"].k, np.ones((4))
3637
)
3738
assert np.array_equal(npf.k, npf.data.k)
39+
3840
# TODO: figure out how to deduplicate trees. components proxy root?
41+
# assert npf.k is npf.data.k
3942
# assert gwf.parent.data.children["gwf"].children["npf"] is npf.data
43+
44+
45+
def test_init_bottom_up():
46+
dis = Dis()
47+
gwf = Gwf(children={"dis": dis})
48+
tdis = Tdis()
49+
sim = Sim(children={"tdis": tdis, "gwf": gwf})
50+
51+
assert isinstance(sim.data, DataTree)
52+
sim.data # view the tree
53+
54+
assert sim.tdis is tdis
55+
56+
assert "tdis" in sim.data.children
57+
assert "gwf" in sim.data.children
58+
assert "dis" in sim.data.children["gwf"].children

0 commit comments

Comments
 (0)