diff --git a/flopy4/__init__.py b/flopy4/__init__.py index e2053151..e64e0405 100644 --- a/flopy4/__init__.py +++ b/flopy4/__init__.py @@ -1,5 +1,4 @@ from collections.abc import Mapping -from itertools import chain from pathlib import Path from typing import Annotated, Any, Optional, get_origin @@ -97,7 +96,7 @@ def resolve_array( tree: DataTree = None, strict: bool = False, **kwargs, -) -> Optional[NDArray]: +) -> tuple[Optional[NDArray], Optional[dict[str, NDArray]]]: """ Resolve an array-like value to the given variable's expected shape. If the value is a collection, check if the shape matches. If scalar, @@ -118,7 +117,7 @@ def resolve_array( f"Component class '{type(self).__name__}' array " f"variable '{attr.name}' could not be resolved " ) - return None + return None, None dims = attr.metadata.get("dims", None) if not dims: if strict: @@ -126,7 +125,7 @@ def resolve_array( f"Component class '{type(self).__name__}' array " f"variable '{attr.name}' needs 'dims' metadata" ) - return None + return None, None shape = [find(tree or DataTree(), key=dim, default=dim) for dim in dims] shape = tuple( [ @@ -142,8 +141,10 @@ def resolve_array( f"variable '{attr.name}' failed dim resolution: " f"{', '.join(unresolved)}" ) - return None - return reshape_array(value, shape) + return None, None + array = reshape_array(value, shape) + coords = {dim: np.arange(size) for dim, size in zip(dims, shape)} + return array, coords def bind_tree( @@ -160,10 +161,11 @@ def bind_tree( data tree, as well as to any non-`attrs` attributes whose name matches a child's name. - TODO: this is massively duplicative, since each component - has a subtree of its own, next to the one its parent owns - and in which its tree appears. need to have a single tree - at the root, then each component's data is a view into it. + TODO: discover dimensions from self, parent and children. + If the parent defines a dimension, it should be used for + self and children. If a dimension is found in self or in + a child which has scope broader than self, send it up to + the parent. """ cls = type(self) @@ -172,30 +174,27 @@ def bind_tree( # bind parent if parent: - # try binding first by name + # bind to parent attrs whose name + # matches this component's name parent_spec = fields_dict(type(parent)) parent_var = parent_spec.get(name, None) if parent_var: assert parent_var.metadata.get("bind", False) setattr(parent, name, self) - # TODO bind multipackages by type - # parent_bindings = { - # n: v - # for n, v in parent_spec.items() - # if v.metadata.get("bind", False) - # } - # print(parent_bindings) + + # bind parent data tree if name in parent.data: parent.data.update({name: self.data}) else: parent.data = parent.data.assign({name: self.data}) self.data = parent.data[self.data.name] - # bind grandparent + # bind grandparent recursively grandparent = getattr(parent, "parent", None) if grandparent is not None: bind_tree(parent, parent=grandparent) + # update parent reference self.parent = parent # bind children @@ -234,11 +233,12 @@ def init_tree( cls = type(self) spec = fields_dict(cls) dimensions = set() + coordinates = {} + components = {} array_vars = {} scalar_vars = {} array_vals = {} scalar_vals = {} - components = {} for var in spec.values(): bind = var.metadata.get("bind", False) @@ -264,7 +264,7 @@ def _yield_scalars(spec, vals): def _yield_arrays(spec, vals): for var in spec.values(): dims = var.metadata["dims"] - val = resolve_array( + val, coords = resolve_array( self, var, value=vals.pop(var.name, var.default), @@ -272,6 +272,7 @@ def _yield_arrays(spec, vals): **{**scalar_vals, **kwargs}, ) if val is not None: + coordinates.update(coords) yield (var.name, (dims, val)) array_vals = dict(list(_yield_arrays(spec=array_vars, vals=self.__dict__))) @@ -279,6 +280,7 @@ def _yield_arrays(spec, vals): self.data = DataTree( Dataset( data_vars=array_vals, + coords=coordinates, attrs={ n: v for n, v in scalar_vals.items() if n not in dimensions }, @@ -332,7 +334,7 @@ def setattribute(self: _Component, attr: Attribute, value: Any): return value if get_origin(attr.type) in [list, np.ndarray]: shape = attr.metadata["dims"] - value = resolve_array(self, attr, value) + value, _ = resolve_array(self, attr, value) value = (shape, value) bind = attr.metadata.get("bind", False) if bind: @@ -341,6 +343,30 @@ def setattribute(self: _Component, attr: Attribute, value: Any): self.data.update({attr.name: value}) +def pop_dims(**kwargs): + """ + Use dims from `Grid` and/or `ModelTime` instances + passed to `grid` and `time` keyword arguments, if + available. + """ + dims = {} + grid: Grid = kwargs.pop("grid", None) + time: ModelTime = kwargs.pop("time", None) + grid_dims = ["nlay", "nrow", "ncol", "nnodes"] + time_dims = ["nper", "nstp"] + if grid: + for dim in grid_dims: + dims[dim] = getattr(grid, dim) + if time: + for dim in time_dims: + dims[dim] = getattr(time, dim) + for dim in grid_dims + time_dims: + v = kwargs.pop(dim, None) + if v is not None: + dims[dim] = v + return kwargs, dims + + def component(maybe_cls: Optional[type[_IsAttrs]] = None) -> type[_Component]: """ Attach a data tree to an `attrs` class instance, and use @@ -362,28 +388,11 @@ def init(self, *args, **kwargs): children = kwargs.pop("children", None) parent = args[0] if args and any(args) else None - # use dims from grid and modeltime, if provided - dim_kwargs = {} - dims_used = set( - chain(*[var.metadata.get("dims", []) for var in spec.values()]) - ) - grid: Grid = kwargs.pop("grid", None) - time: ModelTime = kwargs.pop("time", None) - if grid: - grid_dims = ["nlay", "nrow", "ncol", "nnodes"] - for dim in grid_dims: - if dim in dims_used: - dim_kwargs[dim] = getattr(grid, dim) - if time: - time_dims = ["nper", "ntstp"] - for dim in time_dims: - if dim in dims_used: - dim_kwargs[dim] = getattr(time, dim) - # run the original __init__, then set up the tree + kwargs, dimensions = pop_dims(**kwargs) init_self(self, **kwargs) init_tree( - self, name=name, parent=parent, children=children, **dim_kwargs + self, name=name, parent=parent, children=children, **dimensions ) bind_tree(self, parent=parent, children=children) @@ -397,3 +406,7 @@ def init(self, *args, **kwargs): return wrap return wrap(maybe_cls) + + +# TODO: add separate `component()` decorator like `attrs.field()`? +# for now, "bind" metadata indicates subcomponent, not a variable. diff --git a/flopy4/mf6/__init__.py b/flopy4/mf6/__init__.py index 56c90d8e..e2b0ccd7 100644 --- a/flopy4/mf6/__init__.py +++ b/flopy4/mf6/__init__.py @@ -12,7 +12,8 @@ "Package", "Model", "Simulation", - "Sim", + "Solution", + "Exchange", "COMPONENTS", ] @@ -34,6 +35,17 @@ class Model(Component): pass +class Solution(Package): + pass + + +class Exchange(Package): + exgtype: type = field() + exgfile: Path = field() + exgmnamea: Optional[str] = field(default=None) + exgmnameb: Optional[str] = field(default=None) + + @component @define(slots=False, on_setattr=setattribute) class Tdis(Package): @@ -44,7 +56,11 @@ class PeriodData: tsmult: float = field(default=1.0) nper: int = field( - default=1, metadata={"block": "dimensions", "dim": {"coord": "kper"}} + default=1, + metadata={ + "block": "dimensions", + "dim": {"coord": "kper", "scope": "simulation"}, + }, ) perioddata: list[PeriodData] = field( default=Factory(list), @@ -58,26 +74,9 @@ class PeriodData: ) -class Solution(Package): - pass - - -class Exchange(Package): - exgtype: type = field() - exgfile: Path = field() - exgmnamea: Optional[str] = field(default=None) - exgmnameb: Optional[str] = field(default=None) - - -class Simulation(Component): - pass - - @component @define(init=False, slots=False) -class Sim(Simulation): - # "bind" indicates this is a subcomponent, not a variable. - # TODO: add separate `component()` decorator like `field`? +class Simulation(Component): models: dict[str, Model] = field(metadata={"bind": True}) exchanges: dict[str, Exchange] = field(metadata={"bind": True}) solutions: dict[str, Solution] = field(metadata={"bind": True}) diff --git a/flopy4/mf6/gwf/dis.py b/flopy4/mf6/gwf/dis.py index a629d13f..ead984aa 100644 --- a/flopy4/mf6/gwf/dis.py +++ b/flopy4/mf6/gwf/dis.py @@ -23,13 +23,25 @@ class Dis(Package): default=False, metadata={"block": "options"} ) nlay: int = field( - default=1, metadata={"block": "dimensions", "dim": {"coord": "k"}} + default=1, + metadata={ + "block": "dimensions", + "dim": {"coord": "k", "scope": "simulation"}, + }, ) ncol: int = field( - default=2, metadata={"block": "dimensions", "dim": {"coord": "i"}} + default=2, + metadata={ + "block": "dimensions", + "dim": {"coord": "i", "scope": "simulation"}, + }, ) nrow: int = field( - default=2, metadata={"block": "dimensions", "dim": {"coord": "j"}} + default=2, + metadata={ + "block": "dimensions", + "dim": {"coord": "j", "scope": "simulation"}, + }, ) delr: NDArray[np.floating] = field( default=1.0, diff --git a/flopy4/todo.md b/flopy4/todo.md index df5e4c5d..c7517d3f 100644 --- a/flopy4/todo.md +++ b/flopy4/todo.md @@ -6,6 +6,11 @@ Each component should get a view into the root (as far as it's aware) tree unless it's the root (i.e. simulation) itself, or it's not attached to any parent context, in which case it's the root of its own tree. +Currently it's massively duplicative, since each component +has a subtree of its own, next to the one its parent owns +and in which its tree appears. need to have a single tree +at the root, then each component's data is a view into it. + - subcomponent accessors I think for access by name we want dict style e.g. `gwf["chd1"]`, diff --git a/test/test_component.py b/test/test_component.py index 0618cead..c3f17358 100644 --- a/test/test_component.py +++ b/test/test_component.py @@ -3,12 +3,12 @@ from flopy.discretization.modeltime import ModelTime from xarray import DataTree -from flopy4.mf6 import COMPONENTS, Sim, Tdis +from flopy4.mf6 import COMPONENTS, Simulation, Tdis from flopy4.mf6.gwf import Dis, Gwf, Ic, Npf, Oc def test_registry(): - assert COMPONENTS["sim"] is Sim + assert COMPONENTS["simulation"] is Simulation assert COMPONENTS["tdis"] is Tdis assert COMPONENTS["gwf"] is Gwf assert COMPONENTS["npf"] is Npf @@ -17,7 +17,7 @@ def test_registry(): def test_init_top_down(): - sim = Sim() + sim = Simulation() tdis = Tdis(sim) gwf = Gwf(sim) dis = Dis(gwf) @@ -70,7 +70,7 @@ def test_init_bottom_up(): }, ) tdis = Tdis(time=time) - sim = Sim(children={"tdis": tdis, "gwf": gwf}) + sim = Simulation(children={"tdis": tdis, "gwf": gwf}) assert sim.tdis is tdis # TODO test autoincrement