Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ jobs:
- name: Run ruff
run: pixi run lint

- name: Run mypy
run: pixi run mypy flopy4

build:
name: Build
runs-on: ubuntu-latest
Expand Down
101 changes: 0 additions & 101 deletions flopy4/mf6/__init__.py
Original file line number Diff line number Diff line change
@@ -1,101 +0,0 @@
from abc import ABC
from datetime import datetime
from pathlib import Path
from typing import Optional

import numpy as np
from attrs import define
from numpy.typing import NDArray
from xattree import ROOT, array, dim, field, xattree

__all__ = [
"Component",
"Package",
"Model",
"Simulation",
"Solution",
"Exchange",
"COMPONENTS",
]

COMPONENTS = {}
"""MF6 component registry."""


class Component(ABC):
@classmethod
def __attrs_init_subclass__(cls):
COMPONENTS[cls.__name__.lower()] = cls


@define
class Package(Component):
pass


@define
class Model(Component):
pass


@define
class Solution(Package):
pass


@define
class Exchange(Package):
exgtype: type = field()
exgfile: Path = field()
exgmnamea: Optional[str] = field(default=None)
exgmnameb: Optional[str] = field(default=None)


@xattree
class Tdis(Package):
@define
class PeriodData:
perlen: float
nstp: int
tsmult: float

nper: int = dim(
name="per",
default=1,
scope=ROOT,
metadata={"block": "dimensions"},
)
time_units: Optional[str] = field(
default=None, metadata={"block": "options"}
)
start_date_time: Optional[datetime] = field(
default=None, metadata={"block": "options"}
)
# perioddata: NDArray[np.object_] = array(
# PeriodData,
# dims=("per",),
# metadata={"block": "perioddata"},
# )
perlen: NDArray[np.floating] = array(
default=1.0,
dims=("per",),
metadata={"block": "perioddata"},
)
nstp: NDArray[np.integer] = array(
default=1,
dims=("per",),
metadata={"block": "perioddata"},
)
tsmult: NDArray[np.floating] = array(
default=1.0,
dims=("per",),
metadata={"block": "perioddata"},
)


@xattree
class Simulation(Component):
models: dict[str, Model] = field()
exchanges: dict[str, Exchange] = field()
solutions: dict[str, Solution] = field()
tdis: Tdis = field()
10 changes: 10 additions & 0 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC

COMPONENTS = {}
"""MF6 component registry."""


class Component(ABC):
@classmethod
def __attrs_init_subclass__(cls):
COMPONENTS[cls.__name__.lower()] = cls
4 changes: 4 additions & 0 deletions flopy4/mf6/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np

FILL_DEFAULT = np.nan
FILL_DNODATA = 1e30
66 changes: 66 additions & 0 deletions flopy4/mf6/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

import numpy as np
from numpy.typing import NDArray
from xattree import _get_xatspec

from flopy4.mf6.constants import FILL_DNODATA


def convert_array(value, self_, field) -> NDArray:
if not isinstance(value, dict):
# if not a dict, assume it's a numpy array
# and let xarray deal with it if it isn't
return value

# get spec
spec = _get_xatspec(type(self_))
field = spec.arrays[field.name]
if not field.dims:
raise ValueError(f"Field {field} missing dims")

# resolve dims
explicit_dims = self_.__dict__.get("dims", {})
inherited_dims = self_.parent.data.dims if self_.parent else {}
dims = inherited_dims | explicit_dims
shape = [dims.get(d, d) for d in field.dims]
unresolved = [d for d in shape if isinstance(d, str)]
if any(unresolved):
raise ValueError(f"Couldn't resolve dims: {unresolved}")

# create array
a = np.full(shape, fill_value=FILL_DNODATA, dtype=field.dtype)

def _get_nn(cellid):
match len(cellid):
case 1:
return cellid[0]
case 2:
k, j = cellid
return k * dims["ncpl"] + j
case 3:
k, i, j = cellid
return k * dims["row"] * dims["col"] + i * dims["col"] + j
case _:
raise ValueError(f"Invalid cellid: {cellid}")

# populate array. TODO: is there a way to do this
# without hardcoding awareness of kper and cellid?
if "per" in dims:
for kper, period in value.items():
if kper == "*":
kper = 0
match len(shape):
case 1:
a[kper] = value
case _:
for cellid, v in period.items():
nn = _get_nn(cellid)
a[kper, nn] = v
if kper == "*":
break
else:
for cellid, v in value.items():
nn = _get_nn(cellid)
a[nn] = v

return a
15 changes: 15 additions & 0 deletions flopy4/mf6/exchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pathlib import Path
from typing import Optional

from attrs import define
from xattree import field

from flopy4.mf6.package import Package


@define
class Exchange(Package):
exgtype: type = field()
exgfile: Path = field()
exgmnamea: Optional[str] = field(default=None)
exgmnameb: Optional[str] = field(default=None)
2 changes: 1 addition & 1 deletion flopy4/mf6/gwf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from attrs import define
from xattree import field, xattree

from flopy4.mf6 import Model
from flopy4.mf6.gwf.chd import Chd
from flopy4.mf6.gwf.dis import Dis
from flopy4.mf6.gwf.ic import Ic
from flopy4.mf6.gwf.npf import Npf
from flopy4.mf6.gwf.oc import Oc
from flopy4.mf6.model import Model

__all__ = ["Gwf", "Chd", "Dis", "Ic", "Npf", "Oc"]

Expand Down
50 changes: 13 additions & 37 deletions flopy4/mf6/gwf/chd.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,16 @@
from pathlib import Path
from typing import Optional

import attrs
import numpy as np
from attrs import define
from attrs import Converter, define
from numpy.typing import NDArray
from xattree import _get_xatspec, array, field, xattree
from xattree import array, field, xattree

from flopy4.mf6 import Package
from flopy4.mf6.converters import convert_array
from flopy4.mf6.package import Package

dnodata = 1e30


def _get_nn(ncol, nrow, k, i, j):
return k * nrow * ncol + i * ncol + j


def _convert_array(value, self_, field):
if not isinstance(value, dict):
return value

inherited_dims = self_.__dict__.get("dims", {})
spec = _get_xatspec(type(self_))
field = spec.arrays["head"]
shape = field.dims
if not shape:
raise ValueError()
dims = [inherited_dims.get(d, d) for d in shape]
# TODO pull out dtype from annotation
a = np.full(dims, fill_value=dnodata, dtype=np.float64)
for kper, period in value.items():
if kper == "*":
kper = 0
for cellid, v in period.items():
nn = _get_nn(inherited_dims["col"], inherited_dims["row"], *cellid)
a[kper, nn] = v
return a


@xattree(multi="list")
@xattree
class Chd(Package):
@define(slots=False)
class Steps:
Expand Down Expand Up @@ -75,9 +47,7 @@ class Steps:
),
default=None,
metadata={"block": "period"},
converter=attrs.Converter(
_convert_array, takes_self=True, takes_field=True
),
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
aux: Optional[NDArray[np.floating]] = array(
dims=(
Expand All @@ -86,6 +56,7 @@ class Steps:
),
default=None,
metadata={"block": "period"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
boundname: Optional[NDArray[np.str_]] = array(
dims=(
Expand All @@ -94,7 +65,12 @@ class Steps:
),
default=None,
metadata={"block": "period"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
steps: Optional[NDArray[np.object_]] = array(
Steps, dims=("per", "node"), default=None, metadata={"block": "period"}
Steps,
dims=("per", "node"),
default=None,
metadata={"block": "period"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
9 changes: 8 additions & 1 deletion flopy4/mf6/gwf/dis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from attrs import Converter
from numpy.typing import NDArray
from xattree import array, dim, field, xattree

from flopy4.mf6 import Package
from flopy4.mf6.converters import convert_array
from flopy4.mf6.package import Package


@xattree
Expand Down Expand Up @@ -46,26 +48,31 @@ class Dis(Package):
dims=("col",),
default=1.0,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
delc: NDArray[np.floating] = array(
dims=("row",),
default=1.0,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
top: NDArray[np.floating] = array(
dims=("col", "row"),
default=1.0,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
botm: NDArray[np.floating] = array(
dims=("col", "row", "lay"),
default=0.0,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
idomain: NDArray[np.integer] = array(
dims=("col", "row", "lay"),
default=1,
metadata={"block": "griddata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
nnodes: int = dim(name="node", scope="gwf", init=False)

Expand Down
5 changes: 4 additions & 1 deletion flopy4/mf6/gwf/ic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from attrs import Converter
from numpy.typing import NDArray
from xattree import array, field, xattree

from flopy4.mf6 import Package
from flopy4.mf6.converters import convert_array
from flopy4.mf6.package import Package


@xattree
Expand All @@ -11,6 +13,7 @@ class Ic(Package):
dims=("node",),
default=1.0,
metadata={"block": "packagedata"},
converter=Converter(convert_array, takes_self=True, takes_field=True),
)
export_array_ascii: bool = field(
default=False, metadata={"block": "options"}
Expand Down
Loading
Loading