diff --git a/flopy4/mf6/converter/unstructure.py b/flopy4/mf6/converter/unstructure.py index 4047a483..74faa582 100644 --- a/flopy4/mf6/converter/unstructure.py +++ b/flopy4/mf6/converter/unstructure.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any +import numpy as np import xarray as xr import xattree from modflow_devtools.dfn.schema.block import block_sort_key @@ -89,6 +90,51 @@ def _hack_structured_grid_dims( ) +def _hack_period_non_numeric(name: str, value: xr.DataArray) -> dict[str, dict[int, Any]]: + from flopy4.mf6.gwf import Oc + + def oc_setting_data(rec): + dat = {} + if rec.steps.first: + dat = {kper: "first" for kper in range(value.sizes["nper"])} + elif rec.steps.last: + dat = {kper: "last" for kper in range(value.sizes["nper"])} + elif rec.steps.steps: + steps = " ".join(str(x - 1) for x in rec.steps.steps) + dat = {kper: f"steps {steps}" for kper in range(value.sizes["nper"])} + elif rec.steps.all: + # check last as this defaults to True + dat = {kper: "all" for kper in range(value.sizes["nper"])} + + return dat + + data = {} + match value.dtype: + case np.bool: + dat = {kper: "" for kper in range(value.sizes["nper"]) if value.values[kper]} # type: ignore + data[name] = dat + case np.dtypes.StringDType(): + fname = name.replace("_", " ") + dat = {kper: value.values[kper] for kper in range(value.sizes["nper"])} + data[fname] = dat + case object(): + if isinstance(value.values[0], Oc.PrintSaveSetting): + if hasattr(value.values[0], "printrecord") and isinstance( + value.values[0].printrecord, list + ): + for rec in value.values[0].printrecord: + key = f"{rec.print} {rec.rtype}" + data[key] = oc_setting_data(rec) + if hasattr(value.values[0], "saverecord") and isinstance( + value.values[0].saverecord, list + ): + for rec in value.values[0].saverecord: # type: ignore + key = f"{rec.save} {rec.rtype}" # type: ignore + data[key] = oc_setting_data(rec) + + return data + + def unstructure_component(value: Component) -> dict[str, Any]: blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore blocks: dict[str, dict[str, Any]] = {} @@ -157,10 +203,15 @@ def unstructure_component(value: Component) -> dict[str, Any]: structured_grid_dims=value.parent.data.dims, # type: ignore ) if block_name == "period": - period_data[field_name] = { - kper: field_value.isel(nper=kper) - for kper in range(field_value.sizes["nper"]) - } + if not np.issubdtype(field_value.dtype, np.number): + dat = _hack_period_non_numeric(field_name, field_value) + for n, v in dat.items(): + period_data[n] = v + else: + period_data[field_name] = { + kper: field_value.isel(nper=kper) # type: ignore + for kper in range(field_value.sizes["nper"]) + } else: blocks[block_name][field_name] = field_value @@ -174,11 +225,20 @@ def unstructure_component(value: Component) -> dict[str, Any]: period_blocks[kper] = {} period_blocks[kper][arr_name] = arr + # sort kper order + period_blocks = dict(sorted(period_blocks.items())) + # setup indexed period blocks, combine arrays into datasets for kper, block in period_blocks.items(): - blocks[f"period {kper + 1}"] = { - "period": xr.Dataset(block, coords=block[arr_name].coords) - } + blocks[f"period {kper + 1}"] = {} + for arr_name, val in block.items(): + match block[arr_name]: + case str(): + blocks[f"period {kper + 1}"][arr_name] = val + case xr.DataArray(): + blocks[f"period {kper + 1}"]["period"] = xr.Dataset( + block, coords=block[arr_name].coords + ) # combine "perioddata" block arrays (tdis, ats) into datasets # so they render as lists. temp hack TODO do this generically diff --git a/test/test_mf6_codec.py b/test/test_mf6_codec.py index 55a367d7..1e1e623a 100644 --- a/test/test_mf6_codec.py +++ b/test/test_mf6_codec.py @@ -2,8 +2,6 @@ from pprint import pprint -import pytest - from flopy4.mf6.codec import dumps, loads from flopy4.mf6.converter import COMPONENT_CONVERTER @@ -57,7 +55,31 @@ def test_dumps_ic(): pprint(loaded) -@pytest.mark.xfail(reason="nested type unstructuring not yet supported") +def test_dumps_sto(): + from flopy4.mf6.gwf import Dis, Gwf, Sto + + dis = Dis() + gwf = Gwf(dis=dis) + sto = Sto( + dims={"nper": 3}, + parent=gwf, + steady_state=[False, True, False], + transient=[True, False, True], + ) + + dumped = dumps(COMPONENT_CONVERTER.unstructure(sto)) + print("STO dump:") + print(dumped) + assert "BEGIN PERIOD 1\n TRANSIENT" in dumped + assert "BEGIN PERIOD 2\n STEADY_STATE" in dumped + assert "BEGIN PERIOD 3\n TRANSIENT" in dumped + assert dumped + + loaded = loads(dumped) + print("STO load:") + pprint(loaded) + + def test_dumps_oc(): from flopy4.mf6.gwf import Oc @@ -80,10 +102,45 @@ def test_dumps_oc(): dumped = dumps(COMPONENT_CONVERTER.unstructure(oc)) print("OC dump:") print(dumped) - assert "save head all" in dumped - assert "save budget all" in dumped - assert "print head all" in dumped - assert "print budget all" in dumped + assert "SAVE HEAD all" in dumped + assert "SAVE BUDGET all" in dumped + assert "PRINT HEAD all" in dumped + assert "PRINT BUDGET all" in dumped + assert dumped + + loaded = loads(dumped) + print("OC load:") + pprint(loaded) + + +def test_dumps_oc2(): + from flopy4.mf6.gwf import Oc + + oc = Oc( + dims={"nper": 1}, + budget_file="test.bud", + head_file="test.hds", + perioddata={ + 0: Oc.PrintSaveSetting( + printrecord=[ + Oc.PrintRecord("head", Oc.Steps(first=True)), + Oc.PrintRecord("budget", Oc.Steps(steps=(2, 3, 5))), + ], + saverecord=[ + Oc.SaveRecord("head", Oc.Steps(last=True)), + Oc.SaveRecord("budget", Oc.Steps(first=True)), + ], + ) + }, + ) + + dumped = dumps(COMPONENT_CONVERTER.unstructure(oc)) + print("OC dump:") + print(dumped) + assert "SAVE HEAD last" in dumped + assert "SAVE BUDGET first" in dumped + assert "PRINT HEAD first" in dumped + assert "PRINT BUDGET steps 1 2 4" in dumped assert dumped loaded = loads(dumped)