Skip to content

Commit 3a3a0e2

Browse files
mjrenomjreno
authored andcommitted
baseline oc test
1 parent 4191dbb commit 3a3a0e2

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

flopy4/mf6/converter/unstructure.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,31 @@ def _hack_structured_grid_dims(
9090
)
9191

9292

93-
def _hack_period_non_numeric(name: str, value: xr.DataArray) -> tuple[str, dict[int, str]]:
93+
def _hack_period_non_numeric(name: str, value: xr.DataArray) -> dict[str, dict[int, Any]]:
94+
from flopy4.mf6.gwf import Oc
95+
9496
fname = ""
9597
data = {}
9698
match value.dtype:
9799
case np.bool:
98-
fname = name
99-
data = {kper: "" for kper in range(value.sizes["nper"]) if value.values[kper]}
100+
fname = name # type: ignore
101+
dat = {kper: "" for kper in range(value.sizes["nper"]) if value.values[kper]}
102+
data[fname] = dat
100103
case np.dtypes.StringDType():
101104
fname = name.replace("_", " ")
102-
data = {kper: value.values[kper] for kper in range(value.sizes["nper"])}
105+
dat = {kper: value.values[kper] for kper in range(value.sizes["nper"])}
106+
data[fname] = dat
107+
case object():
108+
if isinstance(value.values[0], Oc.PrintSaveSetting):
109+
for rec in value.values[0].printrecord:
110+
if rec.steps.all:
111+
dat = {kper: "all" for kper in range(value.sizes["nper"])}
112+
key = f"PRINT {rec.rtype}"
113+
data[key] = dat
114+
# for rec in value.values[0].saverecord:
115+
# print("SaveRecord")
103116

104-
return fname, data
117+
return data
105118

106119

107120
def unstructure_component(value: Component) -> dict[str, Any]:
@@ -173,11 +186,12 @@ def unstructure_component(value: Component) -> dict[str, Any]:
173186
)
174187
if block_name == "period":
175188
if not np.issubdtype(field_value.dtype, np.number):
176-
n, v = _hack_period_non_numeric(field_name, field_value)
177-
period_data[n] = v
189+
dat = _hack_period_non_numeric(field_name, field_value)
190+
for n, v in dat.items():
191+
period_data[n] = v
178192
else:
179193
period_data[field_name] = {
180-
kper: field_value.isel(nper=kper)
194+
kper: field_value.isel(nper=kper) # type: ignore
181195
for kper in range(field_value.sizes["nper"])
182196
}
183197
else:
@@ -195,17 +209,19 @@ def unstructure_component(value: Component) -> dict[str, Any]:
195209

196210
# sort kper order
197211
period_blocks = dict(sorted(period_blocks.items()))
212+
print(period_blocks)
198213

199214
# setup indexed period blocks, combine arrays into datasets
200215
for kper, block in period_blocks.items():
201-
arr_name = list(block.keys())[0]
202-
match block[arr_name]:
203-
case str():
204-
blocks[f"period {kper + 1}"] = {arr_name: block[arr_name]}
205-
case xr.DataArray():
206-
blocks[f"period {kper + 1}"] = {
207-
"period": xr.Dataset(block, coords=block[arr_name].coords)
208-
}
216+
blocks[f"period {kper + 1}"] = {}
217+
for arr_name, val in block.items():
218+
match block[arr_name]:
219+
case str():
220+
blocks[f"period {kper + 1}"][arr_name] = val
221+
case xr.DataArray():
222+
blocks[f"period {kper + 1}"]["period"] = xr.Dataset(
223+
block, coords=block[arr_name].coords
224+
)
209225

210226
# combine "perioddata" block arrays (tdis, ats) into datasets
211227
# so they render as lists. temp hack TODO do this generically

test/test_mf6_codec.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from pprint import pprint
44

5-
import pytest
6-
75
from flopy4.mf6.codec import dumps, loads
86
from flopy4.mf6.converter import COMPONENT_CONVERTER
97

@@ -79,7 +77,6 @@ def test_dumps_sto():
7977
pprint(loaded)
8078

8179

82-
@pytest.mark.xfail(reason="nested type unstructuring not yet supported")
8380
def test_dumps_oc():
8481
from flopy4.mf6.gwf import Oc
8582

@@ -102,10 +99,10 @@ def test_dumps_oc():
10299
dumped = dumps(COMPONENT_CONVERTER.unstructure(oc))
103100
print("OC dump:")
104101
print(dumped)
105-
assert "save head all" in dumped
106-
assert "save budget all" in dumped
107-
assert "print head all" in dumped
108-
assert "print budget all" in dumped
102+
assert "SAVE HEAD all" in dumped
103+
assert "SAVE BUDGET all" in dumped
104+
assert "PRINT HEAD all" in dumped
105+
assert "PRINT BUDGET all" in dumped
109106
assert dumped
110107

111108
loaded = loads(dumped)

0 commit comments

Comments
 (0)