Skip to content

Commit 131a661

Browse files
authored
simplify the writer, fix pyproject.toml and launch.json (#197)
consolidate filters, improve use of the devtools dfn spec, fix duplicate ruff error with newer pixi in pyproject.toml, etc
1 parent 1155fbe commit 131a661

File tree

11 files changed

+249
-336
lines changed

11 files changed

+249
-336
lines changed

.vscode/launch.json

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,6 @@
44
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
55
"version": "0.2.0",
66
"configurations": [
7-
{
8-
"name": "Python Debugger: Current File",
9-
"type": "debugpy",
10-
"request": "launch",
11-
"program": "${file}",
12-
"console": "integratedTerminal",
13-
"purpose": [
14-
"debug-test"
15-
],
16-
"justMyCode": false
17-
},
187
{
198
"name": "quickstart",
209
"type": "debugpy",

.vscode/settings.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
"editor.formatOnSave": true,
33
"files.insertFinalNewline": true,
44
"python.testing.pytestArgs": [
5-
"test",
6-
"-s"
5+
"test"
76
],
87
"python.testing.unittestEnabled": false,
98
"python.testing.pytestEnabled": true,

flopy4/mf6/codec/writer/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
trim_blocks=True,
1212
lstrip_blocks=True,
1313
)
14-
_JINJA_ENV.filters["is_dataset"] = filters.is_dataset
15-
_JINJA_ENV.filters["field_format"] = filters.field_format
14+
_JINJA_ENV.filters["field_type"] = filters.field_type
1615
_JINJA_ENV.filters["array_how"] = filters.array_how
1716
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
1817
_JINJA_ENV.filters["array2string"] = filters.array2string
1918
_JINJA_ENV.filters["data2list"] = filters.data2list
20-
_JINJA_ENV.filters["data2keystring"] = filters.data2keystring
2119
_JINJA_TEMPLATE_NAME = "blocks.jinja"
2220
_PRINT_OPTIONS = {
2321
"precision": 4,

flopy4/mf6/codec/writer/filters.py

Lines changed: 72 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -4,59 +4,32 @@
44

55
import numpy as np
66
import xarray as xr
7+
from modflow_devtools.dfn.schema.v2 import FieldType
78
from numpy.typing import NDArray
89

910
from flopy4.mf6.constants import FILL_DNODATA
1011

1112

12-
def _is_keystring_format(dataset: xr.Dataset) -> bool:
13-
"""Check if dataset should use keystring format based on metadata."""
14-
field_metadata = dataset.attrs.get("field_metadata", {})
15-
return any(meta.get("format") == "keystring" for meta in field_metadata.values())
13+
def field_type(value: Any) -> FieldType:
14+
"""Get a value's type according to the MF6 specification."""
1615

17-
18-
def _is_tabular_time_format(dataset: xr.Dataset) -> bool:
19-
"""True if a dataset has multiple columns and only one dimension 'nper'."""
20-
return len(dataset.data_vars) > 1 and all(
21-
"nper" in var.dims and len(var.dims) == 1 for var in dataset.data_vars.values()
22-
)
23-
24-
25-
def is_dataset(value: Any) -> bool:
26-
return isinstance(value, xr.Dataset)
27-
28-
29-
def field_format(value: Any) -> str:
30-
"""
31-
Get a field's formatting type as defined by the MF6 definition language:
32-
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types
33-
"""
3416
if isinstance(value, bool):
3517
return "keyword"
3618
if isinstance(value, int):
3719
return "integer"
3820
if isinstance(value, float):
39-
return "double precision"
21+
return "double"
4022
if isinstance(value, str):
4123
return "string"
42-
if isinstance(value, (dict, tuple)):
24+
if isinstance(value, tuple):
4325
return "record"
4426
if isinstance(value, xr.DataArray):
4527
if value.dtype == "object":
4628
return "list"
4729
return "array"
48-
if isinstance(value, (xr.Dataset, list)):
49-
if isinstance(value, xr.Dataset):
50-
if _is_keystring_format(value):
51-
return "keystring"
52-
if _is_tabular_time_format(value):
53-
return "list"
30+
if isinstance(value, (list, dict, xr.Dataset)):
5431
return "list"
55-
return "keystring"
56-
57-
58-
def has_time_dim(value: Any) -> bool:
59-
return isinstance(value, xr.DataArray) and "nper" in value.dims
32+
raise ValueError(f"Unsupported field type: {type(value)}")
6033

6134

6235
def array_how(value: xr.DataArray) -> str:
@@ -140,38 +113,42 @@ def array2string(value: NDArray) -> str:
140113
return buffer.getvalue().strip()
141114

142115

143-
def nonempty(arr: NDArray | xr.DataArray) -> NDArray:
144-
if isinstance(arr, xr.DataArray):
145-
arr = arr.values
146-
if arr.dtype == "object":
147-
mask = arr != None # noqa: E711
116+
def nonempty(value: NDArray | xr.DataArray) -> NDArray:
117+
"""
118+
Return a boolean mask of non-empty (non-nodata) values in an array.
119+
TODO: don't hardcode FILL_DNODATA, support different fill values
120+
"""
121+
if isinstance(value, xr.DataArray):
122+
value = value.values
123+
if value.dtype == "object":
124+
mask = value != None # noqa: E711
148125
else:
149-
mask = ~np.ma.masked_invalid(arr).mask
150-
mask = mask & (arr != FILL_DNODATA)
126+
mask = ~np.ma.masked_invalid(value).mask
127+
mask = mask & (value != FILL_DNODATA)
151128
return mask
152129

153130

154-
def data2list(value: list | xr.DataArray | xr.Dataset):
131+
def data2list(value: list | tuple | dict | xr.Dataset | xr.DataArray):
155132
"""
156-
Yield record tuples from a list, `DataArray` or `Dataset`.
157-
158-
Yields
159-
------
160-
tuple
161-
Tuples of (*cellid, *values) or (*values) depending on spatial dimensions
133+
Yield records (tuples) from data in a `list`, `dict`, `DataArray` or `Dataset`.
162134
"""
163135

164-
if isinstance(value, list):
165-
for item in value:
166-
yield item
136+
if isinstance(value, (list, tuple)):
137+
for rec in value:
138+
yield rec
139+
return
140+
141+
if isinstance(value, dict):
142+
for name, val in value.values():
143+
yield (name, val)
167144
return
168145

169146
if isinstance(value, xr.Dataset):
170147
yield from dataset2list(value)
171148
return
172149

173-
# handle scalar
174-
if value.ndim == 0:
150+
# otherwise we have a DataArray
151+
if value.ndim == 0: # handle scalar
175152
if not np.isnan(value.item()) and value.item() is not None:
176153
yield (value.item(),)
177154
return
@@ -184,90 +161,67 @@ def data2list(value: list | xr.DataArray | xr.Dataset):
184161
for i, val in enumerate(values):
185162
if has_spatial_dims:
186163
cellid = tuple(idx[i] + 1 for idx in indices)
187-
result = cellid + (val,)
164+
rec = cellid + (val,)
188165
else:
189-
result = (val,)
190-
yield result
166+
rec = (val,)
167+
yield rec
191168

192169

193170
def dataset2list(value: xr.Dataset):
194171
"""
195-
Yield record tuples from an xarray Dataset. For regular/tabular list-based format.
172+
Yield records (tuples) from an `xarray.Dataset`.
196173
197-
Yields
198-
------
199-
tuple
200-
Tuples of (*cellid, *values) or (*values) depending on spatial dimensions
174+
If the first data variable is a string type, assume all are
175+
string type. Then the dataset represents a keystring; yield
176+
tuples of (name, *value). Otherwise, yield tuples: (*value)
177+
if no spatial dimensions, or (*cellid, *value) when spatial
178+
dimensions are present.
201179
"""
202180
if value is None or not any(value.data_vars):
203181
return
204182

205-
# handle scalar
206-
first_arr = next(iter(value.data_vars.values()))
207-
if first_arr.ndim == 0:
208-
field_vals = []
209-
for field_name in value.data_vars.keys():
210-
field_val = value[field_name]
211-
if hasattr(field_val, "item"):
212-
field_vals.append(field_val.item())
213-
else:
214-
field_vals.append(field_val)
215-
yield tuple(field_vals)
183+
first = next(iter(value.data_vars.values()))
184+
is_union = first.dtype.type is np.str_
185+
186+
if first.ndim == 0: # handle scalar
187+
if is_union:
188+
for name in value.data_vars.keys():
189+
val = value[name]
190+
val = val.item() if val.shape == () else val
191+
yield (*name.split("_"), val)
192+
else:
193+
vals = []
194+
for name in value.data_vars.keys():
195+
val = value[name]
196+
val = val.item() if val.shape == () else val
197+
vals.append(val)
198+
yield tuple(vals)
216199
return
217200

218-
# build mask
219201
combined_mask: Any = None
220-
for field_name, arr in value.data_vars.items():
221-
mask = nonempty(arr)
202+
for name, first in value.data_vars.items():
203+
mask = nonempty(first)
222204
combined_mask = mask if combined_mask is None else combined_mask | mask
223205
if combined_mask is None or not np.any(combined_mask):
224206
return
225207

226-
spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nodes")]
208+
spatial_dims = [d for d in first.dims if d in ("nlay", "nrow", "ncol", "nodes")]
227209
has_spatial_dims = len(spatial_dims) > 0
228210
indices = np.where(combined_mask)
229211
for i in range(len(indices[0])):
230-
field_vals = []
231-
for field_name in value.data_vars.keys():
232-
field_val = value[field_name][tuple(idx[i] for idx in indices)]
233-
if hasattr(field_val, "item"):
234-
field_vals.append(field_val.item())
235-
else:
236-
field_vals.append(field_val)
237-
if has_spatial_dims:
238-
cellid = tuple(idx[i] + 1 for idx in indices)
239-
yield cellid + tuple(field_vals)
212+
if is_union:
213+
for name in value.data_vars.keys():
214+
val = value[name][tuple(idx[i] for idx in indices)]
215+
val = val.item() if val.shape == () else val
216+
yield (*name.split("_"), val)
240217
else:
241-
yield tuple(field_vals)
242-
243-
244-
def data2keystring(value: dict | xr.Dataset):
245-
"""
246-
Yield record tuples from a dict or dataset. For irregular list-based format, i.e. keystrings.
247-
248-
Yields
249-
------
250-
tuple
251-
Tuples of (field_name, value) for use with record macro
252-
"""
253-
if isinstance(value, dict):
254-
if not value:
255-
return
256-
for field_name, field_val in value.items():
257-
yield (field_name.upper(), field_val)
258-
elif isinstance(value, xr.Dataset):
259-
if value is None or not any(value.data_vars):
260-
return
261-
262-
for field_name in value.data_vars.keys():
263-
name = (
264-
field_name.replace("_", " ").upper()
265-
if np.issubdtype(value.data_vars[field_name].dtype, np.str_)
266-
else field_name.upper()
267-
)
268-
field_val = value[field_name]
269-
if hasattr(field_val, "item"):
270-
val = field_val.item()
218+
vals = []
219+
for name in value.data_vars.keys():
220+
val = value[name][tuple(idx[i] for idx in indices)]
221+
val = val.item() if val.shape == () else val
222+
vals.append(val)
223+
if has_spatial_dims:
224+
cellid = tuple(idx[i] + 1 for idx in indices)
225+
yield cellid + tuple(vals)
271226
else:
272-
val = field_val
273-
yield (name, val)
227+
yield tuple(vals)
Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,48 @@
1-
{% set inset = " " %}
1+
{% set inset = " " %}
22

33
{% macro field(name, value) %}
4-
{% set format = value|field_format %}
5-
{% if format in ['keyword', 'integer', 'double precision', 'string'] %}
4+
{% set type = value|field_type %}
5+
{% if type in ['keyword', 'integer', 'double', 'string'] %}
66
{{ scalar(name, value) }}
7-
{% elif format == 'record' %}
8-
{{ record(name, value) }}
9-
{% elif format == 'keystring' %}
10-
{{ keystring(name, value) }}
11-
{% elif format == 'array' %}
7+
{% elif type == 'record' %}
8+
{{ record(value) }}
9+
{% elif type == 'array' %}
1210
{{ array(name, value, how=value|array_how) }}
13-
{% elif format == 'list' %}
11+
{% elif type == 'list' %}
1412
{{ list(name, value) }}
1513
{% endif %}
1614
{% endmacro %}
1715

1816
{% macro scalar(name, value) %}
19-
{% set format = value|field_format %}
20-
{% if value is not none %}{{ inset ~ name.upper() }}{% if format != 'keyword' %} {{ value }}{% endif %}{% endif %}
17+
{% set type = value|field_type %}
18+
{% if value is not none %}{{ inset ~ name.upper() }}{% if type != 'keyword' %} {{ value }}{% endif %}{% endif %}
2119
{% endmacro %}
2220

23-
{% macro keystring(name, value) %}
24-
{% for option in (value|data2keystring) -%}
25-
{{ record("", option) }}{% if not loop.last %}{{ "\n" }}{% endif %}
26-
{%- endfor %}
21+
{% macro record(value) %}
22+
{{ inset ~ value|join(" ") -}}
2723
{% endmacro %}
2824

29-
{% macro record(name, value) %}
30-
{%- if value is mapping %}
31-
{% for field_name, field_value in value.items() -%}
32-
{{ field_name.upper() }} {{ field(field_value) }}
25+
{% macro list(name, value) %}
26+
{% for row in (value|data2list) %}
27+
{{ record(row) }}{% if not loop.last %}{{ "\n" }}{% endif %}
3328
{%- endfor %}
34-
{% else %}
35-
{{ inset ~ value|join(" ") }}
36-
{%- endif %}
3729
{% endmacro %}
3830

3931
{% macro array(name, value, how="internal") %}
4032
{{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}
4133

4234
{% if how == "constant" %}
43-
CONSTANT {{ value.item() }}
35+
CONSTANT {{ value.item() }}
4436
{% elif how == "layered constant" %}
4537
{% for layer in value -%}
46-
CONSTANT {{ layer.item() }}
38+
CONSTANT {{ layer.item() }}
4739
{%- endfor %}
4840
{% elif how == "internal" %}
49-
INTERNAL
41+
INTERNAL
5042
{% for chunk in value|array_chunks -%}
5143
{{ (2 * inset) ~ chunk|array2string }}
5244
{%- endfor %}
5345
{% elif how == "external" %}
54-
OPEN/CLOSE {{ value }}
46+
OPEN/CLOSE {{ value }}
5547
{% endif %}
5648
{% endmacro %}
57-
58-
{% macro list(name, value) %}
59-
{% for row in (value|data2list) %}
60-
{{ inset ~ row|join(" ") }}{% if not loop.last %}{{ "\n" }}{% endif %}
61-
{%- endfor %}
62-
{% endmacro %}

0 commit comments

Comments
 (0)