Skip to content

Commit 2f59250

Browse files
mjrenowpbonelli
authored andcommitted
add multi period test with storage
1 parent 68da84e commit 2f59250

File tree

14 files changed

+382
-216
lines changed

14 files changed

+382
-216
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
}
1414
},
1515
"mypy-type-checker.importStrategy": "fromEnvironment",
16+
"python.analysis.typeCheckingMode": "off",
1617
"files.exclude": {
1718
"**/.git": true,
1819
"**/.svn": true,

flopy4/mf6/adapters.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ def plottable(self):
264264

265265
@property
266266
def has_stress_period_data(self):
267-
# TODO oc returns true? is stress package?
268267
return "nper" in self._data.dims
269268

270269
def check(self, f=None, verbose=True, level=1, checktype=None):

flopy4/mf6/codec/writer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
)
1414
_JINJA_ENV.filters["field_type"] = filters.field_type
1515
_JINJA_ENV.filters["array_how"] = filters.array_how
16-
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
16+
_JINJA_ENV.filters["array2chunks"] = filters.array2chunks
1717
_JINJA_ENV.filters["array2string"] = filters.array2string
18+
_JINJA_ENV.filters["array2const"] = filters.array2const
1819
_JINJA_ENV.filters["data2list"] = filters.data2list
1920
_JINJA_TEMPLATE_NAME = "blocks.jinja"
2021
_PRINT_OPTIONS = {

flopy4/mf6/codec/writer/filters.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import xarray as xr
77
from modflow_devtools.dfn.schema.v2 import FieldType
88
from numpy.typing import NDArray
9+
from xattree import Scalar
910

1011
from flopy4.mf6.constants import FILL_DNODATA
1112

@@ -36,10 +37,12 @@ def array_how(value: xr.DataArray) -> str:
3637
# TODO
3738
# - detect constant arrays?
3839
# - above certain size, use external?
40+
if value.max() == value.min():
41+
return "constant"
3942
return "internal"
4043

4144

42-
def array_chunks(value: xr.DataArray, chunks: Mapping[Hashable, int] | None = None):
45+
def array2chunks(value: xr.DataArray, chunks: Mapping[Hashable, int] | None = None):
4346
"""
4447
Yield chunks from a dask-backed array of up to 3 dimensions.
4548
If it's not already chunked, split it into chunks of the
@@ -128,6 +131,13 @@ def nonempty(value: NDArray | xr.DataArray) -> NDArray:
128131
return mask
129132

130133

134+
def array2const(value: xr.DataArray) -> Scalar:
135+
if np.issubdtype(value.dtype, np.integer):
136+
return value.max().item()
137+
if np.issubdtype(value.dtype, np.floating):
138+
return f"{value.max().item():.8f}"
139+
140+
131141
def data2list(value: list | tuple | dict | xr.Dataset | xr.DataArray):
132142
"""
133143
Yield records (tuples) from data in a `list`, `dict`, `DataArray` or `Dataset`.

flopy4/mf6/codec/writer/templates/macros.jinja

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
{% endmacro %}
2020

2121
{% macro record(value) %}
22-
{{ inset ~ value|join(" ") -}}
22+
{{ inset ~ value|join(" ")|upper -}}
2323
{% endmacro %}
2424

2525
{% macro list(name, value) %}
@@ -32,14 +32,14 @@
3232
{{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}
3333

3434
{% if how == "constant" %}
35-
CONSTANT {{ value.item() }}
35+
{{ inset }}CONSTANT {{ value|array2const -}}
3636
{% elif how == "layered constant" %}
3737
{% for layer in value -%}
38-
CONSTANT {{ layer.item() }}
38+
{{ inset }}CONSTANT {{ layer|array2const -}}
3939
{%- endfor %}
4040
{% elif how == "internal" %}
4141
INTERNAL
42-
{% for chunk in value|array_chunks -%}
42+
{% for chunk in value|array2chunks -%}
4343
{{ (2 * inset) ~ chunk|array2string }}
4444
{%- endfor %}
4545
{% elif how == "external" %}

flopy4/mf6/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# TODO use https://environ-config.readthedocs.io/en/stable/?
22

3-
SPARSE_THRESHOLD = 1000
3+
SPARSE_THRESHOLD = 100000

flopy4/mf6/converter.py

Lines changed: 149 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -88,48 +88,98 @@ def _path_to_tuple(field_name: str, path_value: Path) -> tuple:
8888
return (field_name.upper(), "FILEOUT", str(path_value))
8989

9090

91-
def unstructure_component(value: Component) -> dict[str, Any]:
92-
blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore
93-
blocks: dict[str, dict[str, Any]] = {}
94-
xatspec = xattree.get_xatspec(type(value))
91+
def _user_dims(value, field_value):
92+
# terrible hack to convert flat nodes dimension to 3d structured dims.
93+
# long term solution for this is to use a custom xarray index. filters
94+
# should then have access to all dimensions needed.
95+
dims_ = set(field_value.dims).copy()
96+
parent = value.parent # type: ignore
97+
if parent is None:
98+
# TODO for standalone packages
99+
return field_value
100+
101+
if "nper" in dims_:
102+
dims_.remove("nper")
103+
shape = (
104+
field_value.sizes["nper"],
105+
parent.dims["nlay"],
106+
parent.dims["nrow"],
107+
parent.dims["ncol"],
108+
)
109+
dims = ("nper", "nlay", "nrow", "ncol")
110+
coords = {
111+
"nper": field_value.coords["nper"],
112+
"nlay": range(parent.dims["nlay"]),
113+
"nrow": range(parent.dims["nrow"]),
114+
"ncol": range(parent.dims["ncol"]),
115+
}
116+
else:
117+
shape = (
118+
parent.dims["nlay"],
119+
parent.dims["nrow"],
120+
parent.dims["ncol"],
121+
)
122+
dims = ("nlay", "nrow", "ncol")
123+
coords = {
124+
"nlay": range(parent.dims["nlay"]),
125+
"nrow": range(parent.dims["nrow"]),
126+
"ncol": range(parent.dims["ncol"]),
127+
}
128+
129+
if dims_ == {"nodes"}:
130+
field_value = xr.DataArray(
131+
field_value.data.reshape(shape),
132+
dims=dims,
133+
coords=coords,
134+
)
95135

96-
# Handle child component bindings before converting to dict
97-
if isinstance(value, Context):
98-
for field_name, child_spec in xatspec.children.items():
99-
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
100-
block_name = child_spec.metadata["block"] # type: ignore
101-
field_value = getattr(value, field_name, None)
136+
return field_value
102137

138+
139+
def _get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple]]]:
140+
if not isinstance(value, Context):
141+
return {}
142+
143+
blocks = {}
144+
for name, spec in xattree.get_xatspec(type(value)).children.items():
145+
block_name = spec.metadata["block"]
146+
match child := getattr(value, name):
147+
case None:
148+
continue
149+
case Component():
103150
if block_name not in blocks:
104151
blocks[block_name] = {}
152+
blocks[block_name][name] = [_Binding.from_component(child).to_tuple()]
153+
case MutableMapping():
154+
if block_name not in blocks:
155+
blocks[block_name] = {}
156+
blocks[block_name][name] = [
157+
_Binding.from_component(comp).to_tuple()
158+
for comp in child.values()
159+
if comp is not None
160+
]
161+
case Iterable():
162+
if block_name not in blocks:
163+
blocks[block_name] = {}
164+
blocks[block_name][name] = [
165+
_Binding.from_component(comp).to_tuple()
166+
for comp in child
167+
if comp is not None
168+
]
169+
case _:
170+
raise ValueError(f"Unexpected child type: {type(child)}")
105171

106-
if isinstance(field_value, Component):
107-
components = [_Binding.from_component(field_value).to_tuple()]
108-
elif isinstance(field_value, MutableMapping):
109-
components = [
110-
_Binding.from_component(comp).to_tuple()
111-
for comp in field_value.values()
112-
if comp is not None
113-
]
114-
elif isinstance(field_value, Iterable):
115-
components = [
116-
_Binding.from_component(comp).to_tuple()
117-
for comp in field_value
118-
if comp is not None
119-
]
120-
else:
121-
continue
122-
123-
if components:
124-
blocks[block_name][field_name] = components
172+
return blocks
125173

174+
def unstructure_component(value: Component) -> dict[str, Any]:
175+
dfnspec = value.dfn
176+
xatspec = xattree.get_xatspec(type(value))
177+
blocks: dict[str, dict[str, Any]] = _get_binding_blocks(value)
126178
data = xattree.asdict(value)
127179

128-
for block_name, block in blockspec.items():
180+
for block_name, block in dfnspec.blocks.items():
129181
if block_name not in blocks:
130182
blocks[block_name] = {}
131-
period_data = {}
132-
period_blocks = {} # type: ignore
133183

134184
for field_name in block.keys():
135185
# Skip child components that have been processed as bindings
@@ -141,105 +191,98 @@ def unstructure_component(value: Component) -> dict[str, Any]:
141191

142192
field_value = data[field_name]
143193
# convert:
194+
# - bools to keywords
144195
# - paths to records
145196
# - datetime to ISO format
146197
# - auxiliary fields to tuples
147198
# - xarray DataArrays with 'nper' dimension to kper-sliced datasets
148199
# (and split the period data into separate kper-indexed blocks)
149200
# - other values to their original form
201+
if isinstance(field_value, bool):
202+
if field_value: # only write if true
203+
blocks[block_name][field_name] = field_value
150204
if isinstance(field_value, Path):
151205
rec = _path_to_tuple(field_name, field_value)
152-
# name may have changed e.g dropping '_file' suffix
153-
blocks[block_name][rec[0]] = rec
206+
field_name = rec[0] # '_file' suffix dropped
207+
blocks[block_name][field_name] = rec
154208
elif isinstance(field_value, datetime):
155209
blocks[block_name][field_name] = field_value.isoformat()
156-
elif (
157-
field_name == "auxiliary"
158-
and hasattr(field_value, "values")
159-
and field_value is not None
160-
):
161-
blocks[block_name][field_name] = tuple(field_value.values.tolist())
162-
elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims:
163-
has_spatial_dims = any(
164-
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
165-
)
166-
if has_spatial_dims:
167-
# terrible hack to convert flat nodes dimension to 3d structured dims.
168-
# long term solution for this is to use a custom xarray index. filters
169-
# should then have access to all dimensions needed.
170-
dims_ = set(field_value.dims).copy()
171-
dims_.remove("nper")
172-
if dims_ == {"nodes"}:
173-
parent = value.parent # type: ignore
174-
field_value = xr.DataArray(
175-
field_value.data.reshape(
176-
(
177-
field_value.sizes["nper"],
178-
parent.dims["nlay"],
179-
parent.dims["nrow"],
180-
parent.dims["ncol"],
181-
)
182-
),
183-
dims=("nper", "nlay", "nrow", "ncol"),
184-
coords={
185-
"nper": field_value.coords["nper"],
186-
"nlay": range(parent.dims["nlay"]),
187-
"nrow": range(parent.dims["nrow"]),
188-
"ncol": range(parent.dims["ncol"]),
189-
},
190-
name=field_value.name,
191-
)
192-
193-
period_data[field_name] = {
194-
kper: field_value.isel(nper=kper)
195-
for kper in range(field_value.sizes["nper"])
196-
}
210+
elif isinstance(field_value, xr.DataArray):
211+
if field_name == "auxiliary":
212+
blocks[block_name][field_name] = tuple(field_value.values.tolist())
213+
elif "nper" not in field_value.dims:
214+
blocks[block_name][field_name] = _user_dims(value, field_value)
197215
else:
198-
if np.issubdtype(field_value.dtype, np.str_):
216+
period_data = {}
217+
period_blocks = {}
218+
has_spatial_dims = any(
219+
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
220+
)
221+
if has_spatial_dims:
222+
field_value = _user_dims(value, field_value)
223+
199224
period_data[field_name] = {
200-
kper: field_value[kper] for kper in range(field_value.sizes["nper"])
225+
kper: field_value.isel(nper=kper)
226+
for kper in range(field_value.sizes["nper"])
201227
}
202228
else:
203-
if block_name not in period_data:
204-
period_data[block_name] = {}
205-
period_data[block_name][field_name] = field_value # type: ignore
206-
else:
207-
if field_value is not None:
208-
if isinstance(field_value, bool):
209-
if field_value:
210-
blocks[block_name][field_name] = field_value
211-
else:
212-
blocks[block_name][field_name] = field_value
213-
214-
if block_name in period_data and isinstance(period_data[block_name], dict):
215-
dataset = xr.Dataset(period_data[block_name])
216-
_attach_field_metadata(dataset, type(value), list(period_data[block_name].keys())) # type: ignore
217-
blocks[block_name] = {block_name: dataset}
218-
del period_data[block_name]
219-
220-
for arr_name, periods in period_data.items():
221-
for kper, arr in periods.items():
222-
if kper not in period_blocks:
223-
period_blocks[kper] = {}
224-
period_blocks[kper][arr_name] = arr
225-
226-
for kper, block in period_blocks.items():
227-
dataset = xr.Dataset(block)
228-
_attach_field_metadata(dataset, type(value), list(block.keys()))
229-
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}
229+
if np.issubdtype(field_value.dtype, np.str_):
230+
period_data[field_name] = {
231+
kper: field_value[kper]
232+
for kper in range(field_value.sizes["nper"])
233+
if field_value[kper] is not None
234+
}
235+
else:
236+
if block_name not in period_data:
237+
period_data[block_name] = {}
238+
period_data[block_name][field_name] = field_value # type: ignore
239+
240+
dataset = xr.Dataset(period_data[block_name])
241+
_attach_field_metadata(dataset, type(value), list(period_data[block_name].keys())) # type: ignore
242+
blocks[block_name] = {block_name: dataset}
243+
del period_data[block_name]
244+
245+
for arr_name, periods in period_data.items():
246+
for kper, arr in periods.items():
247+
if isinstance(arr, xr.DataArray):
248+
max = arr.max()
249+
if max == arr.min() and max == FILL_DNODATA:
250+
# don't write empty period blocks unless
251+
# to intentionally reset data
252+
pass
253+
else:
254+
if kper not in period_blocks:
255+
period_blocks[kper] = {}
256+
period_blocks[kper][arr_name] = arr
257+
else:
258+
if kper not in period_blocks:
259+
period_blocks[kper] = {}
260+
period_blocks[kper][arr_name] = arr.upper()
261+
262+
for kper, block in period_blocks.items():
263+
dataset = xr.Dataset(block)
264+
_attach_field_metadata(dataset, type(value), list(block.keys()))
265+
blocks[f"{block_name} {kper + 1}"] = {block_name: dataset}
266+
elif field_value is not None:
267+
blocks[block_name][field_name] = field_value
230268

231269
# make sure options block always comes first
270+
# TODO: blocks should already be sorted here
232271
if "options" in blocks:
233272
options_block = blocks.pop("options")
234273
blocks = {"options": options_block, **blocks}
235274

236-
# total temporary hack! manually set solutiongroup 1. still need to support multiple..
275+
# total temporary hack! manually set solutiongroup 1.
276+
# TODO support multiple solution groups
237277
if "solutiongroup" in blocks:
238278
sg = blocks["solutiongroup"]
239279
blocks["solutiongroup 1"] = sg
240280
del blocks["solutiongroup"]
241281

242-
return {name: block for name, block in blocks.items() if name != "period"}
282+
# remove period block
283+
blocks.pop("period", None)
284+
285+
return blocks
243286

244287

245288
def _make_converter() -> Converter:

0 commit comments

Comments
 (0)