Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flopy4/mf6/codec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
unstructure_array,
unstructure_component,
unstructure_oc,
unstructure_tdis,
)

_JINJA_ENV = Environment(
Expand Down Expand Up @@ -40,10 +41,12 @@
def _make_converter() -> Converter:
from flopy4.mf6.component import Component
from flopy4.mf6.gwf.oc import Oc
from flopy4.mf6.tdis import Tdis

converter = Converter()
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
converter.register_unstructure_hook(Component, unstructure_component)
converter.register_unstructure_hook(Tdis, unstructure_tdis)
converter.register_unstructure_hook(Oc, unstructure_oc)
return converter

Expand Down
41 changes: 35 additions & 6 deletions flopy4/mf6/codec/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,39 @@ def unstructure_component(value: Component) -> dict[str, Any]:
return data


def unstructure_tdis(value: Any) -> dict[str, Any]:
data = xattree.asdict(value)
blocks = get_blocks(value.dfn)
for block_name, block in blocks.items():
if block_name == "perioddata":
array_fields = list(block.keys())

# Unstructure all arrays and collect all unique periods
arrays = {}
periods = set() # type: ignore
for field_name in array_fields:
arr = unstructure_array(data.get(field_name, {}))
arrays[field_name] = arr
periods.update(arr.keys())
periods = sorted(periods) # type: ignore

perioddata = {} # type: ignore
for kper in periods:
line = []
for arr in arrays.values():
if kper not in perioddata:
perioddata[kper] = [] # type: ignore
line.append(arr[kper])
perioddata[kper] = tuple(line)

data["perioddata"] = perioddata
return data


def unstructure_oc(value: Any) -> dict[str, Any]:
data = xattree.asdict(value)
for block_name, block in get_blocks(value.dfn).items():
blocks = get_blocks(value.dfn)
for block_name, block in blocks.items():
if block_name == "period":
# Dynamically collect all recarray fields in perioddata block
array_fields = []
Expand All @@ -156,16 +186,15 @@ def unstructure_oc(value: Any) -> dict[str, Any]:

# Unstructure all arrays and collect all unique periods
arrays = {}
all_periods = set() # type: ignore
periods = set() # type: ignore
for action, rtype, field_name in array_fields:
arr = unstructure_array(data.get(field_name, {}))
arrays[(action, rtype)] = arr
if isinstance(arr, dict):
all_periods.update(arr.keys())
all_periods = sorted(all_periods) # type: ignore
periods.update(arr.keys())
periods = sorted(periods) # type: ignore

perioddata = {} # type: ignore
for kper in all_periods:
for kper in periods:
for (action, rtype), arr in arrays.items():
if kper in arr:
if kper not in perioddata:
Expand Down
4 changes: 2 additions & 2 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ def _preio(self, format: str) -> None:
self.filename = self.default_filename()

def load(self, format: str) -> None:
"""Load the component from an input file."""
"""Load the component and any children."""
self._preio(format=format)
self._load(format=format)
for child in self.children.values(): # type: ignore
child.load(format=format)

def write(self, format: str) -> None:
"""Write the component to an input file."""
"""Write the component and any children."""
self._preio(format=format)
self._write(format=format)
for child in self.children.values(): # type: ignore
Expand Down
6 changes: 2 additions & 4 deletions flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,17 @@ def dict_blocks(dfn: Dfn) -> dict:
one or more fields, as opposed to a list block, which
may only contain one recarray field, using list input.
"""
x = {
return {
block_name: block
for block_name, block in get_blocks(dfn).items()
if not _is_list_block(block)
}
return x


def list_blocks(dfn: Dfn) -> dict:
x = {
return {
block_name: block for block_name, block in get_blocks(dfn).items() if _is_list_block(block)
}
return x


def field_type(field: Field) -> str:
Expand Down
8 changes: 4 additions & 4 deletions flopy4/mf6/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def run(self, exe: str | PathLike = "mf6", verbose: bool = False) -> None:
)

def load(self, format="ascii"):
"""Load the simulation in the specified format."""
"""Load the simulation."""
with cd(self.workspace):
super().load(format)
super().load(format=format)

def write(self, format="ascii"):
"""Write the simulation in the specified format."""
"""Write the simulation."""
with cd(self.workspace):
super().write(format)
super().write(format=format)
3 changes: 3 additions & 0 deletions flopy4/mf6/tdis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,21 @@ class PeriodData:
default=1.0,
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)
nstp: NDArray[np.integer] = array(
block="perioddata",
default=1,
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)
tsmult: NDArray[np.floating] = array(
block="perioddata",
default=1.0,
dims=("nper",),
converter=Converter(structure_array, takes_self=True, takes_field=True),
reader="urword",
)

def to_time(self) -> ModelTime:
Expand Down
2 changes: 1 addition & 1 deletion flopy4/mf6/templates/blocks.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ END {{ block_name.upper() }}
{% endfor %}

{% for block_name, block_ in (dfn|list_blocks).items() -%}
{{ macros.list(block_name, block_) }}
{{ macros.list(block_name, block_, stress=block_name == "period") }}
{%- endfor%}
14 changes: 11 additions & 3 deletions flopy4/mf6/templates/macros.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ OPEN/CLOSE {{ value }}
{% endif %}
{% endmacro %}

{% macro list(block_name, block) %}
{% macro list(block_name, block, stress=False) %}
{#
from mf6's perspective, a list block (e.g. period data)
always has just one variable, whose elements might be
Expand All @@ -65,13 +65,21 @@ made sure exists in a sparse dict representation of
an array. we need to expand this into a block for
each stress period.
#}
{% set dict = data[block_name] %}
{% for kper, value in dict.items() %}
{% set d = data[block_name] %}
{% if stress %}
{% for kper, value in d.items() %}
BEGIN {{ block_name.upper() }} {{ kper }}
{% for line in value %}
{{ line|join(" ")|upper }}
{% endfor %}
END {{ block_name.upper() }} {{ kper }}

{% endfor %}
{% else %}
BEGIN {{ block_name.upper() }}
{% for line in d.values() %}
{{ line|join(" ")|upper }}
{% endfor %}
END {{ block_name.upper() }}
{% endif %}
{% endmacro %}
Loading