Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 67 additions & 7 deletions flopy4/mf6/converter/unstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any

import numpy as np
import xarray as xr
import xattree
from modflow_devtools.dfn.schema.block import block_sort_key
Expand Down Expand Up @@ -89,6 +90,51 @@ def _hack_structured_grid_dims(
)


def _hack_period_non_numeric(name: str, value: xr.DataArray) -> dict[str, dict[int, Any]]:
from flopy4.mf6.gwf import Oc

def oc_setting_data(rec):
dat = {}
if rec.steps.first:
dat = {kper: "first" for kper in range(value.sizes["nper"])}
elif rec.steps.last:
dat = {kper: "last" for kper in range(value.sizes["nper"])}
elif rec.steps.steps:
steps = " ".join(str(x - 1) for x in rec.steps.steps)
dat = {kper: f"steps {steps}" for kper in range(value.sizes["nper"])}
elif rec.steps.all:
# check last as this defaults to True
dat = {kper: "all" for kper in range(value.sizes["nper"])}

return dat

data = {}
match value.dtype:
case np.bool:
dat = {kper: "" for kper in range(value.sizes["nper"]) if value.values[kper]} # type: ignore
data[name] = dat
case np.dtypes.StringDType():
fname = name.replace("_", " ")
dat = {kper: value.values[kper] for kper in range(value.sizes["nper"])}
data[fname] = dat
case object():
if isinstance(value.values[0], Oc.PrintSaveSetting):
if hasattr(value.values[0], "printrecord") and isinstance(
value.values[0].printrecord, list
):
for rec in value.values[0].printrecord:
key = f"{rec.print} {rec.rtype}"
data[key] = oc_setting_data(rec)
if hasattr(value.values[0], "saverecord") and isinstance(
value.values[0].saverecord, list
):
for rec in value.values[0].saverecord: # type: ignore
key = f"{rec.save} {rec.rtype}" # type: ignore
data[key] = oc_setting_data(rec)

return data


def unstructure_component(value: Component) -> dict[str, Any]:
blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore
blocks: dict[str, dict[str, Any]] = {}
Expand Down Expand Up @@ -157,10 +203,15 @@ def unstructure_component(value: Component) -> dict[str, Any]:
structured_grid_dims=value.parent.data.dims, # type: ignore
)
if block_name == "period":
period_data[field_name] = {
kper: field_value.isel(nper=kper)
for kper in range(field_value.sizes["nper"])
}
if not np.issubdtype(field_value.dtype, np.number):
dat = _hack_period_non_numeric(field_name, field_value)
for n, v in dat.items():
period_data[n] = v
else:
period_data[field_name] = {
kper: field_value.isel(nper=kper) # type: ignore
for kper in range(field_value.sizes["nper"])
}
else:
blocks[block_name][field_name] = field_value

Expand All @@ -174,11 +225,20 @@ def unstructure_component(value: Component) -> dict[str, Any]:
period_blocks[kper] = {}
period_blocks[kper][arr_name] = arr

# sort kper order
period_blocks = dict(sorted(period_blocks.items()))

# setup indexed period blocks, combine arrays into datasets
for kper, block in period_blocks.items():
blocks[f"period {kper + 1}"] = {
"period": xr.Dataset(block, coords=block[arr_name].coords)
}
blocks[f"period {kper + 1}"] = {}
for arr_name, val in block.items():
match block[arr_name]:
case str():
blocks[f"period {kper + 1}"][arr_name] = val
case xr.DataArray():
blocks[f"period {kper + 1}"]["period"] = xr.Dataset(
block, coords=block[arr_name].coords
)

# combine "perioddata" block arrays (tdis, ats) into datasets
# so they render as lists. temp hack TODO do this generically
Expand Down
71 changes: 64 additions & 7 deletions test/test_mf6_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from pprint import pprint

import pytest

from flopy4.mf6.codec import dumps, loads
from flopy4.mf6.converter import COMPONENT_CONVERTER

Expand Down Expand Up @@ -57,7 +55,31 @@ def test_dumps_ic():
pprint(loaded)


@pytest.mark.xfail(reason="nested type unstructuring not yet supported")
def test_dumps_sto():
from flopy4.mf6.gwf import Dis, Gwf, Sto

dis = Dis()
gwf = Gwf(dis=dis)
sto = Sto(
dims={"nper": 3},
parent=gwf,
steady_state=[False, True, False],
transient=[True, False, True],
)

dumped = dumps(COMPONENT_CONVERTER.unstructure(sto))
print("STO dump:")
print(dumped)
assert "BEGIN PERIOD 1\n TRANSIENT" in dumped
assert "BEGIN PERIOD 2\n STEADY_STATE" in dumped
assert "BEGIN PERIOD 3\n TRANSIENT" in dumped
assert dumped

loaded = loads(dumped)
print("STO load:")
pprint(loaded)


def test_dumps_oc():
from flopy4.mf6.gwf import Oc

Expand All @@ -80,10 +102,45 @@ def test_dumps_oc():
dumped = dumps(COMPONENT_CONVERTER.unstructure(oc))
print("OC dump:")
print(dumped)
assert "save head all" in dumped
assert "save budget all" in dumped
assert "print head all" in dumped
assert "print budget all" in dumped
assert "SAVE HEAD all" in dumped
assert "SAVE BUDGET all" in dumped
assert "PRINT HEAD all" in dumped
assert "PRINT BUDGET all" in dumped
assert dumped

loaded = loads(dumped)
print("OC load:")
pprint(loaded)


def test_dumps_oc2():
from flopy4.mf6.gwf import Oc

oc = Oc(
dims={"nper": 1},
budget_file="test.bud",
head_file="test.hds",
perioddata={
0: Oc.PrintSaveSetting(
printrecord=[
Oc.PrintRecord("head", Oc.Steps(first=True)),
Oc.PrintRecord("budget", Oc.Steps(steps=(2, 3, 5))),
],
saverecord=[
Oc.SaveRecord("head", Oc.Steps(last=True)),
Oc.SaveRecord("budget", Oc.Steps(first=True)),
],
)
},
)

dumped = dumps(COMPONENT_CONVERTER.unstructure(oc))
print("OC dump:")
print(dumped)
assert "SAVE HEAD last" in dumped
assert "SAVE BUDGET first" in dumped
assert "PRINT HEAD first" in dumped
assert "PRINT BUDGET steps 1 2 4" in dumped
assert dumped

loaded = loads(dumped)
Expand Down
Loading