Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flopy4/mf6/attr_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def update_maxbound(instance, attribute, new_value):
period_arrays = []
instance_fields = fields(instance.__class__)
for field in instance_fields:
if field.metadata and field.metadata.get("block") == "period" and "dims" in field.metadata:
if (
field.metadata
and field.metadata.get("block") == "period"
and field.metadata.get("xattree", {}).get("dims")
):
period_arrays.append(field.name)

maxbound_values = []
Expand Down
14 changes: 6 additions & 8 deletions flopy4/mf6/codec/writer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
trim_blocks=True,
lstrip_blocks=True,
)
_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks
_JINJA_ENV.filters["list_blocks"] = filters.list_blocks
_JINJA_ENV.filters["is_dataset"] = filters.is_dataset
_JINJA_ENV.filters["field_format"] = filters.field_format
_JINJA_ENV.filters["array_how"] = filters.array_how
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
_JINJA_ENV.filters["array2string"] = filters.array2string
_JINJA_ENV.filters["field_type"] = filters.field_type
_JINJA_ENV.filters["array2list"] = filters.array2list
_JINJA_ENV.filters["keystring2list"] = filters.keystring2list
_JINJA_ENV.filters["keystring2list_multifield"] = filters.keystring2list_multifield
_JINJA_ENV.filters["data2list"] = filters.data2list
_JINJA_ENV.filters["data2keystring"] = filters.data2keystring
_JINJA_TEMPLATE_NAME = "blocks.jinja"
_PRINT_OPTIONS = {
"precision": 4,
Expand All @@ -31,11 +29,11 @@
def dumps(data) -> str:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
return template.render(data=data)
return template.render(blocks=data)


def dump(data, path: str | PathLike) -> None:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
iterator = template.generate(data=data)
iterator = template.generate(blocks=data)
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
f.writelines(iterator)
6 changes: 1 addition & 5 deletions flopy4/mf6/codec/writer/templates/blocks.jinja
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
{% import 'macros.jinja' as macros with context %}
{% for block_name, block_value in (data|dict_blocks).items() %}
{% for block_name, block_value in blocks.items() %}
BEGIN {{ block_name.upper() }}
{% for field_name, field_value in block_value.items() if (field_value) is not none -%}
{{ macros.field(field_name, field_value) }}
{%- endfor %}
END {{ block_name.upper() }}

{% endfor %}

{% for block_name, block_value in (data|list_blocks).items() -%}
{{ macros.list(block_name, block_value, multi=block_name in ["period"]) }}
{%- endfor %}
67 changes: 20 additions & 47 deletions flopy4/mf6/codec/writer/templates/macros.jinja
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
{% macro field(name, value) %}
{% set type = value|field_type %}
{% if type in ['keyword', 'integer', 'double precision', 'string'] %}
{% set format = value|field_format %}
{% if format in ['keyword', 'integer', 'double precision', 'string'] %}
{{ scalar(name, value) }}
{% elif type == 'record' %}
{% elif format == 'record' %}
{{ record(name, value) }}
{% elif type == 'keystring' %}
{% elif format == 'keystring' %}
{{ keystring(name, value) }}
{% elif type == 'recarray' %}
{{ recarray(name, value, how=value|array_how) }}
{% elif format == 'array' %}
{{ array(name, value, how=value|array_how) }}
{% elif format == 'list' %}
{{ list(name, value) }}
{% endif %}
{% endmacro %}

{% macro scalar(name, value) %}
{% set type = value|field_type %}
{% if value is not none %}{{ name.upper() }}{% if type != 'keyword' %} {{ value }}{% endif %}{% endif %}
{% set format = value|field_format %}
{% if value is not none %}{{ name.upper() }}{% if format != 'keyword' %} {{ value }}{% endif %}{% endif %}
{% endmacro %}

{% macro keystring(name, value) %}
{% for item in value.values() -%}
{{ field(item) }}
{%- endfor %}
{% for option in (value|data2keystring) -%}
{{ record("", option) }}
{% endfor %}
{% endmacro %}

{% macro record(name, value) %}
{% if value is mapping %}
{% for item in value.values() -%}
{{ item.name.upper() }} {{ field(item) }}
{%- if value is mapping %}
{% for field_name, field_value in value.items() -%}
{{ field_name.upper() }} {{ field(field_value) }}
{%- endfor %}
{% else %}
{{ value|join(" ") }}
{% endif %}
{%- endif %}
{% endmacro %}

{% macro recarray(name, value, how="internal") %}
{% macro array(name, value, how="internal") %}
{{ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}

{% if how == "constant" %}
Expand All @@ -51,37 +53,8 @@ OPEN/CLOSE {{ value }}
{% endif %}
{% endmacro %}

{% macro list(name, value, multi=False) %}
{% if multi %}
{# iterate through time periods and combine fields #}
{% set field_arrays = {} %}
{% for field_name, field_value in value.items() %}
{% if field_value is not none %}
{% set _ = field_arrays.update({field_name: field_value}) %}
{% endif %}
{% endfor %}

{% set first_array = field_arrays.values() | list | first %}
{% if first_array is not none %}
{% set nper = first_array.shape[0] %}
{% for period_idx in range(nper) %}
BEGIN {{ name.upper() }} {{ period_idx + 1 }}
{% for row in (field_arrays|keystring2list_multifield(period_idx)) %}
{% macro list(name, value) %}
{% for row in (value|data2list) %}
{{ row|join(" ") }}
{% endfor %}
END {{ name.upper() }} {{ period_idx + 1 }}

{% endfor %}
{% endif %}
{% else %}
{% for field_name, field_value in value.items() %}
{% if field_value is not none %}
BEGIN {{ name.upper() }}
{% for row in (field_value|array2list) %}
{{ row|join(" ") }}
{% endfor %}
END {{ name.upper() }}
{% endif %}
{% endfor %}
{% endif %}
{% endmacro %}
77 changes: 60 additions & 17 deletions flopy4/mf6/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,29 @@
from pathlib import Path
from typing import Any

import xarray as xr
import xattree
from cattrs import Converter

from flopy4.mf6.component import Component
from flopy4.mf6.spec import get_blocks
from flopy4.mf6.spec import fields_dict, get_blocks


def _transform_path_to_record(field_name: str, path_value: Path) -> tuple:
"""Transform a Path field to its corresponding MF6 record format."""
# Infer record structure from field name
def _attach_field_metadata(
dataset: xr.Dataset, component_type: type, field_names: list[str]
) -> None:
field_metadata = {}
component_fields = fields_dict(component_type)
for field_name in field_names:
if field_name in component_fields:
field_metadata[field_name] = component_fields[field_name].metadata
dataset.attrs["field_metadata"] = field_metadata


def _path_to_record(field_name: str, path_value: Path) -> tuple:
if field_name.endswith("_file"):
base_name = field_name.replace("_file", "").upper()
return (base_name, "FILEOUT", str(path_value))

# Default fallback
return (field_name.upper(), "FILEOUT", str(path_value))


Expand All @@ -26,27 +34,62 @@ def unstructure_component(value: Component) -> dict[str, Any]:
blocks: dict[str, dict[str, Any]] = {}
for block_name, block in blockspec.items():
blocks[block_name] = {}
period_data = {}
period_blocks = {} # type: ignore
for field_name in block.keys():
field_value = data[field_name]

# Transform Path fields to record format
# convert:
# - paths to records
# - datetime to ISO format
# - auxiliary fields to tuples
# - xarray DataArrays with 'nper' dimension to kper-sliced datasets
# (and split the period data into separate kper-indexed blocks)
# - other values to their original form
if isinstance(field_value, Path) and field_value is not None:
field_value = _transform_path_to_record(field_name, field_value)

# Transform datetime fields to string format
blocks[block_name][field_name] = _path_to_record(field_name, field_value)
elif isinstance(field_value, datetime) and field_value is not None:
field_value = field_value.isoformat()

# Transform auxiliary fields to tuple for single-line record format
blocks[block_name][field_name] = field_value.isoformat()
elif (
field_name == "auxiliary"
and hasattr(field_value, "values")
and field_value is not None
):
field_value = tuple(field_value.values.tolist())
blocks[block_name][field_name] = tuple(field_value.values.tolist())
elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims:
has_spatial_dims = any(
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nnodes"]
)
if has_spatial_dims:
period_data[field_name] = {
kper: field_value.isel(nper=kper)
for kper in range(field_value.sizes["nper"])
}
else:
if block_name not in period_data:
period_data[block_name] = {}
period_data[block_name][field_name] = field_value # type: ignore
else:
if field_value is not None:
blocks[block_name][field_name] = field_value

if block_name in period_data and isinstance(period_data[block_name], dict):
dataset = xr.Dataset(period_data[block_name])
_attach_field_metadata(dataset, type(value), list(period_data[block_name].keys())) # type: ignore
blocks[block_name] = {block_name: dataset}
del period_data[block_name]

for arr_name, periods in period_data.items():
for kper, arr in periods.items():
if kper not in period_blocks:
period_blocks[kper] = {}
period_blocks[kper][arr_name] = arr

for kper, block in period_blocks.items():
dataset = xr.Dataset(block)
_attach_field_metadata(dataset, type(value), list(block.keys()))
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}

blocks[block_name][field_name] = field_value
return blocks
return {name: block for name, block in blocks.items() if block}


def _make_converter() -> Converter:
Expand Down
Loading
Loading