Skip to content

Commit 96d2b3a

Browse files
committed
simplify the writer
1 parent 1155fbe commit 96d2b3a

File tree

10 files changed

+224
-306
lines changed

10 files changed

+224
-306
lines changed

.vscode/launch.json

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,6 @@
44
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
55
"version": "0.2.0",
66
"configurations": [
7-
{
8-
"name": "Python Debugger: Current File",
9-
"type": "debugpy",
10-
"request": "launch",
11-
"program": "${file}",
12-
"console": "integratedTerminal",
13-
"purpose": [
14-
"debug-test"
15-
],
16-
"justMyCode": false
17-
},
187
{
198
"name": "quickstart",
209
"type": "debugpy",

.vscode/settings.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
"editor.formatOnSave": true,
33
"files.insertFinalNewline": true,
44
"python.testing.pytestArgs": [
5-
"test",
6-
"-s"
5+
"test"
76
],
87
"python.testing.unittestEnabled": false,
98
"python.testing.pytestEnabled": true,

flopy4/mf6/codec/writer/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
trim_blocks=True,
1212
lstrip_blocks=True,
1313
)
14-
_JINJA_ENV.filters["is_dataset"] = filters.is_dataset
15-
_JINJA_ENV.filters["field_format"] = filters.field_format
14+
_JINJA_ENV.filters["field_type"] = filters.field_type
1615
_JINJA_ENV.filters["array_how"] = filters.array_how
1716
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
1817
_JINJA_ENV.filters["array2string"] = filters.array2string
1918
_JINJA_ENV.filters["data2list"] = filters.data2list
20-
_JINJA_ENV.filters["data2keystring"] = filters.data2keystring
2119
_JINJA_TEMPLATE_NAME = "blocks.jinja"
2220
_PRINT_OPTIONS = {
2321
"precision": 4,

flopy4/mf6/codec/writer/filters.py

Lines changed: 53 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,21 @@
44

55
import numpy as np
66
import xarray as xr
7+
from modflow_devtools.dfn.schema.v2 import FieldType
78
from numpy.typing import NDArray
89

910
from flopy4.mf6.constants import FILL_DNODATA
1011

1112

12-
def _is_keystring_format(dataset: xr.Dataset) -> bool:
13-
"""Check if dataset should use keystring format based on metadata."""
14-
field_metadata = dataset.attrs.get("field_metadata", {})
15-
return any(meta.get("format") == "keystring" for meta in field_metadata.values())
13+
def field_type(value: Any) -> FieldType:
14+
"""Get a value's type according to the MF6 specification."""
1615

17-
18-
def _is_tabular_time_format(dataset: xr.Dataset) -> bool:
19-
"""True if a dataset has multiple columns and only one dimension 'nper'."""
20-
return len(dataset.data_vars) > 1 and all(
21-
"nper" in var.dims and len(var.dims) == 1 for var in dataset.data_vars.values()
22-
)
23-
24-
25-
def is_dataset(value: Any) -> bool:
26-
return isinstance(value, xr.Dataset)
27-
28-
29-
def field_format(value: Any) -> str:
30-
"""
31-
Get a field's formatting type as defined by the MF6 definition language:
32-
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types
33-
"""
3416
if isinstance(value, bool):
3517
return "keyword"
3618
if isinstance(value, int):
3719
return "integer"
3820
if isinstance(value, float):
39-
return "double precision"
21+
return "double"
4022
if isinstance(value, str):
4123
return "string"
4224
if isinstance(value, (dict, tuple)):
@@ -45,18 +27,9 @@ def field_format(value: Any) -> str:
4527
if value.dtype == "object":
4628
return "list"
4729
return "array"
48-
if isinstance(value, (xr.Dataset, list)):
49-
if isinstance(value, xr.Dataset):
50-
if _is_keystring_format(value):
51-
return "keystring"
52-
if _is_tabular_time_format(value):
53-
return "list"
30+
if isinstance(value, (list, xr.Dataset)):
5431
return "list"
55-
return "keystring"
56-
57-
58-
def has_time_dim(value: Any) -> bool:
59-
return isinstance(value, xr.DataArray) and "nper" in value.dims
32+
raise ValueError(f"Unsupported field type: {type(value)}")
6033

6134

6235
def array_how(value: xr.DataArray) -> str:
@@ -140,20 +113,26 @@ def array2string(value: NDArray) -> str:
140113
return buffer.getvalue().strip()
141114

142115

143-
def nonempty(arr: NDArray | xr.DataArray) -> NDArray:
144-
if isinstance(arr, xr.DataArray):
145-
arr = arr.values
146-
if arr.dtype == "object":
147-
mask = arr != None # noqa: E711
116+
def nonempty(value: NDArray | xr.DataArray) -> NDArray:
117+
"""
118+
Return a boolean mask of non-empty (non-nodata) values in an array.
119+
TODO: don't hardcode FILL_DNODATA, support different fill values
120+
"""
121+
if isinstance(value, xr.DataArray):
122+
value = value.values
123+
if value.dtype == "object":
124+
mask = value != None # noqa: E711
148125
else:
149-
mask = ~np.ma.masked_invalid(arr).mask
150-
mask = mask & (arr != FILL_DNODATA)
126+
mask = ~np.ma.masked_invalid(value).mask
127+
mask = mask & (value != FILL_DNODATA)
151128
return mask
152129

153130

154-
def data2list(value: list | xr.DataArray | xr.Dataset):
131+
def data2list(value: list | dict | xr.Dataset | xr.DataArray):
155132
"""
156-
Yield record tuples from a list, `DataArray` or `Dataset`.
133+
Yield records (tuples) from data in a `list`, `dict`, `DataArray` or `Dataset`.
134+
Data can be regular or irregular: every item in a `list` is of the same record
135+
type, while items in a `dict` or `Dataset` can be of different types.
157136
158137
Yields
159138
------
@@ -162,16 +141,21 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
162141
"""
163142

164143
if isinstance(value, list):
165-
for item in value:
166-
yield item
144+
for rec in value:
145+
yield rec
146+
return
147+
148+
if isinstance(value, dict):
149+
for name, val in value.items():
150+
yield (name.upper(), val)
167151
return
168152

169153
if isinstance(value, xr.Dataset):
170154
yield from dataset2list(value)
171155
return
172156

173-
# handle scalar
174-
if value.ndim == 0:
157+
# otherwise we have a DataArray
158+
if value.ndim == 0: # handle scalar
175159
if not np.isnan(value.item()) and value.item() is not None:
176160
yield (value.item(),)
177161
return
@@ -184,15 +168,15 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
184168
for i, val in enumerate(values):
185169
if has_spatial_dims:
186170
cellid = tuple(idx[i] + 1 for idx in indices)
187-
result = cellid + (val,)
171+
rec = cellid + (val,)
188172
else:
189-
result = (val,)
190-
yield result
173+
rec = (val,)
174+
yield rec
191175

192176

193177
def dataset2list(value: xr.Dataset):
194178
"""
195-
Yield record tuples from an xarray Dataset. For regular/tabular list-based format.
179+
Yield record tuples from an `xarray.Dataset`. For regular/tabular list-based format.
196180
197181
Yields
198182
------
@@ -202,72 +186,36 @@ def dataset2list(value: xr.Dataset):
202186
if value is None or not any(value.data_vars):
203187
return
204188

205-
# handle scalar
206-
first_arr = next(iter(value.data_vars.values()))
207-
if first_arr.ndim == 0:
208-
field_vals = []
209-
for field_name in value.data_vars.keys():
210-
field_val = value[field_name]
211-
if hasattr(field_val, "item"):
212-
field_vals.append(field_val.item())
213-
else:
214-
field_vals.append(field_val)
215-
yield tuple(field_vals)
189+
first = next(iter(value.data_vars.values()))
190+
if first.ndim == 0: # handle scalar
191+
vals = []
192+
for name in value.data_vars.keys():
193+
val = value[name]
194+
val = val.item() if val.shape == () else val
195+
vals.append(val)
196+
yield tuple(vals)
216197
return
217198

218-
# build mask
219199
combined_mask: Any = None
220-
for field_name, arr in value.data_vars.items():
221-
mask = nonempty(arr)
200+
for name, first in value.data_vars.items():
201+
mask = nonempty(first)
222202
combined_mask = mask if combined_mask is None else combined_mask | mask
223203
if combined_mask is None or not np.any(combined_mask):
224204
return
225205

226-
spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nodes")]
206+
spatial_dims = [d for d in first.dims if d in ("nlay", "nrow", "ncol", "nodes")]
227207
has_spatial_dims = len(spatial_dims) > 0
228208
indices = np.where(combined_mask)
229209
for i in range(len(indices[0])):
230-
field_vals = []
231-
for field_name in value.data_vars.keys():
232-
field_val = value[field_name][tuple(idx[i] for idx in indices)]
233-
if hasattr(field_val, "item"):
234-
field_vals.append(field_val.item())
210+
vals = []
211+
for name in value.data_vars.keys():
212+
val = value[name][tuple(idx[i] for idx in indices)]
213+
if hasattr(val, "item"):
214+
vals.append(val.item())
235215
else:
236-
field_vals.append(field_val)
216+
vals.append(val)
237217
if has_spatial_dims:
238218
cellid = tuple(idx[i] + 1 for idx in indices)
239-
yield cellid + tuple(field_vals)
219+
yield cellid + tuple(vals)
240220
else:
241-
yield tuple(field_vals)
242-
243-
244-
def data2keystring(value: dict | xr.Dataset):
245-
"""
246-
Yield record tuples from a dict or dataset. For irregular list-based format, i.e. keystrings.
247-
248-
Yields
249-
------
250-
tuple
251-
Tuples of (field_name, value) for use with record macro
252-
"""
253-
if isinstance(value, dict):
254-
if not value:
255-
return
256-
for field_name, field_val in value.items():
257-
yield (field_name.upper(), field_val)
258-
elif isinstance(value, xr.Dataset):
259-
if value is None or not any(value.data_vars):
260-
return
261-
262-
for field_name in value.data_vars.keys():
263-
name = (
264-
field_name.replace("_", " ").upper()
265-
if np.issubdtype(value.data_vars[field_name].dtype, np.str_)
266-
else field_name.upper()
267-
)
268-
field_val = value[field_name]
269-
if hasattr(field_val, "item"):
270-
val = field_val.item()
271-
else:
272-
val = field_val
273-
yield (name, val)
221+
yield tuple(vals)
Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,24 @@
1-
{% set inset = " " %}
1+
{% set inset = " " %}
22

33
{% macro field(name, value) %}
4-
{% set format = value|field_format %}
5-
{% if format in ['keyword', 'integer', 'double precision', 'string'] %}
4+
{% set type = value|field_type %}
5+
{% if type in ['keyword', 'integer', 'double', 'string'] %}
66
{{ scalar(name, value) }}
7-
{% elif format == 'record' %}
8-
{{ record(name, value) }}
9-
{% elif format == 'keystring' %}
10-
{{ keystring(name, value) }}
11-
{% elif format == 'array' %}
7+
{% elif type == 'record' %}
8+
{{ record([name, value]) }}
9+
{% elif type == 'array' %}
1210
{{ array(name, value, how=value|array_how) }}
13-
{% elif format == 'list' %}
11+
{% elif type == 'list' %}
1412
{{ list(name, value) }}
1513
{% endif %}
1614
{% endmacro %}
1715

1816
{% macro scalar(name, value) %}
19-
{% set format = value|field_format %}
20-
{% if value is not none %}{{ inset ~ name.upper() }}{% if format != 'keyword' %} {{ value }}{% endif %}{% endif %}
17+
{% set type = value|field_type %}
18+
{% if value is not none %}{{ inset ~ name.upper() }}{% if type != 'keyword' %} {{ value }}{% endif %}{% endif %}
2119
{% endmacro %}
2220

23-
{% macro keystring(name, value) %}
24-
{% for option in (value|data2keystring) -%}
25-
{{ record("", option) }}{% if not loop.last %}{{ "\n" }}{% endif %}
26-
{%- endfor %}
27-
{% endmacro %}
28-
29-
{% macro record(name, value) %}
21+
{% macro record(value) %}
3022
{%- if value is mapping %}
3123
{% for field_name, field_value in value.items() -%}
3224
{{ field_name.upper() }} {{ field(field_value) }}
@@ -36,27 +28,27 @@
3628
{%- endif %}
3729
{% endmacro %}
3830

31+
{% macro list(name, value) %}
32+
{% for row in (value|data2list) %}
33+
{{ record(row) }}{% if not loop.last %}{{ "\n" }}{% endif %}
34+
{%- endfor %}
35+
{% endmacro %}
36+
3937
{% macro array(name, value, how="internal") %}
4038
{{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}
4139

4240
{% if how == "constant" %}
43-
CONSTANT {{ value.item() }}
41+
CONSTANT {{ value.item() }}
4442
{% elif how == "layered constant" %}
4543
{% for layer in value -%}
46-
CONSTANT {{ layer.item() }}
44+
CONSTANT {{ layer.item() }}
4745
{%- endfor %}
4846
{% elif how == "internal" %}
49-
INTERNAL
47+
INTERNAL
5048
{% for chunk in value|array_chunks -%}
5149
{{ (2 * inset) ~ chunk|array2string }}
5250
{%- endfor %}
5351
{% elif how == "external" %}
54-
OPEN/CLOSE {{ value }}
52+
OPEN/CLOSE {{ value }}
5553
{% endif %}
5654
{% endmacro %}
57-
58-
{% macro list(name, value) %}
59-
{% for row in (value|data2list) %}
60-
{{ inset ~ row|join(" ") }}{% if not loop.last %}{{ "\n" }}{% endif %}
61-
{%- endfor %}
62-
{% endmacro %}

flopy4/mf6/converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _get_binding_terms(component: Component) -> tuple[str, ...] | None:
7272
def _attach_field_metadata(
7373
dataset: xr.Dataset, component_type: type, field_names: list[str]
7474
) -> None:
75+
# TODO: attach metadata to array attrs instead of dataset attrs
7576
field_metadata = {}
7677
component_fields = fields_dict(component_type)
7778
for field_name in field_names:

0 commit comments

Comments
 (0)