Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flopy4/mf6/codec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
unstructure_component,
unstructure_oc,
)
from flopy4.mf6.spec import get_blocks

_JINJA_ENV = Environment(
loader=PackageLoader("flopy4.mf6"),
trim_blocks=True,
lstrip_blocks=True,
)
_JINJA_ENV.filters["blocks"] = get_blocks
_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks
_JINJA_ENV.filters["list_blocks"] = filters.list_blocks
_JINJA_ENV.filters["field_type"] = filters.field_type
_JINJA_ENV.filters["field_value"] = filters.field_value
_JINJA_ENV.filters["array_how"] = filters.array_how
Expand Down
65 changes: 23 additions & 42 deletions flopy4/mf6/codec/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,52 +146,33 @@ def unstructure_component(value: Component) -> dict[str, Any]:
def unstructure_oc(value: Any) -> dict[str, Any]:
data = xattree.asdict(value)
for block_name, block in get_blocks(value.dfn).items():
if block_name == "perioddata":
# Unstructure all four arrays
save_head = unstructure_array(data.get("save_head", {}))
save_budget = unstructure_array(data.get("save_budget", {}))
print_head = unstructure_array(data.get("print_head", {}))
print_budget = unstructure_array(data.get("print_budget", {}))

# Collect all unique periods
if block_name == "period":
# Dynamically collect all recarray fields in perioddata block
array_fields = []
for field_name, field in block.items():
# Try to split field_name into action and kind, e.g. save_head -> ("save", "head")
action, rtype = field_name.split("_")
array_fields.append((action, rtype, field_name))

# Unstructure all arrays and collect all unique periods
arrays = {}
all_periods = set() # type: ignore
for d in (save_head, save_budget, print_head, print_budget):
if isinstance(d, dict):
all_periods.update(d.keys())
for action, rtype, field_name in array_fields:
arr = unstructure_array(data.get(field_name, {}))
arrays[(action, rtype)] = arr
if isinstance(arr, dict):
all_periods.update(arr.keys())
all_periods = sorted(all_periods) # type: ignore

saverecord = {} # type: ignore
printrecord = {} # type: ignore
perioddata = {} # type: ignore
for kper in all_periods:
# Save head
if kper in save_head:
v = save_head[kper]
if kper not in saverecord:
saverecord[kper] = []
saverecord[kper].append({"action": "save", "type": "head", "ocsetting": v})
# Save budget
if kper in save_budget:
v = save_budget[kper]
if kper not in saverecord:
saverecord[kper] = []
saverecord[kper].append({"action": "save", "type": "budget", "ocsetting": v})
# Print head
if kper in print_head:
v = print_head[kper]
if kper not in printrecord:
printrecord[kper] = []
printrecord[kper].append({"action": "print", "type": "head", "ocsetting": v})
# Print budget
if kper in print_budget:
v = print_budget[kper]
if kper not in printrecord:
printrecord[kper] = []
printrecord[kper].append({"action": "print", "type": "budget", "ocsetting": v})

data["saverecord"] = saverecord
data["printrecord"] = printrecord
data["save"] = "save"
data["print"] = "print"
for (action, rtype), arr in arrays.items():
if kper in arr:
if kper not in perioddata:
perioddata[kper] = []
perioddata[kper].append((action, rtype, arr[kper]))

data["period"] = perioddata
else:
for field_name, field in block.items():
# unstructure arrays destined for list-based input
Expand Down
34 changes: 33 additions & 1 deletion flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,41 @@
import numpy as np
import xarray as xr
from jinja2 import pass_context
from modflow_devtools.dfn import Field
from modflow_devtools.dfn import Dfn, Field
from numpy.typing import NDArray

from flopy4.mf6.spec import get_blocks


def _is_list_block(block: dict) -> bool:
return (
len(block) == 1
and (field := next(iter(block.values())))["type"] == "recarray"
and field["reader"] != "readarray"
) or (all(f["type"] == "recarray" and f["reader"] != "readarray" for f in block.values()))


def dict_blocks(dfn: Dfn) -> dict:
"""
Get dictionary blocks from an MF6 input definition. A
dictionary block is a standard block which can contain
one or more fields, as opposed to a list block, which
may only contain one recarray field, using list input.
"""
x = {
block_name: block
for block_name, block in get_blocks(dfn).items()
if not _is_list_block(block)
}
return x


def list_blocks(dfn: Dfn) -> dict:
x = {
block_name: block for block_name, block in get_blocks(dfn).items() if _is_list_block(block)
}
return x


def field_type(field: Field) -> str:
"""
Expand Down
29 changes: 15 additions & 14 deletions flopy4/mf6/gwf/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
from attrs import Converter, define
from modflow_devtools.dfn import Dfn, Field
from modflow_devtools.dfn import Field
from numpy.typing import NDArray
from xattree import xattree

Expand Down Expand Up @@ -111,43 +111,44 @@ class Period:
format: Optional[Format] = field(block="options", default=None, init=False)
save_head: Optional[NDArray[np.object_]] = array(
Steps,
block="perioddata",
block="period",
default="all",
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)
save_budget: Optional[NDArray[np.object_]] = array(
Steps,
block="perioddata",
block="period",
default="all",
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)
print_head: Optional[NDArray[np.object_]] = array(
Steps,
block="perioddata",
block="period",
default="all",
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)
print_budget: Optional[NDArray[np.object_]] = array(
Steps,
block="perioddata",
block="period",
default="all",
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)

@classmethod
def get_dfn(cls) -> Dfn:
"""Generate the component's MODFLOW 6 definition."""
dfn = super().get_dfn()
for field_name in list(dfn["perioddata"].keys()):
dfn["perioddata"].pop(field_name)
dfn["perioddata"]["saverecord"] = _oc_action_field("save")
dfn["perioddata"]["printrecord"] = _oc_action_field("print")
return dfn
# original DFN
# @classmethod
# def get_dfn(cls) -> Dfn:
# """Generate the component's MODFLOW 6 definition."""
# dfn = super().get_dfn()
# for field_name in list(dfn["perioddata"].keys()):
# dfn["perioddata"].pop(field_name)
# dfn["perioddata"]["saverecord"] = _oc_action_field("save")
# dfn["perioddata"]["printrecord"] = _oc_action_field("print")
# return dfn
3 changes: 2 additions & 1 deletion flopy4/mf6/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def block_sort_key(item: tuple[str, dict]) -> int:
return 2
elif k == "packagedata":
return 3
elif k == "perioddata":
elif "period" in k:
# some packages have block "period", some have "perioddata"
return 4
else:
return 5
Expand Down
6 changes: 5 additions & 1 deletion flopy4/mf6/templates/blocks.jinja
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
{% import 'macros.jinja' as macros with context %}
{% for block_name, block_ in (dfn|blocks).items() %}
{% for block_name, block_ in (dfn|dict_blocks).items() %}
BEGIN {{ block_name.upper() }}
{% for field in block_.values() -%}
{{ macros.field(field) }}
{%- endfor %}
END {{ block_name.upper() }}

{% endfor %}

{% for block_name, block_ in (dfn|list_blocks).items() -%}
{{ macros.list(block_name, block_) }}
{%- endfor%}
40 changes: 26 additions & 14 deletions flopy4/mf6/templates/macros.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
{% elif type == 'keystring' %}
{{ keystring(f) }}
{% elif type == 'recarray' %}
{{ recarray(f) }}
{{ recarray(f, how=f|array_how) }}
{% endif %}
{% endmacro %}

Expand All @@ -29,16 +29,9 @@
{%- endfor %}
{% endmacro %}

{% macro recarray(f) %}
{% macro recarray(f, how="internal") %}
{% set name = f.name %}
{% set value = f|field_value %}
{% if f.reader != 'readarray' %}
{{ list(f) }}
{% else %}
{{ array(f.name, value, how=f|array_how) }}
{% endif %}
{% endmacro %}

{% macro array(name, value, how="internal") %}
{{ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}

{% if how == "constant" %}
Expand All @@ -57,9 +50,28 @@ OPEN/CLOSE {{ value }}
{% endif %}
{% endmacro %}

{% macro list(f) %}
{{ f }}
{% for item in f.children.values() %}
{{ field(item) }}
{% macro list(block_name, block) %}
{#
from mf6's perspective, a list block (e.g. period data)
always has just one variable, whose elements might be
records or unions. where we spin those out into arrays
for each individual leaf field to fit the xarray data
model, we have to combine them back here.

this macro receives the block definition. from that
it looks up the value of the one variable with the
same name as the block, which custom converter has
made sure exists in a sparse dict representation of
an array. we need to spin this out into a block for
each stress period.
#}
{% set dict = data[block_name] %}
{% for kper, value in dict.items() %}
BEGIN {{ block_name.upper() }} {{ kper }}
{% for line in value %}
{{ line|join(" ")|upper }}
{% endfor %}
END {{ block_name.upper() }} {{ kper }}

{% endfor %}
{% endmacro %}
1 change: 0 additions & 1 deletion test/test_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def test_dumps_ic():
assert result


@pytest.mark.xfail(reason="TODO period block unstructuring")
def test_dumps_oc():
from flopy4.mf6.gwf import Oc

Expand Down
Loading