Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion flopy4/mf6/codec/writer/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from io import StringIO
from typing import Any, Literal

import attrs
import numpy as np
import xarray as xr
from modflow_devtools.dfn.schema.v2 import FieldType
Expand Down Expand Up @@ -202,7 +203,8 @@ def dataset2list(value: xr.Dataset):
return

# special case OC for now.
is_oc = all(
# TODO remove after properly handling object dtype period data arrays
is_oc = any(
str(v.name).startswith("save_") or str(v.name).startswith("print_")
for v in value.data_vars.values()
)
Expand All @@ -211,9 +213,17 @@ def dataset2list(value: xr.Dataset):
if (first := next(iter(value.data_vars.values()))).ndim == 0:
if is_oc:
for name in value.data_vars.keys():
if not (name.startswith("save_") or name.startswith("print_")):
# TODO: not working yet
if name == "perioddata":
val = value[name]
val = val.item() if val.shape == () else val
yield attrs.astuple(val, recurse=True)
continue
val = value[name]
val = val.item() if val.shape == () else val
yield (*name.split("_"), val)

else:
vals = []
for name in value.data_vars.keys():
Expand Down
51 changes: 51 additions & 0 deletions flopy4/mf6/converter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from pathlib import Path
from typing import Any

import cattr
import xattree
from cattr import Converter
from cattrs.gen import make_hetero_tuple_unstructure_fn

from flopy4.mf6.component import Component
from flopy4.mf6.context import Context
from flopy4.mf6.converter.structure import structure_array
from flopy4.mf6.converter.unstructure import (
unstructure_component,
)
from flopy4.mf6.gwf.oc import Oc

__all__ = [
"structure",
"unstructure",
"structure_array",
"unstructure_array",
"COMPONENT_CONVERTER",
]


def _make_converter() -> Converter:
converter = Converter(unstruct_strat=cattr.UnstructureStrategy.AS_TUPLE)
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
converter.register_unstructure_hook(Component, unstructure_component)
converter.register_unstructure_hook(
Oc.PrintSaveSetting, make_hetero_tuple_unstructure_fn(Oc.PrintSaveSetting, converter)
)
converter.register_unstructure_hook(
Oc.Steps, make_hetero_tuple_unstructure_fn(Oc.Steps, converter)
)
return converter


COMPONENT_CONVERTER = _make_converter()


def structure(data: dict[str, Any], path: Path) -> Component:
component = COMPONENT_CONVERTER.structure(data, Component)
if isinstance(component, Context):
component.workspace = path.parent
component.filename = path.name
return component


def unstructure(component: Component) -> dict[str, Any]:
return COMPONENT_CONVERTER.unstructure(component)
83 changes: 83 additions & 0 deletions flopy4/mf6/converter/structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Any

import numpy as np
import sparse
from numpy.typing import NDArray
from xattree import get_xatspec

from flopy4.adapters import get_nn
from flopy4.mf6.config import SPARSE_THRESHOLD
from flopy4.mf6.constants import FILL_DNODATA


def structure_array(value, self_, field) -> NDArray:
"""
Convert a sparse dictionary representation of an array to a
dense numpy array or a sparse COO array.

TODO: generalize this not only to dictionaries but to any
form that can be converted to an array (e.g. nested list)
"""

if not isinstance(value, dict):
# if not a dict, assume it's a numpy array
# and let xarray deal with it if it isn't
return value

spec = get_xatspec(type(self_)).flat
field = spec[field.name]
if not field.dims:
raise ValueError(f"Field {field} missing dims")

# resolve dims
explicit_dims = self_.__dict__.get("dims", {})
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
dims = inherited_dims | explicit_dims
shape = [dims.get(d, d) for d in field.dims]
unresolved = [d for d in shape if isinstance(d, str)]
if any(unresolved):
raise ValueError(f"Couldn't resolve dims: {unresolved}")

if np.prod(shape) > SPARSE_THRESHOLD:
a: dict[tuple[Any, ...], Any] = dict()

def set_(arr, val, *ind):
arr[tuple(ind)] = val

def final(arr):
coords = np.array(list(map(list, zip(*arr.keys()))))
return sparse.COO(
coords,
list(arr.values()),
shape=shape,
fill_value=field.default or FILL_DNODATA,
)
else:
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore

def set_(arr, val, *ind):
arr[ind] = val

def final(arr):
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
return arr

if "nper" in dims:
for kper, period in value.items():
if kper == "*":
kper = 0
match len(shape):
case 1:
set_(a, period, kper)
case _:
for cellid, v in period.items():
nn = get_nn(cellid, **dims)
set_(a, v, kper, nn)
if kper == "*":
break
else:
for cellid, v in value.items():
nn = get_nn(cellid, **dims)
set_(a, v, nn)

return final(a)
112 changes: 5 additions & 107 deletions flopy4/mf6/converter.py → flopy4/mf6/converter/unstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,17 @@
from pathlib import Path
from typing import Any

import numpy as np
import sparse
import xarray as xr
import xattree
from cattrs import Converter
from modflow_devtools.dfn.schema.block import block_sort_key
from numpy.typing import NDArray
from xattree import get_xatspec

from flopy4.adapters import get_nn
from flopy4.mf6.binding import Binding
from flopy4.mf6.component import Component
from flopy4.mf6.config import SPARSE_THRESHOLD
from flopy4.mf6.constants import FILL_DNODATA
from flopy4.mf6.context import Context
from flopy4.mf6.spec import FileInOut


def path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
def _path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
t = [name.upper()]
if name.endswith("_file"):
t[0] = name.replace("_file", "").upper()
Expand All @@ -31,7 +23,7 @@ def path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
return tuple(t)


def make_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
def _make_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
if not isinstance(value, Context):
return {}

Expand Down Expand Up @@ -104,7 +96,7 @@ def unstructure_component(value: Component) -> dict[str, Any]:
data = xattree.asdict(value)

# create child component binding blocks
blocks.update(make_binding_blocks(value))
blocks.update(_make_binding_blocks(value))

# process blocks in order, unstructuring fields as needed,
# then slice period data into separate kper-indexed blocks
Expand Down Expand Up @@ -132,6 +124,7 @@ def unstructure_component(value: Component) -> dict[str, Any]:
# - 'auxiliary' fields to tuples
# - xarray DataArrays with 'nper' dim to dict of kper-sliced datasets
# - other values to their original form
# TODO: use cattrs converters for field unstructuring?
match field_value := data[field_name]:
case None:
continue
Expand All @@ -141,7 +134,7 @@ def unstructure_component(value: Component) -> dict[str, Any]:
case Path():
field_spec = xatspec.attrs[field_name]
field_meta = getattr(field_spec, "metadata", {})
t = path_to_tuple(
t = _path_to_tuple(
field_name, field_value, inout=field_meta.get("inout", "fileout")
)
# name may have changed e.g dropping '_file' suffix
Expand Down Expand Up @@ -197,98 +190,3 @@ def unstructure_component(value: Component) -> dict[str, Any]:
del blocks["solutiongroup"]

return {name: block for name, block in blocks.items() if name != period_block_name}


def _make_converter() -> Converter:
converter = Converter()
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
converter.register_unstructure_hook(Component, unstructure_component)
return converter


COMPONENT_CONVERTER = _make_converter()


def dict_to_array(value, self_, field) -> NDArray:
"""
Convert a sparse dictionary representation of an array to a
dense numpy array or a sparse COO array.

TODO: generalize this not only to dictionaries but to any
form that can be converted to an array (e.g. nested list)
"""

if not isinstance(value, dict):
# if not a dict, assume it's a numpy array
# and let xarray deal with it if it isn't
return value

spec = get_xatspec(type(self_)).flat
field = spec[field.name]
if not field.dims:
raise ValueError(f"Field {field} missing dims")

# resolve dims
explicit_dims = self_.__dict__.get("dims", {})
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
dims = inherited_dims | explicit_dims
shape = [dims.get(d, d) for d in field.dims]
unresolved = [d for d in shape if isinstance(d, str)]
if any(unresolved):
raise ValueError(f"Couldn't resolve dims: {unresolved}")

if np.prod(shape) > SPARSE_THRESHOLD:
a: dict[tuple[Any, ...], Any] = dict()

def set_(arr, val, *ind):
arr[tuple(ind)] = val

def final(arr):
coords = np.array(list(map(list, zip(*arr.keys()))))
return sparse.COO(
coords,
list(arr.values()),
shape=shape,
fill_value=field.default or FILL_DNODATA,
)
else:
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore

def set_(arr, val, *ind):
arr[ind] = val

def final(arr):
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
return arr

if "nper" in dims:
for kper, period in value.items():
if kper == "*":
kper = 0
match len(shape):
case 1:
set_(a, period, kper)
case _:
for cellid, v in period.items():
nn = get_nn(cellid, **dims)
set_(a, v, kper, nn)
if kper == "*":
break
else:
for cellid, v in value.items():
nn = get_nn(cellid, **dims)
set_(a, v, nn)

return final(a)


def structure(data: dict[str, Any], path: Path) -> Component:
component = COMPONENT_CONVERTER.structure(data, Component)
if isinstance(component, Context):
component.workspace = path.parent
component.filename = path.name
return component


def unstructure(component: Component) -> dict[str, Any]:
return COMPONENT_CONVERTER.unstructure(component)
8 changes: 4 additions & 4 deletions flopy4/mf6/gwf/chd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from xattree import xattree

from flopy4.mf6.constants import LENBOUNDNAME
from flopy4.mf6.converter import dict_to_array
from flopy4.mf6.converter import structure_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field, path
from flopy4.mf6.utils.grid_utils import update_maxbound
Expand Down Expand Up @@ -38,7 +38,7 @@ class Chd(Package):
"nodes",
),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
on_setattr=update_maxbound,
)
aux: Optional[NDArray[np.float64]] = array(
Expand All @@ -48,7 +48,7 @@ class Chd(Package):
"nodes",
),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
on_setattr=update_maxbound,
)
boundname: Optional[NDArray[np.str_]] = array(
Expand All @@ -59,6 +59,6 @@ class Chd(Package):
"nodes",
),
default=None,
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
on_setattr=update_maxbound,
)
Loading
Loading