Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 54 additions & 41 deletions flopy4/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Mapping
from itertools import chain
from pathlib import Path
from typing import Annotated, Any, Optional, get_origin

Expand Down Expand Up @@ -97,7 +96,7 @@ def resolve_array(
tree: DataTree = None,
strict: bool = False,
**kwargs,
) -> Optional[NDArray]:
) -> tuple[Optional[NDArray], Optional[dict[str, NDArray]]]:
"""
Resolve an array-like value to the given variable's expected shape.
If the value is a collection, check if the shape matches. If scalar,
Expand All @@ -118,15 +117,15 @@ def resolve_array(
f"Component class '{type(self).__name__}' array "
f"variable '{attr.name}' could not be resolved "
)
return None
return None, None
dims = attr.metadata.get("dims", None)
if not dims:
if strict:
raise ValueError(
f"Component class '{type(self).__name__}' array "
f"variable '{attr.name}' needs 'dims' metadata"
)
return None
return None, None
shape = [find(tree or DataTree(), key=dim, default=dim) for dim in dims]
shape = tuple(
[
Expand All @@ -142,8 +141,10 @@ def resolve_array(
f"variable '{attr.name}' failed dim resolution: "
f"{', '.join(unresolved)}"
)
return None
return reshape_array(value, shape)
return None, None
array = reshape_array(value, shape)
coords = {dim: np.arange(size) for dim, size in zip(dims, shape)}
return array, coords


def bind_tree(
Expand All @@ -160,10 +161,11 @@ def bind_tree(
data tree, as well as to any non-`attrs` attributes whose
name matches a child's name.

TODO: this is massively duplicative, since each component
has a subtree of its own, next to the one its parent owns
and in which its tree appears. need to have a single tree
at the root, then each component's data is a view into it.
TODO: discover dimensions from self, parent and children.
If the parent defines a dimension, it should be used for
self and children. If a dimension is found in self or in
a child which has scope broader than self, send it up to
the parent.
"""

cls = type(self)
Expand All @@ -172,30 +174,27 @@ def bind_tree(

# bind parent
if parent:
# try binding first by name
# bind to parent attrs whose name
# matches this component's name
parent_spec = fields_dict(type(parent))
parent_var = parent_spec.get(name, None)
if parent_var:
assert parent_var.metadata.get("bind", False)
setattr(parent, name, self)
# TODO bind multipackages by type
# parent_bindings = {
# n: v
# for n, v in parent_spec.items()
# if v.metadata.get("bind", False)
# }
# print(parent_bindings)

# bind parent data tree
if name in parent.data:
parent.data.update({name: self.data})
else:
parent.data = parent.data.assign({name: self.data})
self.data = parent.data[self.data.name]

# bind grandparent
# bind grandparent recursively
grandparent = getattr(parent, "parent", None)
if grandparent is not None:
bind_tree(parent, parent=grandparent)

# update parent reference
self.parent = parent

# bind children
Expand Down Expand Up @@ -234,11 +233,12 @@ def init_tree(
cls = type(self)
spec = fields_dict(cls)
dimensions = set()
coordinates = {}
components = {}
array_vars = {}
scalar_vars = {}
array_vals = {}
scalar_vals = {}
components = {}

for var in spec.values():
bind = var.metadata.get("bind", False)
Expand All @@ -264,21 +264,23 @@ def _yield_scalars(spec, vals):
def _yield_arrays(spec, vals):
for var in spec.values():
dims = var.metadata["dims"]
val = resolve_array(
val, coords = resolve_array(
self,
var,
value=vals.pop(var.name, var.default),
tree=parent.data.root if parent else None,
**{**scalar_vals, **kwargs},
)
if val is not None:
coordinates.update(coords)
yield (var.name, (dims, val))

array_vals = dict(list(_yield_arrays(spec=array_vars, vals=self.__dict__)))

self.data = DataTree(
Dataset(
data_vars=array_vals,
coords=coordinates,
attrs={
n: v for n, v in scalar_vals.items() if n not in dimensions
},
Expand Down Expand Up @@ -332,7 +334,7 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
return value
if get_origin(attr.type) in [list, np.ndarray]:
shape = attr.metadata["dims"]
value = resolve_array(self, attr, value)
value, _ = resolve_array(self, attr, value)
value = (shape, value)
bind = attr.metadata.get("bind", False)
if bind:
Expand All @@ -341,6 +343,30 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
self.data.update({attr.name: value})


def pop_dims(**kwargs):
"""
Use dims from `Grid` and/or `ModelTime` instances
passed to `grid` and `time` keyword arguments, if
available.
"""
dims = {}
grid: Grid = kwargs.pop("grid", None)
time: ModelTime = kwargs.pop("time", None)
grid_dims = ["nlay", "nrow", "ncol", "nnodes"]
time_dims = ["nper", "nstp"]
if grid:
for dim in grid_dims:
dims[dim] = getattr(grid, dim)
if time:
for dim in time_dims:
dims[dim] = getattr(time, dim)
for dim in grid_dims + time_dims:
v = kwargs.pop(dim, None)
if v is not None:
dims[dim] = v
return kwargs, dims


def component(maybe_cls: Optional[type[_IsAttrs]] = None) -> type[_Component]:
"""
Attach a data tree to an `attrs` class instance, and use
Expand All @@ -362,28 +388,11 @@ def init(self, *args, **kwargs):
children = kwargs.pop("children", None)
parent = args[0] if args and any(args) else None

# use dims from grid and modeltime, if provided
dim_kwargs = {}
dims_used = set(
chain(*[var.metadata.get("dims", []) for var in spec.values()])
)
grid: Grid = kwargs.pop("grid", None)
time: ModelTime = kwargs.pop("time", None)
if grid:
grid_dims = ["nlay", "nrow", "ncol", "nnodes"]
for dim in grid_dims:
if dim in dims_used:
dim_kwargs[dim] = getattr(grid, dim)
if time:
time_dims = ["nper", "ntstp"]
for dim in time_dims:
if dim in dims_used:
dim_kwargs[dim] = getattr(time, dim)

# run the original __init__, then set up the tree
kwargs, dimensions = pop_dims(**kwargs)
init_self(self, **kwargs)
init_tree(
self, name=name, parent=parent, children=children, **dim_kwargs
self, name=name, parent=parent, children=children, **dimensions
)
bind_tree(self, parent=parent, children=children)

Expand All @@ -397,3 +406,7 @@ def init(self, *args, **kwargs):
return wrap

return wrap(maybe_cls)


# TODO: add separate `component()` decorator like `attrs.field()`?
# for now, "bind" metadata indicates subcomponent, not a variable.
39 changes: 19 additions & 20 deletions flopy4/mf6/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"Package",
"Model",
"Simulation",
"Sim",
"Solution",
"Exchange",
"COMPONENTS",
]

Expand All @@ -34,6 +35,17 @@ class Model(Component):
pass


class Solution(Package):
pass


class Exchange(Package):
exgtype: type = field()
exgfile: Path = field()
exgmnamea: Optional[str] = field(default=None)
exgmnameb: Optional[str] = field(default=None)


@component
@define(slots=False, on_setattr=setattribute)
class Tdis(Package):
Expand All @@ -44,7 +56,11 @@ class PeriodData:
tsmult: float = field(default=1.0)

nper: int = field(
default=1, metadata={"block": "dimensions", "dim": {"coord": "kper"}}
default=1,
metadata={
"block": "dimensions",
"dim": {"coord": "kper", "scope": "simulation"},
},
)
perioddata: list[PeriodData] = field(
default=Factory(list),
Expand All @@ -58,26 +74,9 @@ class PeriodData:
)


class Solution(Package):
pass


class Exchange(Package):
exgtype: type = field()
exgfile: Path = field()
exgmnamea: Optional[str] = field(default=None)
exgmnameb: Optional[str] = field(default=None)


class Simulation(Component):
pass


@component
@define(init=False, slots=False)
class Sim(Simulation):
# "bind" indicates this is a subcomponent, not a variable.
# TODO: add separate `component()` decorator like `field`?
class Simulation(Component):
models: dict[str, Model] = field(metadata={"bind": True})
exchanges: dict[str, Exchange] = field(metadata={"bind": True})
solutions: dict[str, Solution] = field(metadata={"bind": True})
Expand Down
18 changes: 15 additions & 3 deletions flopy4/mf6/gwf/dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,25 @@ class Dis(Package):
default=False, metadata={"block": "options"}
)
nlay: int = field(
default=1, metadata={"block": "dimensions", "dim": {"coord": "k"}}
default=1,
metadata={
"block": "dimensions",
"dim": {"coord": "k", "scope": "simulation"},
},
)
ncol: int = field(
default=2, metadata={"block": "dimensions", "dim": {"coord": "i"}}
default=2,
metadata={
"block": "dimensions",
"dim": {"coord": "i", "scope": "simulation"},
},
)
nrow: int = field(
default=2, metadata={"block": "dimensions", "dim": {"coord": "j"}}
default=2,
metadata={
"block": "dimensions",
"dim": {"coord": "j", "scope": "simulation"},
},
)
delr: NDArray[np.floating] = field(
default=1.0,
Expand Down
5 changes: 5 additions & 0 deletions flopy4/todo.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Each component should get a view into the root (as far as it's aware) tree
unless it's the root (i.e. simulation) itself, or it's not attached to any
parent context, in which case it's the root of its own tree.

Currently it's massively duplicative, since each component
has a subtree of its own, next to the one its parent owns
and in which its tree appears. need to have a single tree
at the root, then each component's data is a view into it.

- subcomponent accessors

I think for access by name we want dict style e.g. `gwf["chd1"]`,
Expand Down
8 changes: 4 additions & 4 deletions test/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from flopy.discretization.modeltime import ModelTime
from xarray import DataTree

from flopy4.mf6 import COMPONENTS, Sim, Tdis
from flopy4.mf6 import COMPONENTS, Simulation, Tdis
from flopy4.mf6.gwf import Dis, Gwf, Ic, Npf, Oc


def test_registry():
assert COMPONENTS["sim"] is Sim
assert COMPONENTS["simulation"] is Simulation
assert COMPONENTS["tdis"] is Tdis
assert COMPONENTS["gwf"] is Gwf
assert COMPONENTS["npf"] is Npf
Expand All @@ -17,7 +17,7 @@ def test_registry():


def test_init_top_down():
sim = Sim()
sim = Simulation()
tdis = Tdis(sim)
gwf = Gwf(sim)
dis = Dis(gwf)
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_init_bottom_up():
},
)
tdis = Tdis(time=time)
sim = Sim(children={"tdis": tdis, "gwf": gwf})
sim = Simulation(children={"tdis": tdis, "gwf": gwf})

assert sim.tdis is tdis
# TODO test autoincrement
Expand Down
Loading