Skip to content

Commit f7ac532

Browse files
committed
various fixes
1 parent d2b2e68 commit f7ac532

File tree

8 files changed

+94
-61
lines changed

8 files changed

+94
-61
lines changed

flopy4/mf6/codec/writer/filters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def array2const(value: xr.DataArray) -> Scalar:
134134
return value.max().item()
135135
if np.issubdtype(value.dtype, np.floating):
136136
return f"{value.max().item():.8f}"
137+
return value.ravel()[0]
137138

138139

139140
def data2list(value: list | dict | xr.Dataset | xr.DataArray):
@@ -149,7 +150,7 @@ def data2list(value: list | dict | xr.Dataset | xr.DataArray):
149150
return
150151

151152
if isinstance(value, dict):
152-
for name, val in value.values():
153+
for name, val in value.items():
153154
yield (name, val)
154155
return
155156

flopy4/mf6/codec/writer/templates/macros.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
{% endmacro %}
2020

2121
{% macro record(value) %}
22-
{{ inset ~ value|join(" ")|upper -}}
22+
{{ inset ~ value|join(" ") -}}
2323
{% endmacro %}
2424

2525
{% macro list(name, value) %}

flopy4/mf6/component.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,15 @@ def write(self, format: str = MF6) -> None:
202202
def to_dict(self, blocks: bool = False) -> dict[str, Any]:
203203
"""Convert the component to a dictionary representation."""
204204
data = xattree_asdict(self)
205+
data.pop("filename")
206+
data.pop("workspace", None)
207+
data.pop("nodes", None) # TODO: find a better way to omit
205208
if blocks:
206-
blocks_ = {}
207-
for field_name, field_meta in self.data.attrs["metadata"].items():
208-
block_name = field_meta["block"]
209+
blocks_ = {} # type: ignore
210+
for field_name, field_value in data.items():
211+
block_name = self.dfn.fields[field_name].block
209212
if block_name not in blocks_:
210213
blocks_[block_name] = {}
211-
blocks_[block_name][field_name] = data[field_name]
214+
blocks_[block_name][field_name] = field_value
212215
return blocks_
213216
return data

flopy4/mf6/converter.py

Lines changed: 70 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str,
3232
if not isinstance(value, Context):
3333
return {}
3434

35-
blocks = {}
35+
blocks = {} # type: ignore
3636
xatspec = xattree.get_xatspec(type(value))
3737

3838
for child_name, child_spec in xatspec.children.items():
3939
if (child := getattr(value, child_name, None)) is None:
4040
continue
41-
if (block_name := child_spec.metadata["block"]) not in blocks:
41+
if (block_name := child_spec.metadata["block"]) not in blocks: # type: ignore
4242
blocks[block_name] = {}
4343
match child:
4444
case Component():
@@ -80,55 +80,57 @@ def has_tdis_dims(value: xr.DataArray) -> bool:
8080
return "nper" in value.dims
8181

8282

83-
def _hack_grid_dims(value, field_value):
84-
# terrible hack to convert flat nodes dimension to 3d structured dims.
85-
# long term solution for this is to use a custom xarray index. filters
86-
# should then have access to all dimensions needed.
87-
dims_ = set(field_value.dims).copy()
88-
parent = value.parent # type: ignore
89-
if parent is None:
90-
# TODO for standalone packages
91-
return field_value
83+
def _hack_structured_grid_dims(value: xr.DataArray, structured_grid_dims: Mapping):
84+
"""
85+
Temporary hack to convert flat nodes dimension to 3d structured dims.
86+
long term solution for this is to use a custom xarray index. filters
87+
should then have access to all dimensions needed.
88+
"""
9289

93-
if "nper" in dims_:
94-
dims_.remove("nper")
95-
shape = (
96-
field_value.sizes["nper"],
97-
parent.dims["nlay"],
98-
parent.dims["nrow"],
99-
parent.dims["ncol"],
90+
if "nper" in (old_dims := set(value.dims).copy()):
91+
old_dims.remove("nper")
92+
shape: tuple[int, ...] = (
93+
value.sizes["nper"],
94+
structured_grid_dims["nlay"],
95+
structured_grid_dims["nrow"],
96+
structured_grid_dims["ncol"],
10097
)
101-
dims = ("nper", "nlay", "nrow", "ncol")
98+
dims: tuple[str, ...] = ("nper", "nlay", "nrow", "ncol")
10299
coords = {
103-
"nper": field_value.coords["nper"],
104-
"nlay": range(parent.dims["nlay"]),
105-
"nrow": range(parent.dims["nrow"]),
106-
"ncol": range(parent.dims["ncol"]),
100+
"nper": value.coords["nper"],
101+
"nlay": range(structured_grid_dims["nlay"]),
102+
"nrow": range(structured_grid_dims["nrow"]),
103+
"ncol": range(structured_grid_dims["ncol"]),
107104
}
108105
else:
109106
shape = (
110-
parent.dims["nlay"],
111-
parent.dims["nrow"],
112-
parent.dims["ncol"],
107+
structured_grid_dims["nlay"],
108+
structured_grid_dims["nrow"],
109+
structured_grid_dims["ncol"],
113110
)
114111
dims = ("nlay", "nrow", "ncol")
115112
coords = {
116-
"nlay": range(parent.dims["nlay"]),
117-
"nrow": range(parent.dims["nrow"]),
118-
"ncol": range(parent.dims["ncol"]),
113+
"nlay": range(structured_grid_dims["nlay"]),
114+
"nrow": range(structured_grid_dims["nrow"]),
115+
"ncol": range(structured_grid_dims["ncol"]),
119116
}
120117

121-
if dims_ == {"nodes"}:
122-
field_value = xr.DataArray(
123-
field_value.data.reshape(shape),
118+
if old_dims == {"nodes"}:
119+
value = xr.DataArray(
120+
value.data.reshape(shape),
124121
dims=dims,
125122
coords=coords,
126123
)
127124

128-
return field_value
125+
return value
129126

130127

131-
def unstructure_field(name: str, value: Any) -> tuple[str, Any]:
128+
def unstructure_field(
129+
name: str,
130+
value: Any,
131+
# TODO: temporary, remove not needed
132+
structured_grid_dims: Mapping | None,
133+
) -> tuple[str, Any]:
132134
"""
133135
Convert:
134136
@@ -166,19 +168,34 @@ def unstructure_field(name: str, value: Any) -> tuple[str, Any]:
166168
return name, value.isoformat()
167169
case xr.DataArray():
168170
if name == "auxiliary":
169-
value = tuple(value.values.tolist())
171+
return name, tuple(value.values.tolist())
170172
if has_grid_dims(value):
171-
value = _hack_grid_dims(value, value)
173+
if structured_grid_dims is None:
174+
raise ValueError("Need structured grid dimension sizes")
175+
value = _hack_structured_grid_dims(value, structured_grid_dims=structured_grid_dims)
172176
if has_tdis_dims(value):
173177
value = {kper: value.isel(nper=kper) for kper in range(value.sizes["nper"])}
174178
return name, value
175179
case _:
176180
return name, value
177181

178182

179-
def unstructure_block(block: dict[str, Any]) -> dict[str, Any]:
183+
def unstructure_block(
184+
block: dict[str, Any],
185+
# TODO: temporary, remove not needed
186+
structured_grid_dims: Mapping | None,
187+
) -> dict[str, Any]:
180188
"""Unstructure a block of data, converting fields to a suitable format."""
181-
return dict([unstructure_field(block.get(field_name, None)) for field_name in block.keys()])
189+
return dict(
190+
[
191+
unstructure_field(
192+
name=field_name,
193+
value=block.get(field_name, None),
194+
structured_grid_dims=structured_grid_dims,
195+
)
196+
for field_name in block.keys()
197+
]
198+
)
182199

183200

184201
def _hack_field_metadata(
@@ -195,8 +212,8 @@ def _hack_field_metadata(
195212

196213
def segment_period_data(block: dict[str, Any], cls: type[Component]) -> dict[str, dict[str, Any]]:
197214
"""Partition period data by stress period"""
198-
arrays = {}
199-
blocks = {}
215+
arrays = {} # type: ignore
216+
blocks = {} # type: ignore
200217
period = PERIOD.upper()
201218

202219
for arr_name, periods in block.items():
@@ -220,14 +237,26 @@ def unstructure_component(value: Component) -> dict[str, Any]:
220237
data = value.to_dict(blocks=True)
221238
blocks: dict[str, dict[str, Any]] = {}
222239
blocks.update(binding_blocks := get_binding_blocks(value))
240+
241+
# temporary hack! TODO remove once we have a structured grid index
242+
if "nlay" in value.data.dims: # type: ignore
243+
structured_grid_dims = value.data.dims # type: ignore
244+
elif value.data.parent is not None and "nlay" in value.data.parent.dims: # type: ignore
245+
structured_grid_dims = value.data.parent.dims # type: ignore
246+
else:
247+
structured_grid_dims = None
248+
223249
blocks.update(
224250
{
225-
block_name: unstructure_block(data[block_name])
251+
block_name: unstructure_block(
252+
data[block_name], structured_grid_dims=structured_grid_dims
253+
)
226254
for block_name in dfn.blocks.keys()
227255
if block_name not in binding_blocks
228256
}
229257
)
230258
if period_block := blocks.pop(PERIOD, None):
259+
period_block = {k: v for k, v in period_block.items() if v is not None}
231260
blocks.update(segment_period_data(period_block, cls))
232261

233262
# total temporary hack! manually set solutiongroup 1.

flopy4/mf6/solution.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from abc import ABC
2-
from pathlib import Path
3-
from typing import ClassVar, Optional
2+
from typing import ClassVar
43

54
import attrs
6-
from xattree import field, xattree
5+
from xattree import xattree
76

87
from flopy4.mf6.package import Package
98

@@ -12,8 +11,7 @@
1211
class Solution(Package, ABC):
1312
slntype: ClassVar[str] = "sln"
1413

15-
slnfname: Optional[Path] = field(default=None) # type: ignore
1614
models: list[str] = attrs.field(default=attrs.Factory(list))
1715

1816
def default_filename(self) -> str:
19-
return str(self.slnfname) if self.slnfname else f"solution.{self.slntype.lower()}"
17+
return f"solution.{self.slntype.lower()}"

flopy4/spec.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Wrap `xattree` and `attrs` specification utilities.
33
These include field decorators and introspection functions.
4+
TODO: add `derived` option to dims? or more generic option
5+
to any field indicating it is not part of the formal spec?
46
"""
57

68
from attrs import NOTHING, Attribute

test/test_codec.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,11 @@ def test_dumps_drn():
253253
assert len(period2_lines) == 3
254254

255255
# node elev cond
256-
assert "1 1 5 10.00000000 1.00000000" in dumped # Period 1: (0,0,4)
257-
assert "2 5 1 8.00000000 2.00000000" in dumped # Period 1: (1,4,0)
258-
assert "1 2 2 12.00000000 1.50000000" in dumped # Period 2: (0,1,1)
259-
assert "1 3 4 9.00000000 0.80000000" in dumped # Period 2: (0,2,3)
260-
assert "2 4 3 7.00000000 2.20000000" in dumped # Period 2: (1,3,2)
256+
assert "1 1 5 10.0 1.0" in dumped # Period 1: (0,0,4)
257+
assert "2 5 1 8.0 2.0" in dumped # Period 1: (1,4,0)
258+
assert "1 2 2 12.0 1.5" in dumped # Period 2: (0,1,1)
259+
assert "1 3 4 9.0 0.8" in dumped # Period 2: (0,2,3)
260+
assert "2 4 3 7.0 2.2" in dumped # Period 2: (1,3,2)
261261
assert "1e+30" not in dumped
262262
assert "1.0e+30" not in dumped
263263

@@ -334,8 +334,8 @@ def test_dumps_wel_with_aux():
334334

335335
assert len(lines) == 2
336336
# node q aux_value
337-
assert "1 2 3 -75.00000000 1.00000000" in dumped # (0,1,2) -> node 8, q=-75.0, aux=1.0
338-
assert "2 4 5 -25.00000000 2.00000000" in dumped # (1,3,4) -> node 45, q=-25.0, aux=2.0
337+
assert "1 2 3 -75.0 1.0" in dumped # (0,1,2) -> node 8, q=-75.0, aux=1.0
338+
assert "2 4 5 -25.0 2.0" in dumped # (1,3,4) -> node 45, q=-25.0, aux=2.0
339339
assert "1e+30" not in dumped
340340
assert "1.0e+30" not in dumped
341341

test/test_mf6.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_gwf_chd01(function_tmpdir):
1313
time = ModelTime(perlen=[5.0], nstp=[1], tsmult=[1.0], time_units="days")
1414

1515
ims = Ims(
16-
slnfname="sln1.ims",
16+
filename="sln1.ims",
1717
models=[gwf_name],
1818
print_option="summary",
1919
outer_dvclose=1.00000000e-06,
@@ -91,7 +91,7 @@ def test_gwf_npf01(function_tmpdir):
9191
)
9292

9393
ims = Ims(
94-
# slnfname="sln1.ims",
94+
filename="sln1.ims",
9595
models=[gwf_name],
9696
print_option="summary",
9797
outer_dvclose=1.00000000e-06,

0 commit comments

Comments
 (0)