Skip to content

Commit b2bdd95

Browse files
authored
writer progress (#173)
use datasets instead of dicts and move the logic to slice period data arrays into blocks into the converter. makes the templates and macros much simpler. should not be expensive as we don't actually materialize the array yet, just slice it. miscellaneous other cleanup and fixes. output format looking pretty good for everything except model and simulation, still need to solve parent/child binding problem (coming next).
1 parent f95b38f commit b2bdd95

File tree

10 files changed

+272
-372
lines changed

10 files changed

+272
-372
lines changed

flopy4/mf6/attr_hooks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def update_maxbound(instance, attribute, new_value):
2626
period_arrays = []
2727
instance_fields = fields(instance.__class__)
2828
for field in instance_fields:
29-
if field.metadata and field.metadata.get("block") == "period" and "dims" in field.metadata:
29+
if (
30+
field.metadata
31+
and field.metadata.get("block") == "period"
32+
and field.metadata.get("xattree", {}).get("dims")
33+
):
3034
period_arrays.append(field.name)
3135

3236
maxbound_values = []

flopy4/mf6/codec/writer/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111
trim_blocks=True,
1212
lstrip_blocks=True,
1313
)
14-
_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks
15-
_JINJA_ENV.filters["list_blocks"] = filters.list_blocks
14+
_JINJA_ENV.filters["is_dataset"] = filters.is_dataset
15+
_JINJA_ENV.filters["field_format"] = filters.field_format
1616
_JINJA_ENV.filters["array_how"] = filters.array_how
1717
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
1818
_JINJA_ENV.filters["array2string"] = filters.array2string
19-
_JINJA_ENV.filters["field_type"] = filters.field_type
20-
_JINJA_ENV.filters["array2list"] = filters.array2list
21-
_JINJA_ENV.filters["keystring2list"] = filters.keystring2list
22-
_JINJA_ENV.filters["keystring2list_multifield"] = filters.keystring2list_multifield
19+
_JINJA_ENV.filters["data2list"] = filters.data2list
20+
_JINJA_ENV.filters["data2keystring"] = filters.data2keystring
2321
_JINJA_TEMPLATE_NAME = "blocks.jinja"
2422
_PRINT_OPTIONS = {
2523
"precision": 4,
@@ -31,11 +29,11 @@
3129
def dumps(data) -> str:
3230
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
3331
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
34-
return template.render(data=data)
32+
return template.render(blocks=data)
3533

3634

3735
def dump(data, path: str | PathLike) -> None:
3836
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
39-
iterator = template.generate(data=data)
37+
iterator = template.generate(blocks=data)
4038
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
4139
f.writelines(iterator)
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
{% import 'macros.jinja' as macros with context %}
2-
{% for block_name, block_value in (data|dict_blocks).items() %}
2+
{% for block_name, block_value in blocks.items() %}
33
BEGIN {{ block_name.upper() }}
44
{% for field_name, field_value in block_value.items() if (field_value) is not none -%}
55
{{ macros.field(field_name, field_value) }}
66
{%- endfor %}
77
END {{ block_name.upper() }}
88

99
{% endfor %}
10-
11-
{% for block_name, block_value in (data|list_blocks).items() -%}
12-
{{ macros.list(block_name, block_value, multi=block_name in ["period"]) }}
13-
{%- endfor %}
Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,40 @@
11
{% macro field(name, value) %}
2-
{% set type = value|field_type %}
3-
{% if type in ['keyword', 'integer', 'double precision', 'string'] %}
2+
{% set format = value|field_format %}
3+
{% if format in ['keyword', 'integer', 'double precision', 'string'] %}
44
{{ scalar(name, value) }}
5-
{% elif type == 'record' %}
5+
{% elif format == 'record' %}
66
{{ record(name, value) }}
7-
{% elif type == 'keystring' %}
7+
{% elif format == 'keystring' %}
88
{{ keystring(name, value) }}
9-
{% elif type == 'recarray' %}
10-
{{ recarray(name, value, how=value|array_how) }}
9+
{% elif format == 'array' %}
10+
{{ array(name, value, how=value|array_how) }}
11+
{% elif format == 'list' %}
12+
{{ list(name, value) }}
1113
{% endif %}
1214
{% endmacro %}
1315

1416
{% macro scalar(name, value) %}
15-
{% set type = value|field_type %}
16-
{% if value is not none %}{{ name.upper() }}{% if type != 'keyword' %} {{ value }}{% endif %}{% endif %}
17+
{% set format = value|field_format %}
18+
{% if value is not none %}{{ name.upper() }}{% if format != 'keyword' %} {{ value }}{% endif %}{% endif %}
1719
{% endmacro %}
1820

1921
{% macro keystring(name, value) %}
20-
{% for item in value.values() -%}
21-
{{ field(item) }}
22-
{%- endfor %}
22+
{% for option in (value|data2keystring) -%}
23+
{{ record("", option) }}
24+
{% endfor %}
2325
{% endmacro %}
2426

2527
{% macro record(name, value) %}
26-
{% if value is mapping %}
27-
{% for item in value.values() -%}
28-
{{ item.name.upper() }} {{ field(item) }}
28+
{%- if value is mapping %}
29+
{% for field_name, field_value in value.items() -%}
30+
{{ field_name.upper() }} {{ field(field_value) }}
2931
{%- endfor %}
3032
{% else %}
3133
{{ value|join(" ") }}
32-
{% endif %}
34+
{%- endif %}
3335
{% endmacro %}
3436

35-
{% macro recarray(name, value, how="internal") %}
37+
{% macro array(name, value, how="internal") %}
3638
{{ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}
3739

3840
{% if how == "constant" %}
@@ -51,37 +53,8 @@ OPEN/CLOSE {{ value }}
5153
{% endif %}
5254
{% endmacro %}
5355

54-
{% macro list(name, value, multi=False) %}
55-
{% if multi %}
56-
{# iterate through time periods and combine fields #}
57-
{% set field_arrays = {} %}
58-
{% for field_name, field_value in value.items() %}
59-
{% if field_value is not none %}
60-
{% set _ = field_arrays.update({field_name: field_value}) %}
61-
{% endif %}
62-
{% endfor %}
63-
64-
{% set first_array = field_arrays.values() | list | first %}
65-
{% if first_array is not none %}
66-
{% set nper = first_array.shape[0] %}
67-
{% for period_idx in range(nper) %}
68-
BEGIN {{ name.upper() }} {{ period_idx + 1 }}
69-
{% for row in (field_arrays|keystring2list_multifield(period_idx)) %}
56+
{% macro list(name, value) %}
57+
{% for row in (value|data2list) %}
7058
{{ row|join(" ") }}
7159
{% endfor %}
72-
END {{ name.upper() }} {{ period_idx + 1 }}
73-
74-
{% endfor %}
75-
{% endif %}
76-
{% else %}
77-
{% for field_name, field_value in value.items() %}
78-
{% if field_value is not none %}
79-
BEGIN {{ name.upper() }}
80-
{% for row in (field_value|array2list) %}
81-
{{ row|join(" ") }}
82-
{% endfor %}
83-
END {{ name.upper() }}
84-
{% endif %}
85-
{% endfor %}
86-
{% endif %}
8760
{% endmacro %}

flopy4/mf6/converter.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,29 @@
22
from pathlib import Path
33
from typing import Any
44

5+
import xarray as xr
56
import xattree
67
from cattrs import Converter
78

89
from flopy4.mf6.component import Component
9-
from flopy4.mf6.spec import get_blocks
10+
from flopy4.mf6.spec import fields_dict, get_blocks
1011

1112

12-
def _transform_path_to_record(field_name: str, path_value: Path) -> tuple:
13-
"""Transform a Path field to its corresponding MF6 record format."""
14-
# Infer record structure from field name
13+
def _attach_field_metadata(
14+
dataset: xr.Dataset, component_type: type, field_names: list[str]
15+
) -> None:
16+
field_metadata = {}
17+
component_fields = fields_dict(component_type)
18+
for field_name in field_names:
19+
if field_name in component_fields:
20+
field_metadata[field_name] = component_fields[field_name].metadata
21+
dataset.attrs["field_metadata"] = field_metadata
22+
23+
24+
def _path_to_record(field_name: str, path_value: Path) -> tuple:
1525
if field_name.endswith("_file"):
1626
base_name = field_name.replace("_file", "").upper()
1727
return (base_name, "FILEOUT", str(path_value))
18-
19-
# Default fallback
2028
return (field_name.upper(), "FILEOUT", str(path_value))
2129

2230

@@ -26,27 +34,62 @@ def unstructure_component(value: Component) -> dict[str, Any]:
2634
blocks: dict[str, dict[str, Any]] = {}
2735
for block_name, block in blockspec.items():
2836
blocks[block_name] = {}
37+
period_data = {}
38+
period_blocks = {} # type: ignore
2939
for field_name in block.keys():
3040
field_value = data[field_name]
31-
32-
# Transform Path fields to record format
41+
# convert:
42+
# - paths to records
43+
# - datetime to ISO format
44+
# - auxiliary fields to tuples
45+
# - xarray DataArrays with 'nper' dimension to kper-sliced datasets
46+
# (and split the period data into separate kper-indexed blocks)
47+
# - other values to their original form
3348
if isinstance(field_value, Path) and field_value is not None:
34-
field_value = _transform_path_to_record(field_name, field_value)
35-
36-
# Transform datetime fields to string format
49+
blocks[block_name][field_name] = _path_to_record(field_name, field_value)
3750
elif isinstance(field_value, datetime) and field_value is not None:
38-
field_value = field_value.isoformat()
39-
40-
# Transform auxiliary fields to tuple for single-line record format
51+
blocks[block_name][field_name] = field_value.isoformat()
4152
elif (
4253
field_name == "auxiliary"
4354
and hasattr(field_value, "values")
4455
and field_value is not None
4556
):
46-
field_value = tuple(field_value.values.tolist())
57+
blocks[block_name][field_name] = tuple(field_value.values.tolist())
58+
elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims:
59+
has_spatial_dims = any(
60+
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nnodes"]
61+
)
62+
if has_spatial_dims:
63+
period_data[field_name] = {
64+
kper: field_value.isel(nper=kper)
65+
for kper in range(field_value.sizes["nper"])
66+
}
67+
else:
68+
if block_name not in period_data:
69+
period_data[block_name] = {}
70+
period_data[block_name][field_name] = field_value # type: ignore
71+
else:
72+
if field_value is not None:
73+
blocks[block_name][field_name] = field_value
74+
75+
if block_name in period_data and isinstance(period_data[block_name], dict):
76+
dataset = xr.Dataset(period_data[block_name])
77+
_attach_field_metadata(dataset, type(value), list(period_data[block_name].keys())) # type: ignore
78+
blocks[block_name] = {block_name: dataset}
79+
del period_data[block_name]
80+
81+
for arr_name, periods in period_data.items():
82+
for kper, arr in periods.items():
83+
if kper not in period_blocks:
84+
period_blocks[kper] = {}
85+
period_blocks[kper][arr_name] = arr
86+
87+
for kper, block in period_blocks.items():
88+
dataset = xr.Dataset(block)
89+
_attach_field_metadata(dataset, type(value), list(block.keys()))
90+
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}
4791

48-
blocks[block_name][field_name] = field_value
49-
return blocks
92+
return {name: block for name, block in blocks.items() if block}
5093

5194

5295
def _make_converter() -> Converter:

0 commit comments

Comments
 (0)