diff --git a/flopy4/mf6/codec/__init__.py b/flopy4/mf6/codec/__init__.py index c5f3faba..34f14966 100644 --- a/flopy4/mf6/codec/__init__.py +++ b/flopy4/mf6/codec/__init__.py @@ -14,14 +14,14 @@ unstructure_component, unstructure_oc, ) -from flopy4.mf6.spec import get_blocks _JINJA_ENV = Environment( loader=PackageLoader("flopy4.mf6"), trim_blocks=True, lstrip_blocks=True, ) -_JINJA_ENV.filters["blocks"] = get_blocks +_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks +_JINJA_ENV.filters["list_blocks"] = filters.list_blocks _JINJA_ENV.filters["field_type"] = filters.field_type _JINJA_ENV.filters["field_value"] = filters.field_value _JINJA_ENV.filters["array_how"] = filters.array_how diff --git a/flopy4/mf6/codec/converter.py b/flopy4/mf6/codec/converter.py index 5ff838bf..b614d713 100644 --- a/flopy4/mf6/codec/converter.py +++ b/flopy4/mf6/codec/converter.py @@ -146,52 +146,33 @@ def unstructure_component(value: Component) -> dict[str, Any]: def unstructure_oc(value: Any) -> dict[str, Any]: data = xattree.asdict(value) for block_name, block in get_blocks(value.dfn).items(): - if block_name == "perioddata": - # Unstructure all four arrays - save_head = unstructure_array(data.get("save_head", {})) - save_budget = unstructure_array(data.get("save_budget", {})) - print_head = unstructure_array(data.get("print_head", {})) - print_budget = unstructure_array(data.get("print_budget", {})) - - # Collect all unique periods + if block_name == "period": + # Dynamically collect all recarray fields in perioddata block + array_fields = [] + for field_name, field in block.items(): + # Try to split field_name into action and kind, e.g. save_head -> ("save", "head") + action, rtype = field_name.split("_") + array_fields.append((action, rtype, field_name)) + + # Unstructure all arrays and collect all unique periods + arrays = {} all_periods = set() # type: ignore - for d in (save_head, save_budget, print_head, print_budget): - if isinstance(d, dict): - all_periods.update(d.keys()) + for action, rtype, field_name in array_fields: + arr = unstructure_array(data.get(field_name, {})) + arrays[(action, rtype)] = arr + if isinstance(arr, dict): + all_periods.update(arr.keys()) all_periods = sorted(all_periods) # type: ignore - saverecord = {} # type: ignore - printrecord = {} # type: ignore + perioddata = {} # type: ignore for kper in all_periods: - # Save head - if kper in save_head: - v = save_head[kper] - if kper not in saverecord: - saverecord[kper] = [] - saverecord[kper].append({"action": "save", "type": "head", "ocsetting": v}) - # Save budget - if kper in save_budget: - v = save_budget[kper] - if kper not in saverecord: - saverecord[kper] = [] - saverecord[kper].append({"action": "save", "type": "budget", "ocsetting": v}) - # Print head - if kper in print_head: - v = print_head[kper] - if kper not in printrecord: - printrecord[kper] = [] - printrecord[kper].append({"action": "print", "type": "head", "ocsetting": v}) - # Print budget - if kper in print_budget: - v = print_budget[kper] - if kper not in printrecord: - printrecord[kper] = [] - printrecord[kper].append({"action": "print", "type": "budget", "ocsetting": v}) - - data["saverecord"] = saverecord - data["printrecord"] = printrecord - data["save"] = "save" - data["print"] = "print" + for (action, rtype), arr in arrays.items(): + if kper in arr: + if kper not in perioddata: + perioddata[kper] = [] + perioddata[kper].append((action, rtype, arr[kper])) + + data["period"] = perioddata else: for field_name, field in block.items(): # unstructure arrays destined for list-based input diff --git a/flopy4/mf6/filters.py b/flopy4/mf6/filters.py index 96cc5daf..5d5f0739 100644 --- a/flopy4/mf6/filters.py +++ b/flopy4/mf6/filters.py @@ -4,9 +4,41 @@ import numpy as np import xarray as xr from jinja2 import pass_context -from modflow_devtools.dfn import Field +from modflow_devtools.dfn import Dfn, Field from numpy.typing import NDArray +from flopy4.mf6.spec import get_blocks + + +def _is_list_block(block: dict) -> bool: + return ( + len(block) == 1 + and (field := next(iter(block.values())))["type"] == "recarray" + and field["reader"] != "readarray" + ) or (all(f["type"] == "recarray" and f["reader"] != "readarray" for f in block.values())) + + +def dict_blocks(dfn: Dfn) -> dict: + """ + Get dictionary blocks from an MF6 input definition. A + dictionary block is a standard block which can contain + one or more fields, as opposed to a list block, which + may only contain one recarray field, using list input. + """ + x = { + block_name: block + for block_name, block in get_blocks(dfn).items() + if not _is_list_block(block) + } + return x + + +def list_blocks(dfn: Dfn) -> dict: + x = { + block_name: block for block_name, block in get_blocks(dfn).items() if _is_list_block(block) + } + return x + def field_type(field: Field) -> str: """ diff --git a/flopy4/mf6/gwf/oc.py b/flopy4/mf6/gwf/oc.py index ef600320..6bcae76a 100644 --- a/flopy4/mf6/gwf/oc.py +++ b/flopy4/mf6/gwf/oc.py @@ -3,7 +3,7 @@ import numpy as np from attrs import Converter, define -from modflow_devtools.dfn import Dfn, Field +from modflow_devtools.dfn import Field from numpy.typing import NDArray from xattree import xattree @@ -111,7 +111,7 @@ class Period: format: Optional[Format] = field(block="options", default=None, init=False) save_head: Optional[NDArray[np.object_]] = array( Steps, - block="perioddata", + block="period", default="all", dims=("nper",), converter=Converter(structure_array, takes_self=True, takes_field=True), @@ -119,7 +119,7 @@ class Period: ) save_budget: Optional[NDArray[np.object_]] = array( Steps, - block="perioddata", + block="period", default="all", dims=("nper",), converter=Converter(structure_array, takes_self=True, takes_field=True), @@ -127,7 +127,7 @@ class Period: ) print_head: Optional[NDArray[np.object_]] = array( Steps, - block="perioddata", + block="period", default="all", dims=("nper",), converter=Converter(structure_array, takes_self=True, takes_field=True), @@ -135,19 +135,20 @@ class Period: ) print_budget: Optional[NDArray[np.object_]] = array( Steps, - block="perioddata", + block="period", default="all", dims=("nper",), converter=Converter(structure_array, takes_self=True, takes_field=True), reader="urword", ) - @classmethod - def get_dfn(cls) -> Dfn: - """Generate the component's MODFLOW 6 definition.""" - dfn = super().get_dfn() - for field_name in list(dfn["perioddata"].keys()): - dfn["perioddata"].pop(field_name) - dfn["perioddata"]["saverecord"] = _oc_action_field("save") - dfn["perioddata"]["printrecord"] = _oc_action_field("print") - return dfn + # original DFN + # @classmethod + # def get_dfn(cls) -> Dfn: + # """Generate the component's MODFLOW 6 definition.""" + # dfn = super().get_dfn() + # for field_name in list(dfn["perioddata"].keys()): + # dfn["perioddata"].pop(field_name) + # dfn["perioddata"]["saverecord"] = _oc_action_field("save") + # dfn["perioddata"]["printrecord"] = _oc_action_field("print") + # return dfn diff --git a/flopy4/mf6/spec.py b/flopy4/mf6/spec.py index 63bcb40e..4a56fe1b 100644 --- a/flopy4/mf6/spec.py +++ b/flopy4/mf6/spec.py @@ -134,7 +134,8 @@ def block_sort_key(item: tuple[str, dict]) -> int: return 2 elif k == "packagedata": return 3 - elif k == "perioddata": + elif "period" in k: + # some packages have block "period", some have "perioddata" return 4 else: return 5 diff --git a/flopy4/mf6/templates/blocks.jinja b/flopy4/mf6/templates/blocks.jinja index 0acd2920..07392dac 100644 --- a/flopy4/mf6/templates/blocks.jinja +++ b/flopy4/mf6/templates/blocks.jinja @@ -1,5 +1,5 @@ {% import 'macros.jinja' as macros with context %} -{% for block_name, block_ in (dfn|blocks).items() %} +{% for block_name, block_ in (dfn|dict_blocks).items() %} BEGIN {{ block_name.upper() }} {% for field in block_.values() -%} {{ macros.field(field) }} @@ -7,3 +7,7 @@ BEGIN {{ block_name.upper() }} END {{ block_name.upper() }} {% endfor %} + +{% for block_name, block_ in (dfn|list_blocks).items() -%} +{{ macros.list(block_name, block_) }} +{%- endfor%} diff --git a/flopy4/mf6/templates/macros.jinja b/flopy4/mf6/templates/macros.jinja index baa3477f..86bf07ff 100644 --- a/flopy4/mf6/templates/macros.jinja +++ b/flopy4/mf6/templates/macros.jinja @@ -7,7 +7,7 @@ {% elif type == 'keystring' %} {{ keystring(f) }} {% elif type == 'recarray' %} -{{ recarray(f) }} +{{ recarray(f, how=f|array_how) }} {% endif %} {% endmacro %} @@ -29,16 +29,9 @@ {%- endfor %} {% endmacro %} -{% macro recarray(f) %} +{% macro recarray(f, how="internal") %} +{% set name = f.name %} {% set value = f|field_value %} -{% if f.reader != 'readarray' %} -{{ list(f) }} -{% else %} -{{ array(f.name, value, how=f|array_how) }} -{% endif %} -{% endmacro %} - -{% macro array(name, value, how="internal") %} {{ name.upper() }}{% if "layered" in how %} LAYERED{% endif %} {% if how == "constant" %} @@ -57,9 +50,28 @@ OPEN/CLOSE {{ value }} {% endif %} {% endmacro %} -{% macro list(f) %} -{{ f }} -{% for item in f.children.values() %} -{{ field(item) }} +{% macro list(block_name, block) %} +{# +from mf6's perspective, a list block (e.g. period data) +always has just one variable, whose elements might be +records or unions. where we spin those out into arrays +for each individual leaf field to fit the xarray data +model, we have to combine them back here. + +this macro receives the block definition. from that +it looks up the value of the one variable with the +same name as the block, which custom converter has +made sure exists in a sparse dict representation of +an array. we need to spin this out into a block for +each stress period. +#} +{% set dict = data[block_name] %} +{% for kper, value in dict.items() %} +BEGIN {{ block_name.upper() }} {{ kper }} +{% for line in value %} +{{ line|join(" ")|upper }} +{% endfor %} +END {{ block_name.upper() }} {{ kper }} + {% endfor %} {% endmacro %} diff --git a/test/test_codec.py b/test/test_codec.py index bc5b64fc..b365bde7 100644 --- a/test/test_codec.py +++ b/test/test_codec.py @@ -21,7 +21,6 @@ def test_dumps_ic(): assert result -@pytest.mark.xfail(reason="TODO period block unstructuring") def test_dumps_oc(): from flopy4.mf6.gwf import Oc