Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions flopy4/mf6/codec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,22 @@
from jinja2 import Environment, PackageLoader

from flopy4.mf6 import filters
from flopy4.mf6.codec.converter import structure_array, unstructure_array
from flopy4.mf6.codec.converter import (
structure_array,
unstructure_array,
unstructure_component,
unstructure_oc,
)
from flopy4.mf6.spec import get_blocks

_JINJA_ENV = Environment(
loader=PackageLoader("flopy4.mf6"),
trim_blocks=True,
lstrip_blocks=True,
)
_JINJA_ENV.filters["blocks"] = filters.blocks
_JINJA_ENV.filters["blocks"] = get_blocks
_JINJA_ENV.filters["field_type"] = filters.field_type
_JINJA_ENV.filters["field_value"] = filters.field_value
_JINJA_ENV.filters["is_list"] = filters.is_list
_JINJA_ENV.filters["array_how"] = filters.array_how
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
_JINJA_ENV.filters["array2string"] = filters.array2string
Expand All @@ -31,10 +36,20 @@
"threshold": sys.maxsize,
}

_CONVERTER = Converter()
_CONVERTER.register_unstructure_hook_factory(
lambda cls: xattree.has(cls), lambda cls: xattree.asdict
)

def _make_converter() -> Converter:
from flopy4.mf6.component import Component
from flopy4.mf6.gwf.oc import Oc

converter = Converter()
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
converter.register_unstructure_hook(Component, unstructure_component)
converter.register_unstructure_hook(Oc, unstructure_oc)
return converter


_CONVERTER = _make_converter()


# TODO unstructure arrays into sparse dicts
# TODO combine OC fields into list input as defined in the MF6 dfn
Expand Down
85 changes: 78 additions & 7 deletions flopy4/mf6/codec/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import numpy as np
import sparse
import xattree
from numpy.typing import NDArray
from xarray import DataArray
from xattree import get_xatspec

from flopy4.mf6.component import Component
from flopy4.mf6.config import SPARSE_THRESHOLD
from flopy4.mf6.constants import FILL_DNODATA
from flopy4.mf6.spec import get_blocks


# TODO: convert to a cattrs structuring hook so we don't have to
Expand Down Expand Up @@ -108,22 +111,90 @@ def unstructure_array(value: DataArray) -> dict:
MF6 list-based input format.
"""
# make sure dim 'kper' is present
if "kper" not in value.dims:
raise ValueError("array must have 'kper' dimension")
time_dim = "nper"
if time_dim not in value.dims:
raise ValueError(f"Array must have dimension '{time_dim}'")

if isinstance(value.data, sparse.COO):
coords = value.coords
data = value.data
else:
coords = np.array(np.nonzero(value)).T # type: ignore
data = value[tuple(coords.T)] # type: ignore
coords = np.array(np.nonzero(value.data)).T # type: ignore
data = value.data[tuple(coords.T)] # type: ignore
if not coords.size: # type: ignore
return {}
match value.ndim:
case 1:
return {k: v for k, v in zip(coords[:, 0], data)} # type: ignore
return {int(k): v for k, v in zip(coords[:, 0], data)} # type: ignore
case 2:
return {(k, j): v for (k, j), v in zip(coords, data)} # type: ignore
return {(int(k), int(j)): v for (k, j), v in zip(coords, data)} # type: ignore
case 3:
return {(k, i, j): v for (k, i, j), v in zip(coords, data)} # type: ignore
return {(int(k), int(i), int(j)): v for (k, i, j), v in zip(coords, data)} # type: ignore
return {}


def unstructure_component(value: Component) -> dict[str, Any]:
data = xattree.asdict(value)
for block in get_blocks(value.dfn).values():
for field_name, field in block.items():
# unstructure arrays destined for list-based input
if field["type"] == "recarray" and field["reader"] != "readarray":
data[field_name] = unstructure_array(data[field_name])
return data


def unstructure_oc(value: Any) -> dict[str, Any]:
data = xattree.asdict(value)
for block_name, block in get_blocks(value.dfn).items():
if block_name == "perioddata":
# Unstructure all four arrays
save_head = unstructure_array(data.get("save_head", {}))
save_budget = unstructure_array(data.get("save_budget", {}))
print_head = unstructure_array(data.get("print_head", {}))
print_budget = unstructure_array(data.get("print_budget", {}))

# Collect all unique periods
all_periods = set() # type: ignore
for d in (save_head, save_budget, print_head, print_budget):
if isinstance(d, dict):
all_periods.update(d.keys())
all_periods = sorted(all_periods) # type: ignore

saverecord = {} # type: ignore
printrecord = {} # type: ignore
for kper in all_periods:
# Save head
if kper in save_head:
v = save_head[kper]
if kper not in saverecord:
saverecord[kper] = []
saverecord[kper].append({"action": "save", "type": "head", "ocsetting": v})
# Save budget
if kper in save_budget:
v = save_budget[kper]
if kper not in saverecord:
saverecord[kper] = []
saverecord[kper].append({"action": "save", "type": "budget", "ocsetting": v})
# Print head
if kper in print_head:
v = print_head[kper]
if kper not in printrecord:
printrecord[kper] = []
printrecord[kper].append({"action": "print", "type": "head", "ocsetting": v})
# Print budget
if kper in print_budget:
v = print_budget[kper]
if kper not in printrecord:
printrecord[kper] = []
printrecord[kper].append({"action": "print", "type": "budget", "ocsetting": v})

data["saverecord"] = saverecord
data["printrecord"] = printrecord
data["save"] = "save"
data["print"] = "print"
else:
for field_name, field in block.items():
# unstructure arrays destined for list-based input
if field["type"] == "recarray" and field["reader"] != "readarray":
data[field_name] = unstructure_array(data[field_name])
return data
3 changes: 3 additions & 0 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC
from collections.abc import MutableMapping
from pathlib import Path
from typing import ClassVar

from modflow_devtools.dfn import Dfn, Field
from xattree import xattree
Expand Down Expand Up @@ -33,6 +34,8 @@ class Component(ABC, MutableMapping):

filename: str = field(default=None)

dfn: ClassVar[Dfn]

@property
def path(self) -> Path:
return Path.cwd() / self.filename
Expand Down
22 changes: 1 addition & 21 deletions flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,12 @@
from collections.abc import Hashable, Mapping
from io import StringIO
from typing import Any

import numpy as np
import xarray as xr
from jinja2 import pass_context
from modflow_devtools.dfn import Dfn, Field
from modflow_devtools.dfn import Field
from numpy.typing import NDArray

from flopy4.mf6.spec import block_sort_key


def blocks(dfn: Dfn) -> dict:
"""
Get blocks from an MF6 input definition. Anything not an
explicitly defined key in the `Dfn` typed dict is a block.
"""
return dict(
sorted(
{k: v for k, v in dfn.items() if k not in Dfn.__annotations__}.items(),
key=block_sort_key,
)
)


def field_type(field: Field) -> str:
"""
Expand Down Expand Up @@ -105,7 +89,3 @@ def array2string(value: NDArray) -> str:
)
np.savetxt(buffer, value, fmt=format, delimiter=" ")
return buffer.getvalue().strip()


def is_list(value: Any) -> bool:
return isinstance(value, list)
73 changes: 73 additions & 0 deletions flopy4/mf6/gwf/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from attrs import Converter, define
from modflow_devtools.dfn import Dfn, Field
from numpy.typing import NDArray
from xattree import xattree

Expand All @@ -11,6 +12,64 @@
from flopy4.mf6.spec import array, field
from flopy4.utils import to_path

_OCSETTING = Field(
name="ocsetting",
type="keystring",
reader="urword",
children={
"all": Field(
name="all",
type="keyword",
reader="urword",
),
"first": Field(
name="first",
type="keyword",
reader="urword",
),
"last": Field(
name="last",
type="keyword",
reader="urword",
),
"steps": Field(
name="steps",
type="integer",
reader="urword",
),
"frequency": Field(
name="frequency",
type="integer",
reader="urword",
),
},
)

_RTYPE = Field(
name="rtype",
type="string",
reader="urword",
)


def _oc_action_field(action: str) -> Field:
return Field(
name=f"{action}record",
type="recarray",
dims=("nper",),
block="perioddata",
reader="urword",
children={
action: Field(
name=action,
type="keyword",
reader="urword",
),
"rtype": _RTYPE,
"ocsetting": _OCSETTING,
},
)


@xattree
class Oc(Package):
Expand Down Expand Up @@ -56,25 +115,39 @@ class Period:
default="all",
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)
save_budget: Optional[NDArray[np.object_]] = array(
Steps,
block="perioddata",
default="all",
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)
print_head: Optional[NDArray[np.object_]] = array(
Steps,
block="perioddata",
default="all",
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)
print_budget: Optional[NDArray[np.object_]] = array(
Steps,
block="perioddata",
default="all",
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)

@classmethod
def get_dfn(cls) -> Dfn:
"""Generate the component's MODFLOW 6 definition."""
dfn = super().get_dfn()
for field_name in list(dfn["perioddata"].keys()):
dfn["perioddata"].pop(field_name)
dfn["perioddata"]["saverecord"] = _oc_action_field("save")
dfn["perioddata"]["printrecord"] = _oc_action_field("print")
return dfn
21 changes: 20 additions & 1 deletion flopy4/mf6/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
from attrs import NOTHING, Attribute
from modflow_devtools.dfn import Field, FieldType
from modflow_devtools.dfn import Dfn, Field, FieldType, Reader

from flopy4.spec import array as flopy_array
from flopy4.spec import coord as flopy_coord
Expand All @@ -32,6 +32,7 @@ def field(
if block:
metadata = metadata or {}
metadata["block"] = block
metadata["reader"] = "urword"
return flopy_field(
default=default,
validator=validator,
Expand All @@ -57,6 +58,7 @@ def dim(
if block:
metadata = metadata or {}
metadata["block"] = block
metadata["reader"] = "urword"
return flopy_dim(
scope=scope,
coord=coord,
Expand All @@ -80,6 +82,7 @@ def coord(
if block:
metadata = metadata or {}
metadata["block"] = block
metadata["reader"] = "readarray"
return flopy_coord(
scope=scope,
default=default,
Expand All @@ -99,11 +102,13 @@ def array(
eq=None,
metadata=None,
block: str | None = None,
reader: Reader = "readarray",
):
"""Define an array field."""
if block:
metadata = metadata or {}
metadata["block"] = block
metadata["reader"] = reader
return flopy_array(
cls=cls,
dims=dims,
Expand Down Expand Up @@ -227,4 +232,18 @@ def to_dfn_field(attribute: Attribute) -> Field:
children={k: to_dfn_field(v) for k, v in fields_dict(attribute.type)} # type: ignore
if attribute.metadata.get("kind", None) == "child" # type: ignore
else None, # type: ignore
reader=attribute.metadata.get("reader", "urword"),
)


def get_blocks(dfn: Dfn) -> dict:
"""
Get blocks from an MF6 input definition. Anything not an
explicitly defined key in the `Dfn` typed dict is a block.
"""
return dict(
sorted(
{k: v for k, v in dfn.items() if k not in Dfn.__annotations__}.items(),
key=block_sort_key,
)
)
Loading
Loading