Skip to content

Commit 53f8f81

Browse files
authored
add dfn property to Component (#139)
introspect the component after init and set a dfn property which reflects its mf6io definition (specification)
1 parent 772a313 commit 53f8f81

File tree

9 files changed

+573
-470
lines changed

9 files changed

+573
-470
lines changed

flopy4/mf6/component.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from abc import ABC
22
from collections.abc import MutableMapping
33

4+
from attrs import Attribute
5+
from modflow_devtools.dfn import Dfn, Var
46
from xattree import xattree
57

8+
from flopy4.mf6.spec import fields_dict
9+
610
COMPONENTS = {}
711
"""MF6 component registry."""
812

@@ -13,6 +17,9 @@ class Component(ABC, MutableMapping):
1317
def __attrs_init_subclass__(cls):
1418
COMPONENTS[cls.__name__.lower()] = cls
1519

20+
def __attrs_post_init__(self):
21+
self._dfn = self._get_dfn()
22+
1623
def __getitem__(self, key):
1724
return self.children[key] # type: ignore
1825

@@ -28,7 +35,46 @@ def __iter__(self):
2835
def __len__(self):
2936
return len(self.children) # type: ignore
3037

38+
@property
39+
def dfn(self) -> Dfn:
40+
"""Return the component's definition."""
41+
return self._dfn
42+
43+
def _get_dfn(self) -> Dfn:
44+
def _to_dfn_spec(attribute: Attribute) -> Var:
45+
return Var(
46+
name=attribute.name,
47+
type=attribute.type,
48+
shape=attribute.metadata.get("dims", None),
49+
block=attribute.metadata.get("block", None),
50+
default=attribute.default,
51+
children={k: _to_dfn_spec(v) for k, v in fields_dict(attribute.type)} # type: ignore
52+
if attribute.metadata.get("kind", None) == "child" # type: ignore
53+
else None, # type: ignore
54+
)
55+
56+
fields = {k: _to_dfn_spec(v) for k, v in fields_dict(self.__class__).items()}
57+
blocks: dict[str, dict[str, Var]] = {}
58+
for k, v in fields.items():
59+
if (block := v.get("block", None)) is not None:
60+
blocks.setdefault(block, {})[k] = v
61+
else:
62+
blocks[k] = v
63+
return Dfn(
64+
name=self.name, # type: ignore
65+
advanced=getattr(self, "advanced_package", False),
66+
multi=getattr(self, "multi_package", False),
67+
ref=getattr(self, "sub_package", None),
68+
sln=getattr(self, "solution_package", None),
69+
**blocks,
70+
)
71+
72+
def load(self) -> None:
73+
# TODO: load
74+
for child in self.children.values(): # type: ignore
75+
child.load()
76+
3177
def write(self) -> None:
32-
# TODO: write with jinja to file
78+
# TODO: write
3379
for child in self.children.values(): # type: ignore
3480
child.write()

flopy4/mf6/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from abc import ABC
2+
13
from xattree import xattree
24

35
from flopy4.mf6.component import Component
46

57

68
@xattree
7-
class Model(Component):
9+
class Model(Component, ABC):
810
pass

flopy4/mf6/package.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from abc import ABC
2+
13
from xattree import xattree
24

35
from flopy4.mf6.component import Component
46

57

68
@xattree
7-
class Package(Component):
9+
class Package(Component, ABC):
810
pass

flopy4/mf6/solution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from abc import ABC
2+
13
from xattree import xattree
24

35
from flopy4.mf6.package import Package
46

57

68
@xattree
7-
class Solution(Package):
9+
class Solution(Package, ABC):
810
pass

flopy4/mf6/spec.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
These include field decorators and introspection functions.
44
"""
55

6-
from attrs import NOTHING, Attribute, fields_dict
6+
from attrs import NOTHING, Attribute
77

88
from flopy4.spec import array as flopy_array
99
from flopy4.spec import coord as flopy_coord
1010
from flopy4.spec import dim as flopy_dim
1111
from flopy4.spec import field as flopy_field
12+
from flopy4.spec import fields_dict as flopy_fields_dict
1213

1314

1415
def field(
@@ -148,3 +149,17 @@ def blocks_dict(cls) -> dict[str, Block]:
148149
blocks[block] = {}
149150
blocks[block][k] = v
150151
return dict(sorted(blocks.items(), key=_block_sort_key))
152+
153+
154+
def fields(cls) -> list[Attribute]:
155+
"""Return an ordered list of fields for a component class."""
156+
return list(fields_dict(cls).values())
157+
158+
159+
def fields_dict(cls) -> dict[str, Attribute]:
160+
"""
161+
Return an ordered dictionary of fields for a component class,
162+
whose keys are field names. Each field is an `attrs.Attribute`.
163+
"""
164+
fields = flopy_fields_dict(cls)
165+
return {k: v for k, v in fields.items() if "block" in v.metadata}

flopy4/spec.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
"""Wrap `xattree` and `attrs` specification utilities."""
1+
"""
2+
Wrap `xattree` and `attrs` specification utilities.
3+
These include field decorators and introspection functions.
4+
"""
25

3-
from attrs import NOTHING
6+
from attrs import NOTHING, Attribute
47
from xattree import array as xattree_array
58
from xattree import coord as xattree_coord
69
from xattree import dim as xattree_dim
710
from xattree import field as xattree_field
11+
from xattree import fields_dict as xattree_fields_dict
812

913

1014
def field(
@@ -87,3 +91,17 @@ def array(
8791
eq=eq,
8892
metadata=metadata,
8993
)
94+
95+
96+
def fields(cls) -> list[Attribute]:
97+
"""Return an ordered list of fields for a component class."""
98+
return list(fields_dict(cls).values())
99+
100+
101+
def fields_dict(cls) -> dict[str, Attribute]:
102+
"""
103+
Return an ordered dictionary of fields for a component class,
104+
whose keys are field names. Each field is an `attrs.Attribute`.
105+
"""
106+
fields = xattree_fields_dict(cls)
107+
return {k: v for k, v in fields.items()}

0 commit comments

Comments
 (0)