Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion flopy4/mf6/component.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from abc import ABC
from collections.abc import MutableMapping

from attrs import Attribute
from modflow_devtools.dfn import Dfn, Var
from xattree import xattree

from flopy4.mf6.spec import fields_dict

COMPONENTS = {}
"""MF6 component registry."""

Expand All @@ -13,6 +17,9 @@ class Component(ABC, MutableMapping):
def __attrs_init_subclass__(cls):
COMPONENTS[cls.__name__.lower()] = cls

def __attrs_post_init__(self):
self._dfn = self._get_dfn()

def __getitem__(self, key):
return self.children[key] # type: ignore

Expand All @@ -28,7 +35,46 @@ def __iter__(self):
def __len__(self):
return len(self.children) # type: ignore

@property
def dfn(self) -> Dfn:
"""Return the component's definition."""
return self._dfn

def _get_dfn(self) -> Dfn:
def _to_dfn_spec(attribute: Attribute) -> Var:
return Var(
name=attribute.name,
type=attribute.type,
shape=attribute.metadata.get("dims", None),
block=attribute.metadata.get("block", None),
default=attribute.default,
children={k: _to_dfn_spec(v) for k, v in fields_dict(attribute.type)} # type: ignore
if attribute.metadata.get("kind", None) == "child" # type: ignore
else None, # type: ignore
)

fields = {k: _to_dfn_spec(v) for k, v in fields_dict(self.__class__).items()}
blocks: dict[str, dict[str, Var]] = {}
for k, v in fields.items():
if (block := v.get("block", None)) is not None:
blocks.setdefault(block, {})[k] = v
else:
blocks[k] = v
return Dfn(
name=self.name, # type: ignore
advanced=getattr(self, "advanced_package", False),
multi=getattr(self, "multi_package", False),
ref=getattr(self, "sub_package", None),
sln=getattr(self, "solution_package", None),
**blocks,
)

def load(self) -> None:
# TODO: load
for child in self.children.values(): # type: ignore
child.load()

def write(self) -> None:
# TODO: write with jinja to file
# TODO: write
for child in self.children.values(): # type: ignore
child.write()
4 changes: 3 additions & 1 deletion flopy4/mf6/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC

from xattree import xattree

from flopy4.mf6.component import Component


@xattree
class Model(Component):
class Model(Component, ABC):
pass
4 changes: 3 additions & 1 deletion flopy4/mf6/package.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC

from xattree import xattree

from flopy4.mf6.component import Component


@xattree
class Package(Component):
class Package(Component, ABC):
pass
4 changes: 3 additions & 1 deletion flopy4/mf6/solution.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC

from xattree import xattree

from flopy4.mf6.package import Package


@xattree
class Solution(Package):
class Solution(Package, ABC):
pass
17 changes: 16 additions & 1 deletion flopy4/mf6/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
These include field decorators and introspection functions.
"""

from attrs import NOTHING, Attribute, fields_dict
from attrs import NOTHING, Attribute

from flopy4.spec import array as flopy_array
from flopy4.spec import coord as flopy_coord
from flopy4.spec import dim as flopy_dim
from flopy4.spec import field as flopy_field
from flopy4.spec import fields_dict as flopy_fields_dict


def field(
Expand Down Expand Up @@ -148,3 +149,17 @@ def blocks_dict(cls) -> dict[str, Block]:
blocks[block] = {}
blocks[block][k] = v
return dict(sorted(blocks.items(), key=_block_sort_key))


def fields(cls) -> list[Attribute]:
"""Return an ordered list of fields for a component class."""
return list(fields_dict(cls).values())


def fields_dict(cls) -> dict[str, Attribute]:
"""
Return an ordered dictionary of fields for a component class,
whose keys are field names. Each field is an `attrs.Attribute`.
"""
fields = flopy_fields_dict(cls)
return {k: v for k, v in fields.items() if "block" in v.metadata}
22 changes: 20 additions & 2 deletions flopy4/spec.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Wrap `xattree` and `attrs` specification utilities."""
"""
Wrap `xattree` and `attrs` specification utilities.
These include field decorators and introspection functions.
"""

from attrs import NOTHING
from attrs import NOTHING, Attribute
from xattree import array as xattree_array
from xattree import coord as xattree_coord
from xattree import dim as xattree_dim
from xattree import field as xattree_field
from xattree import fields_dict as xattree_fields_dict


def field(
Expand Down Expand Up @@ -87,3 +91,17 @@ def array(
eq=eq,
metadata=metadata,
)


def fields(cls) -> list[Attribute]:
"""Return an ordered list of fields for a component class."""
return list(fields_dict(cls).values())


def fields_dict(cls) -> dict[str, Attribute]:
"""
Return an ordered dictionary of fields for a component class,
whose keys are field names. Each field is an `attrs.Attribute`.
"""
fields = xattree_fields_dict(cls)
return {k: v for k, v in fields.items()}
Loading
Loading