Skip to content

Commit d8232a6

Browse files
committed
list input working for perioddata. unstructured oc working.
1 parent 2ca78e0 commit d8232a6

File tree

8 files changed

+106
-76
lines changed

8 files changed

+106
-76
lines changed

flopy4/mf6/codec/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
unstructure_component,
1515
unstructure_oc,
1616
)
17-
from flopy4.mf6.spec import get_blocks
1817

1918
_JINJA_ENV = Environment(
2019
loader=PackageLoader("flopy4.mf6"),
2120
trim_blocks=True,
2221
lstrip_blocks=True,
2322
)
24-
_JINJA_ENV.filters["blocks"] = get_blocks
23+
_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks
24+
_JINJA_ENV.filters["list_blocks"] = filters.list_blocks
2525
_JINJA_ENV.filters["field_type"] = filters.field_type
2626
_JINJA_ENV.filters["field_value"] = filters.field_value
2727
_JINJA_ENV.filters["array_how"] = filters.array_how

flopy4/mf6/codec/converter.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -146,52 +146,33 @@ def unstructure_component(value: Component) -> dict[str, Any]:
146146
def unstructure_oc(value: Any) -> dict[str, Any]:
147147
data = xattree.asdict(value)
148148
for block_name, block in get_blocks(value.dfn).items():
149-
if block_name == "perioddata":
150-
# Unstructure all four arrays
151-
save_head = unstructure_array(data.get("save_head", {}))
152-
save_budget = unstructure_array(data.get("save_budget", {}))
153-
print_head = unstructure_array(data.get("print_head", {}))
154-
print_budget = unstructure_array(data.get("print_budget", {}))
155-
156-
# Collect all unique periods
149+
if block_name == "period":
150+
# Dynamically collect all recarray fields in perioddata block
151+
array_fields = []
152+
for field_name, field in block.items():
153+
# Try to split field_name into action and kind, e.g. save_head -> ("save", "head")
154+
action, rtype = field_name.split("_")
155+
array_fields.append((action, rtype, field_name))
156+
157+
# Unstructure all arrays and collect all unique periods
158+
arrays = {}
157159
all_periods = set() # type: ignore
158-
for d in (save_head, save_budget, print_head, print_budget):
159-
if isinstance(d, dict):
160-
all_periods.update(d.keys())
160+
for action, rtype, field_name in array_fields:
161+
arr = unstructure_array(data.get(field_name, {}))
162+
arrays[(action, rtype)] = arr
163+
if isinstance(arr, dict):
164+
all_periods.update(arr.keys())
161165
all_periods = sorted(all_periods) # type: ignore
162166

163-
saverecord = {} # type: ignore
164-
printrecord = {} # type: ignore
167+
perioddata = {} # type: ignore
165168
for kper in all_periods:
166-
# Save head
167-
if kper in save_head:
168-
v = save_head[kper]
169-
if kper not in saverecord:
170-
saverecord[kper] = []
171-
saverecord[kper].append({"action": "save", "type": "head", "ocsetting": v})
172-
# Save budget
173-
if kper in save_budget:
174-
v = save_budget[kper]
175-
if kper not in saverecord:
176-
saverecord[kper] = []
177-
saverecord[kper].append({"action": "save", "type": "budget", "ocsetting": v})
178-
# Print head
179-
if kper in print_head:
180-
v = print_head[kper]
181-
if kper not in printrecord:
182-
printrecord[kper] = []
183-
printrecord[kper].append({"action": "print", "type": "head", "ocsetting": v})
184-
# Print budget
185-
if kper in print_budget:
186-
v = print_budget[kper]
187-
if kper not in printrecord:
188-
printrecord[kper] = []
189-
printrecord[kper].append({"action": "print", "type": "budget", "ocsetting": v})
190-
191-
data["saverecord"] = saverecord
192-
data["printrecord"] = printrecord
193-
data["save"] = "save"
194-
data["print"] = "print"
169+
for (action, rtype), arr in arrays.items():
170+
if kper in arr:
171+
if kper not in perioddata:
172+
perioddata[kper] = []
173+
perioddata[kper].append((action, rtype, arr[kper]))
174+
175+
data["period"] = perioddata
195176
else:
196177
for field_name, field in block.items():
197178
# unstructure arrays destined for list-based input

flopy4/mf6/filters.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,41 @@
44
import numpy as np
55
import xarray as xr
66
from jinja2 import pass_context
7-
from modflow_devtools.dfn import Field
7+
from modflow_devtools.dfn import Dfn, Field
88
from numpy.typing import NDArray
99

10+
from flopy4.mf6.spec import get_blocks
11+
12+
13+
def _is_list_block(block: dict) -> bool:
14+
return (
15+
len(block) == 1
16+
and (field := next(iter(block.values())))["type"] == "recarray"
17+
and field["reader"] != "readarray"
18+
) or (all(f["type"] == "recarray" and f["reader"] != "readarray" for f in block.values()))
19+
20+
21+
def dict_blocks(dfn: Dfn) -> dict:
22+
"""
23+
Get dictionary blocks from an MF6 input definition. A
24+
dictionary block is a standard block which can contain
25+
one or more fields, as opposed to a list block, which
26+
may only contain one recarray field, using list input.
27+
"""
28+
x = {
29+
block_name: block
30+
for block_name, block in get_blocks(dfn).items()
31+
if not _is_list_block(block)
32+
}
33+
return x
34+
35+
36+
def list_blocks(dfn: Dfn) -> dict:
37+
x = {
38+
block_name: block for block_name, block in get_blocks(dfn).items() if _is_list_block(block)
39+
}
40+
return x
41+
1042

1143
def field_type(field: Field) -> str:
1244
"""

flopy4/mf6/gwf/oc.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55
from attrs import Converter, define
6-
from modflow_devtools.dfn import Dfn, Field
6+
from modflow_devtools.dfn import Field
77
from numpy.typing import NDArray
88
from xattree import xattree
99

@@ -111,43 +111,44 @@ class Period:
111111
format: Optional[Format] = field(block="options", default=None, init=False)
112112
save_head: Optional[NDArray[np.object_]] = array(
113113
Steps,
114-
block="perioddata",
114+
block="period",
115115
default="all",
116116
dims=("nper",),
117117
converter=Converter(structure_array, takes_self=True, takes_field=True),
118118
reader="urword",
119119
)
120120
save_budget: Optional[NDArray[np.object_]] = array(
121121
Steps,
122-
block="perioddata",
122+
block="period",
123123
default="all",
124124
dims=("nper",),
125125
converter=Converter(structure_array, takes_self=True, takes_field=True),
126126
reader="urword",
127127
)
128128
print_head: Optional[NDArray[np.object_]] = array(
129129
Steps,
130-
block="perioddata",
130+
block="period",
131131
default="all",
132132
dims=("nper",),
133133
converter=Converter(structure_array, takes_self=True, takes_field=True),
134134
reader="urword",
135135
)
136136
print_budget: Optional[NDArray[np.object_]] = array(
137137
Steps,
138-
block="perioddata",
138+
block="period",
139139
default="all",
140140
dims=("nper",),
141141
converter=Converter(structure_array, takes_self=True, takes_field=True),
142142
reader="urword",
143143
)
144144

145-
@classmethod
146-
def get_dfn(cls) -> Dfn:
147-
"""Generate the component's MODFLOW 6 definition."""
148-
dfn = super().get_dfn()
149-
for field_name in list(dfn["perioddata"].keys()):
150-
dfn["perioddata"].pop(field_name)
151-
dfn["perioddata"]["saverecord"] = _oc_action_field("save")
152-
dfn["perioddata"]["printrecord"] = _oc_action_field("print")
153-
return dfn
145+
# original DFN
146+
# @classmethod
147+
# def get_dfn(cls) -> Dfn:
148+
# """Generate the component's MODFLOW 6 definition."""
149+
# dfn = super().get_dfn()
150+
# for field_name in list(dfn["perioddata"].keys()):
151+
# dfn["perioddata"].pop(field_name)
152+
# dfn["perioddata"]["saverecord"] = _oc_action_field("save")
153+
# dfn["perioddata"]["printrecord"] = _oc_action_field("print")
154+
# return dfn

flopy4/mf6/spec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def block_sort_key(item: tuple[str, dict]) -> int:
134134
return 2
135135
elif k == "packagedata":
136136
return 3
137-
elif k == "perioddata":
137+
elif "period" in k:
138+
# some packages have block "period", some have "perioddata"
138139
return 4
139140
else:
140141
return 5

flopy4/mf6/templates/blocks.jinja

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
{% import 'macros.jinja' as macros with context %}
2-
{% for block_name, block_ in (dfn|blocks).items() %}
2+
{% for block_name, block_ in (dfn|dict_blocks).items() %}
33
BEGIN {{ block_name.upper() }}
44
{% for field in block_.values() -%}
55
{{ macros.field(field) }}
66
{%- endfor %}
77
END {{ block_name.upper() }}
88

99
{% endfor %}
10+
11+
{% for block_name, block_ in (dfn|list_blocks).items() -%}
12+
{{ macros.list(block_name, block_) }}
13+
{%- endfor%}

flopy4/mf6/templates/macros.jinja

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
{% elif type == 'keystring' %}
88
{{ keystring(f) }}
99
{% elif type == 'recarray' %}
10-
{{ recarray(f) }}
10+
{{ recarray(f, how=f|array_how) }}
1111
{% endif %}
1212
{% endmacro %}
1313

@@ -29,16 +29,9 @@
2929
{%- endfor %}
3030
{% endmacro %}
3131

32-
{% macro recarray(f) %}
32+
{% macro recarray(f, how="internal") %}
33+
{% set name = f.name %}
3334
{% set value = f|field_value %}
34-
{% if f.reader != 'readarray' %}
35-
{{ list(f) }}
36-
{% else %}
37-
{{ array(f.name, value, how=f|array_how) }}
38-
{% endif %}
39-
{% endmacro %}
40-
41-
{% macro array(name, value, how="internal") %}
4235
{{ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}
4336

4437
{% if how == "constant" %}
@@ -57,9 +50,28 @@ OPEN/CLOSE {{ value }}
5750
{% endif %}
5851
{% endmacro %}
5952

60-
{% macro list(f) %}
61-
{{ f }}
62-
{% for item in f.children.values() %}
63-
{{ field(item) }}
53+
{% macro list(block_name, block) %}
54+
{#
55+
from mf6's perspective, a list block (e.g. period data)
56+
always has just one variable, whose elements might be
57+
records or unions. where we spin those out into arrays
58+
for each individual leaf field to fit the xarray data
59+
model, we have to combine them back here.
60+
61+
this macro receives the block definition. from that
62+
it looks up the value of the one variable with the
63+
same name as the block, which custom converter has
64+
made sure exists in a sparse dict representation of
65+
an array. we need to spin this out into a block for
66+
each stress period.
67+
#}
68+
{% set dict = data[block_name] %}
69+
{% for kper, value in dict.items() %}
70+
BEGIN {{ block_name.upper() }} {{ kper }}
71+
{% for line in value %}
72+
{{ line|join(" ")|upper }}
73+
{% endfor %}
74+
END {{ block_name.upper() }} {{ kper }}
75+
6476
{% endfor %}
6577
{% endmacro %}

test/test_codec.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def test_dumps_ic():
2121
assert result
2222

2323

24-
@pytest.mark.xfail(reason="TODO period block unstructuring")
2524
def test_dumps_oc():
2625
from flopy4.mf6.gwf import Oc
2726

0 commit comments

Comments
 (0)