Skip to content

Commit ad85b25

Browse files
mjrenomjreno
authored andcommitted
add multi period test with storage
1 parent 1155fbe commit ad85b25

File tree

13 files changed

+307
-142
lines changed

13 files changed

+307
-142
lines changed

flopy4/mf6/adapters.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ def plottable(self):
264264

265265
@property
266266
def has_stress_period_data(self):
267-
# TODO oc returns true? is stress package?
268267
return "nper" in self._data.dims
269268

270269
def check(self, f=None, verbose=True, level=1, checktype=None):

flopy4/mf6/codec/writer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_JINJA_ENV.filters["array_how"] = filters.array_how
1717
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
1818
_JINJA_ENV.filters["array2string"] = filters.array2string
19+
_JINJA_ENV.filters["data2const"] = filters.data2const
1920
_JINJA_ENV.filters["data2list"] = filters.data2list
2021
_JINJA_ENV.filters["data2keystring"] = filters.data2keystring
2122
_JINJA_TEMPLATE_NAME = "blocks.jinja"

flopy4/mf6/codec/writer/filters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def array_how(value: xr.DataArray) -> str:
6363
# TODO
6464
# - detect constant arrays?
6565
# - above certain size, use external?
66+
if value.max() == value.min():
67+
return "constant"
6668
return "internal"
6769

6870

@@ -151,6 +153,13 @@ def nonempty(arr: NDArray | xr.DataArray) -> NDArray:
151153
return mask
152154

153155

156+
def data2const(value: xr.DataArray):
157+
if np.issubdtype(value.dtype, np.integer):
158+
return value.max().item()
159+
if np.issubdtype(value.dtype, np.floating):
160+
return f"{value.max().item():.8f}"
161+
162+
154163
def data2list(value: list | xr.DataArray | xr.Dataset):
155164
"""
156165
Yield record tuples from a list, `DataArray` or `Dataset`.
@@ -234,6 +243,9 @@ def dataset2list(value: xr.Dataset):
234243
field_vals.append(field_val.item())
235244
else:
236245
field_vals.append(field_val)
246+
for j, v in enumerate(field_vals):
247+
if isinstance(field_vals[j], float):
248+
field_vals[j] = f"{field_vals[j]:.8f}"
237249
if has_spatial_dims:
238250
cellid = tuple(idx[i] + 1 for idx in indices)
239251
yield cellid + tuple(field_vals)

flopy4/mf6/codec/writer/templates/macros.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
{{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}
4141

4242
{% if how == "constant" %}
43-
CONSTANT {{ value.item() }}
43+
CONSTANT {{ value|data2const -}}
4444
{% elif how == "layered constant" %}
4545
{% for layer in value -%}
4646
CONSTANT {{ layer.item() }}

flopy4/mf6/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# TODO use https://environ-config.readthedocs.io/en/stable/?
22

3-
SPARSE_THRESHOLD = 1000
3+
SPARSE_THRESHOLD = 100000

flopy4/mf6/converter.py

Lines changed: 78 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,54 @@ def _path_to_record(field_name: str, path_value: Path) -> tuple:
8787
return (field_name.upper(), "FILEOUT", str(path_value))
8888

8989

90+
def _user_dims(value, field_value):
91+
# terrible hack to convert flat nodes dimension to 3d structured dims.
92+
# long term solution for this is to use a custom xarray index. filters
93+
# should then have access to all dimensions needed.
94+
dims_ = set(field_value.dims).copy()
95+
parent = value.parent # type: ignore
96+
if parent is None:
97+
# TODO for standalone packages
98+
return field_value
99+
100+
if "nper" in dims_:
101+
dims_.remove("nper")
102+
shape = (
103+
field_value.sizes["nper"],
104+
parent.dims["nlay"],
105+
parent.dims["nrow"],
106+
parent.dims["ncol"],
107+
)
108+
dims = ("nper", "nlay", "nrow", "ncol")
109+
coords = {
110+
"nper": field_value.coords["nper"],
111+
"nlay": range(parent.dims["nlay"]),
112+
"nrow": range(parent.dims["nrow"]),
113+
"ncol": range(parent.dims["ncol"]),
114+
}
115+
else:
116+
shape = (
117+
parent.dims["nlay"],
118+
parent.dims["nrow"],
119+
parent.dims["ncol"],
120+
)
121+
dims = ("nlay", "nrow", "ncol")
122+
coords = {
123+
"nlay": range(parent.dims["nlay"]),
124+
"nrow": range(parent.dims["nrow"]),
125+
"ncol": range(parent.dims["ncol"]),
126+
}
127+
128+
if dims_ == {"nodes"}:
129+
field_value = xr.DataArray(
130+
field_value.data.reshape(shape),
131+
dims=dims,
132+
coords=coords,
133+
)
134+
135+
return field_value
136+
137+
90138
def unstructure_component(value: Component) -> dict[str, Any]:
91139
blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore
92140
blocks: dict[str, dict[str, Any]] = {}
@@ -161,47 +209,27 @@ def unstructure_component(value: Component) -> dict[str, Any]:
161209
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"]
162210
)
163211
if has_spatial_dims:
164-
# terrible hack to convert flat nodes dimension to 3d structured dims.
165-
# long term solution for this is to use a custom xarray index. filters
166-
# should then have access to all dimensions needed.
167-
dims_ = set(field_value.dims).copy()
168-
dims_.remove("nper")
169-
if dims_ == {"nodes"}:
170-
parent = value.parent # type: ignore
171-
field_value = xr.DataArray(
172-
field_value.data.reshape(
173-
(
174-
field_value.sizes["nper"],
175-
parent.dims["nlay"],
176-
parent.dims["nrow"],
177-
parent.dims["ncol"],
178-
)
179-
),
180-
dims=("nper", "nlay", "nrow", "ncol"),
181-
coords={
182-
"nper": field_value.coords["nper"],
183-
"nlay": range(parent.dims["nlay"]),
184-
"nrow": range(parent.dims["nrow"]),
185-
"ncol": range(parent.dims["ncol"]),
186-
},
187-
name=field_value.name,
188-
)
212+
field_value = _user_dims(value, field_value)
189213

190214
period_data[field_name] = {
191215
kper: field_value.isel(nper=kper)
192216
for kper in range(field_value.sizes["nper"])
193217
}
194218
else:
195-
if (
196-
# TODO: refactor
197-
# field_name == "save_budget"
198-
# or field_name == "save_head"
199-
# or field_name == "print_budget"
200-
# or field_name == "print_head"
201-
np.issubdtype(field_value.dtype, np.str_)
202-
):
219+
if value.__class__.__name__ == "Oc":
203220
period_data[field_name] = {
204-
kper: field_value[kper] for kper in range(field_value.sizes["nper"])
221+
kper: field_value.values[kper]
222+
for kper in range(field_value.sizes["nper"])
223+
if field_value.values[kper] is not None
224+
}
225+
elif np.issubdtype(field_value.dtype, np.str_):
226+
fname = field_name
227+
if value.__class__.__name__ == "Sto":
228+
fname = field_name.replace("_", "-")
229+
period_data[fname] = {
230+
kper: field_value[kper]
231+
for kper in range(field_value.sizes["nper"])
232+
if field_value[kper] is not None
205233
}
206234
else:
207235
if block_name not in period_data:
@@ -212,6 +240,8 @@ def unstructure_component(value: Component) -> dict[str, Any]:
212240
if isinstance(field_value, bool):
213241
if field_value:
214242
blocks[block_name][field_name] = field_value
243+
elif isinstance(field_value, xr.DataArray):
244+
blocks[block_name][field_name] = _user_dims(value, field_value)
215245
else:
216246
blocks[block_name][field_name] = field_value
217247

@@ -223,9 +253,20 @@ def unstructure_component(value: Component) -> dict[str, Any]:
223253

224254
for arr_name, periods in period_data.items():
225255
for kper, arr in periods.items():
226-
if kper not in period_blocks:
227-
period_blocks[kper] = {}
228-
period_blocks[kper][arr_name] = arr
256+
if isinstance(arr, xr.DataArray):
257+
max = arr.max()
258+
if max == arr.min() and max == FILL_DNODATA:
259+
# don't write empty period blocks unless
260+
# to intentionally reset data
261+
pass
262+
else:
263+
if kper not in period_blocks:
264+
period_blocks[kper] = {}
265+
period_blocks[kper][arr_name] = arr
266+
else:
267+
if kper not in period_blocks:
268+
period_blocks[kper] = {}
269+
period_blocks[kper][arr_name] = arr.upper()
229270

230271
for kper, block in period_blocks.items():
231272
dataset = xr.Dataset(block)

flopy4/mf6/gwf/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
from flopy4.mf6.gwf.ic import Ic
1414
from flopy4.mf6.gwf.npf import Npf
1515
from flopy4.mf6.gwf.oc import Oc
16+
from flopy4.mf6.gwf.sto import Sto
1617
from flopy4.mf6.gwf.wel import Wel
1718
from flopy4.mf6.model import Model
1819
from flopy4.mf6.spec import field
1920
from flopy4.mf6.utils import open_cbc, open_hds
2021

21-
__all__ = ["Gwf", "Chd", "Dis", "Drn", "Ic", "Npf", "Oc", "Wel"]
22+
__all__ = ["Gwf", "Chd", "Dis", "Drn", "Ic", "Npf", "Oc", "Sto", "Wel"]
2223

2324

2425
def convert_grid(value):
@@ -66,11 +67,12 @@ def budget(self):
6667
nc_filerecord: Optional[Path] = field(block="options", default=None)
6768
dis: Dis = field(converter=convert_grid, block="packages")
6869
ic: Ic = field(block="packages")
69-
oc: Oc = field(block="packages")
70+
oc: Oc = field(block="packages", default=None)
7071
npf: Npf = field(block="packages")
72+
sto: Sto = field(block="packages", default=None)
7173
chd: list[Chd] = field(block="packages")
72-
wel: list[Wel] = field(block="packages")
7374
drn: list[Drn] = field(block="packages")
75+
wel: list[Wel] = field(block="packages")
7476
output: Output = attrs.field(
7577
default=attrs.Factory(lambda self: Gwf.Output(self), takes_self=True)
7678
)

flopy4/mf6/gwf/oc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,31 +55,31 @@ class Period:
5555
save_head: Optional[NDArray[np.object_]] = array(
5656
object,
5757
block="period",
58-
default="all",
58+
default=None,
5959
dims=("nper",),
6060
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
6161
format="keystring",
6262
)
6363
save_budget: Optional[NDArray[np.object_]] = array(
6464
object,
6565
block="period",
66-
default="all",
66+
default=None,
6767
dims=("nper",),
6868
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
6969
format="keystring",
7070
)
7171
print_head: Optional[NDArray[np.object_]] = array(
7272
object,
7373
block="period",
74-
default="all",
74+
default=None,
7575
dims=("nper",),
7676
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
7777
format="keystring",
7878
)
7979
print_budget: Optional[NDArray[np.object_]] = array(
8080
object,
8181
block="period",
82-
default="all",
82+
default=None,
8383
dims=("nper",),
8484
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
8585
format="keystring",

flopy4/mf6/gwf/sto.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,7 @@ class Sto(Package):
3939
default=0.15,
4040
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
4141
)
42-
steady_state: Optional[NDArray[np.bool_]] = array(
43-
block="period",
44-
dims=("nper",),
45-
default=None,
46-
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
47-
)
48-
transient: Optional[NDArray[np.bool_]] = array(
42+
storage: Optional[NDArray[np.str_]] = array(
4943
block="period",
5044
dims=("nper",),
5145
default=None,

test/test_codec.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,11 @@ def test_dumps_drn():
253253
assert len(period2_lines) == 3
254254

255255
# node elev cond
256-
assert "1 1 5 10.0 1.0" in dumped # Period 1: (0,0,4)
257-
assert "2 5 1 8.0 2.0" in dumped # Period 1: (1,4,0)
258-
assert "1 2 2 12.0 1.5" in dumped # Period 2: (0,1,1)
259-
assert "1 3 4 9.0 0.8" in dumped # Period 2: (0,2,3)
260-
assert "2 4 3 7.0 2.2" in dumped # Period 2: (1,3,2)
256+
assert "1 1 5 10.00000000 1.00000000" in dumped # Period 1: (0,0,4)
257+
assert "2 5 1 8.00000000 2.00000000" in dumped # Period 1: (1,4,0)
258+
assert "1 2 2 12.00000000 1.50000000" in dumped # Period 2: (0,1,1)
259+
assert "1 3 4 9.00000000 0.80000000" in dumped # Period 2: (0,2,3)
260+
assert "2 4 3 7.00000000 2.20000000" in dumped # Period 2: (1,3,2)
261261
assert "1e+30" not in dumped
262262
assert "1.0e+30" not in dumped
263263

@@ -334,8 +334,8 @@ def test_dumps_wel_with_aux():
334334

335335
assert len(lines) == 2
336336
# node q aux_value
337-
assert "1 2 3 -75.0 1.0" in dumped # (0,1,2) -> node 8, q=-75.0, aux=1.0
338-
assert "2 4 5 -25.0 2.0" in dumped # (1,3,4) -> node 45, q=-25.0, aux=2.0
337+
assert "1 2 3 -75.00000000 1.00000000" in dumped # (0,1,2) -> node 8, q=-75.0, aux=1.0
338+
assert "2 4 5 -25.00000000 2.00000000" in dumped # (1,3,4) -> node 45, q=-25.0, aux=2.0
339339
assert "1e+30" not in dumped
340340
assert "1.0e+30" not in dumped
341341

0 commit comments

Comments
 (0)