Skip to content

Commit 0a36f75

Browse files
authored
refactor input file writing (#171)
separate writer and converter such that converter just turns proper class instances into dicts/etc and back. rewrite the eager/expensive converter approach for arrays destined for list-based format, now with lazy iteration in the template, as described in #109 (comment). misc other cleanup/fixes as well.
1 parent 6060f9b commit 0a36f75

File tree

23 files changed

+1926
-1274
lines changed

23 files changed

+1926
-1274
lines changed

flopy4/mf6/attr_hooks.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Attribute hooks for attrs on_setattr callbacks."""
2+
3+
import numpy as np
4+
from attrs import fields
5+
6+
from flopy4.mf6.constants import FILL_DNODATA
7+
8+
9+
def update_maxbound(instance, attribute, new_value):
10+
"""
11+
Generalized function to update maxbound when period block arrays change.
12+
13+
This function automatically finds all period block arrays in the instance
14+
and calculates maxbound based on the maximum number of non-default values
15+
across all arrays.
16+
17+
Args:
18+
instance: The package instance
19+
attribute: The attribute being set (from attrs on_setattr)
20+
new_value: The new value being set
21+
22+
Returns:
23+
The new_value (unchanged)
24+
"""
25+
26+
period_arrays = []
27+
instance_fields = fields(instance.__class__)
28+
for field in instance_fields:
29+
if field.metadata and field.metadata.get("block") == "period" and "dims" in field.metadata:
30+
period_arrays.append(field.name)
31+
32+
maxbound_values = []
33+
for array_name in period_arrays:
34+
if attribute and attribute.name == array_name:
35+
array_val = new_value
36+
else:
37+
array_val = getattr(instance, array_name, None)
38+
39+
if array_val is not None:
40+
array_data = (
41+
array_val if array_val.data.shape == array_val.shape else array_val.todense()
42+
)
43+
44+
if array_data.dtype.kind in ["U", "S"]: # String arrays
45+
non_default_count = len(np.where(array_data != "")[0])
46+
else: # Numeric arrays
47+
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])
48+
49+
maxbound_values.append(non_default_count)
50+
if maxbound_values:
51+
instance.maxbound = max(maxbound_values)
52+
53+
return new_value

flopy4/mf6/codec/__init__.py

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,6 @@
1-
import sys
2-
from os import PathLike
3-
from typing import Any
4-
5-
import numpy as np
6-
import xattree
7-
from cattrs import Converter
8-
from jinja2 import Environment, PackageLoader
9-
10-
from flopy4.mf6 import filters
11-
from flopy4.mf6.codec.converter import (
12-
unstructure_array,
13-
unstructure_chd,
14-
unstructure_component,
15-
unstructure_oc,
16-
unstructure_tdis,
17-
)
18-
19-
_JINJA_ENV = Environment(
20-
loader=PackageLoader("flopy4.mf6"),
21-
trim_blocks=True,
22-
lstrip_blocks=True,
23-
)
24-
_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks
25-
_JINJA_ENV.filters["list_blocks"] = filters.list_blocks
26-
_JINJA_ENV.filters["field_type"] = filters.field_type
27-
_JINJA_ENV.filters["field_value"] = filters.field_value
28-
_JINJA_ENV.filters["array_how"] = filters.array_how
29-
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
30-
_JINJA_ENV.filters["array2string"] = filters.array2string
31-
32-
_JINJA_TEMPLATE_NAME = "blocks.jinja"
33-
34-
_PRINT_OPTIONS = {
35-
"precision": 4,
36-
"linewidth": sys.maxsize,
37-
"threshold": sys.maxsize,
38-
}
39-
40-
41-
def _make_converter() -> Converter:
42-
# TODO: document what is converter's responsibility vs Jinja's
43-
# TODO: how can we make sure writing remains lazy for list input?
44-
# don't eagerly unstructure to dict, lazily access from the template?
45-
46-
from flopy4.mf6.component import Component
47-
from flopy4.mf6.gwf.chd import Chd
48-
from flopy4.mf6.gwf.oc import Oc
49-
from flopy4.mf6.tdis import Tdis
50-
51-
converter = Converter()
52-
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
53-
converter.register_unstructure_hook(Component, unstructure_component)
54-
converter.register_unstructure_hook(Tdis, unstructure_tdis)
55-
converter.register_unstructure_hook(Chd, unstructure_chd)
56-
converter.register_unstructure_hook(Oc, unstructure_oc)
57-
return converter
58-
59-
60-
_CONVERTER = _make_converter()
61-
62-
63-
def loads(data: str) -> Any:
64-
# TODO
65-
pass
66-
67-
68-
def load(path: str | PathLike) -> Any:
69-
# TODO
70-
pass
71-
72-
73-
def dumps(data) -> str:
74-
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
75-
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
76-
return template.render(dfn=type(data).dfn, data=_CONVERTER.unstructure(data))
77-
78-
79-
def dump(data, path: str | PathLike) -> None:
80-
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
81-
iterator = template.generate(dfn=type(data).dfn, data=_CONVERTER.unstructure(data))
82-
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
83-
f.writelines(iterator)
84-
1+
from flopy4.mf6.codec.writer import dump, dumps, load, loads
852

863
__all__ = [
87-
"unstructure_array",
884
"loads",
895
"load",
906
"dumps",

flopy4/mf6/codec/converter.py

Lines changed: 0 additions & 154 deletions
This file was deleted.

flopy4/mf6/codec/reader/__init__.py

Whitespace-only changes.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import sys
2+
from os import PathLike
3+
from typing import Any
4+
5+
import numpy as np
6+
from jinja2 import Environment, PackageLoader
7+
8+
from flopy4.mf6 import filters
9+
10+
_JINJA_ENV = Environment(
11+
loader=PackageLoader("flopy4.mf6.codec.writer"),
12+
trim_blocks=True,
13+
lstrip_blocks=True,
14+
)
15+
_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks
16+
_JINJA_ENV.filters["list_blocks"] = filters.list_blocks
17+
_JINJA_ENV.filters["array_how"] = filters.array_how
18+
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
19+
_JINJA_ENV.filters["array2string"] = filters.array2string
20+
_JINJA_ENV.filters["field_type"] = filters.field_type
21+
_JINJA_ENV.filters["array2list"] = filters.array2list
22+
_JINJA_ENV.filters["keystring2list"] = filters.keystring2list
23+
_JINJA_ENV.filters["keystring2list_multifield"] = filters.keystring2list_multifield
24+
_JINJA_TEMPLATE_NAME = "blocks.jinja"
25+
_PRINT_OPTIONS = {
26+
"precision": 4,
27+
"linewidth": sys.maxsize,
28+
"threshold": sys.maxsize,
29+
}
30+
31+
32+
def loads(data: str) -> Any:
33+
# TODO
34+
pass
35+
36+
37+
def load(path: str | PathLike) -> Any:
38+
# TODO
39+
pass
40+
41+
42+
def dumps(data) -> str:
43+
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
44+
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
45+
return template.render(data=data)
46+
47+
48+
def dump(data, path: str | PathLike) -> None:
49+
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
50+
iterator = template.generate(data=data)
51+
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
52+
f.writelines(iterator)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{% import 'macros.jinja' as macros with context %}
2+
{% for block_name, block_value in (data|dict_blocks).items() %}
3+
BEGIN {{ block_name.upper() }}
4+
{% for field_name, field_value in block_value.items() if (field_value) is not none -%}
5+
{{ macros.field(field_name, field_value) }}
6+
{%- endfor %}
7+
END {{ block_name.upper() }}
8+
9+
{% endfor %}
10+
11+
{% for block_name, block_value in (data|list_blocks).items() -%}
12+
{{ macros.list(block_name, block_value, multi=block_name in ["period"]) }}
13+
{%- endfor %}

0 commit comments

Comments
 (0)