Skip to content

Commit a471903

Browse files
authored
add grid and time fields to components (#124)
Initial step towards #120, add grid property to Gwf and time property to Simulation, and use a converter so the user can pass grid/modeltime to gwf/sim instead of dis/tdis, respectively
1 parent 20bfbc1 commit a471903

File tree

9 files changed

+242
-111
lines changed

9 files changed

+242
-111
lines changed

docs/examples/quickstart.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,5 @@
5555
ax.grid(which="both", color="white")
5656
head.plot.imshow(ax=ax)
5757
head.plot.contour(ax=ax, levels=[0.2, 0.4, 0.6, 0.8], linewidths=3.0)
58-
budget.plot.quiver(
59-
x="x", y="y", u="npf-qx", v="npf-qy", ax=ax, color="white"
60-
)
58+
budget.plot.quiver(x="x", y="y", u="npf-qx", v="npf-qy", ax=ax, color="white")
6159
fig.savefig(ws / "quickstart.png")

flopy4/mf6/converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import sparse
55
from numpy.typing import NDArray
6-
from xattree import _get_xatspec
6+
from xattree import get_xatspec
77

88
from flopy4.mf6.config import SPARSE_THRESHOLD
99
from flopy4.mf6.constants import FILL_DNODATA
@@ -16,8 +16,8 @@ def convert_array(value, self_, field) -> NDArray:
1616
return value
1717

1818
# get spec
19-
spec = _get_xatspec(type(self_))
20-
field = spec.arrays[field.name]
19+
spec = get_xatspec(type(self_))
20+
field = spec[field.name]
2121
if not field.dims:
2222
raise ValueError(f"Field {field} missing dims")
2323

flopy4/mf6/gwf/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import imod
66
import xarray as xr
77
from attrs import define
8+
from flopy.discretization.grid import Grid
89
from xattree import field, xattree
910

1011
from flopy4.mf6.gwf.chd import Chd
@@ -18,6 +19,14 @@
1819
__all__ = ["Gwf", "Chd", "Dis", "Ic", "Npf", "Oc"]
1920

2021

22+
def convert_grid(value):
23+
if isinstance(value, Grid):
24+
return Dis.from_grid(value)
25+
if isinstance(value, Dis):
26+
return value
27+
raise TypeError(f"Expected Grid or Dis, got {type(value)}")
28+
29+
2130
@xattree
2231
class Gwf(Model):
2332
@define
@@ -39,7 +48,7 @@ def budget(self):
3948
merge_to_dataset=True,
4049
)
4150

42-
dis: Dis = field()
51+
dis: Dis = field(converter=convert_grid)
4352
ic: Ic = field()
4453
oc: Oc = field()
4554
npf: Npf = field()
@@ -69,3 +78,7 @@ class NewtonOptions:
6978
nc_filerecord: Optional[Path] = field(
7079
default=None, metadata={"block": "options"}
7180
)
81+
82+
@property
83+
def grid(self) -> Grid:
84+
return self.dis.to_grid()

flopy4/mf6/gwf/dis.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from attrs import Converter
3+
from flopy.discretization.structuredgrid import StructuredGrid
34
from numpy.typing import NDArray
45
from xattree import array, dim, field, xattree
56

@@ -82,3 +83,49 @@ class Dis(Package):
8283

8384
def __attrs_post_init__(self):
8485
self.nnodes = self.ncol * self.nrow * self.nlay
86+
87+
def to_grid(self) -> StructuredGrid:
88+
"""
89+
Convert the discretization to a `StructuredGrid`.
90+
91+
Returns
92+
-------
93+
StructuredGrid
94+
A `StructuredGrid` with the same dimensions and data as the `Dis`.
95+
"""
96+
return StructuredGrid(
97+
nlay=self.nlay,
98+
nrow=self.nrow,
99+
ncol=self.ncol,
100+
delr=self.delr,
101+
delc=self.delc,
102+
top=self.top,
103+
botm=self.botm,
104+
idomain=self.idomain,
105+
)
106+
107+
@classmethod
108+
def from_grid(cls, grid: StructuredGrid) -> "Dis":
109+
"""
110+
Create a discretization from a `StructuredGrid`.
111+
112+
Parameters
113+
----------
114+
grid : StructuredGrid
115+
A structured grid.
116+
117+
Returns
118+
-------
119+
Dis
120+
A discretization with the same dimensions and data as the grid.
121+
"""
122+
return Dis(
123+
nlay=grid.nlay,
124+
nrow=grid.nrow,
125+
ncol=grid.ncol,
126+
delr=grid.delr,
127+
delc=grid.delc,
128+
top=grid.top,
129+
botm=grid.botm,
130+
idomain=grid.idomain,
131+
)

flopy4/mf6/simulation.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22

3+
from flopy.discretization.modeltime import ModelTime
34
from xattree import field, xattree
45

56
from flopy4.mf6.component import Component
@@ -9,10 +10,22 @@
910
from flopy4.mf6.tdis import Tdis
1011

1112

13+
def convert_time(value):
14+
if isinstance(value, ModelTime):
15+
return Tdis.from_time(value)
16+
if isinstance(value, Tdis):
17+
return value
18+
raise TypeError(f"Expected ModelTime or Tdis, got {type(value)}")
19+
20+
1221
@xattree
1322
class Simulation(Component):
1423
models: dict[str, Model] = field()
1524
exchanges: dict[str, Exchange] = field()
1625
solutions: dict[str, Solution] = field()
17-
tdis: Tdis = field()
26+
tdis: Tdis = field(converter=convert_time)
1827
sim_ws: Path = field(default=None)
28+
29+
@property
30+
def time(self) -> ModelTime:
31+
return self.tdis.to_time()

flopy4/mf6/tdis.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from attrs import Converter, define
6+
from flopy.discretization.modeltime import ModelTime
67
from numpy.typing import NDArray
78
from xattree import ROOT, array, dim, field, xattree
89

@@ -48,3 +49,26 @@ class PeriodData:
4849
metadata={"block": "perioddata"},
4950
converter=Converter(convert_array, takes_self=True, takes_field=True),
5051
)
52+
53+
def to_time(self) -> ModelTime:
54+
"""Convert to a `ModelTime` object."""
55+
return ModelTime(
56+
nper=self.nper,
57+
time_units=self.time_units,
58+
start_date_time=self.start_date_time,
59+
perlen=self.perlen,
60+
nstp=self.nstp,
61+
tsmult=self.tsmult,
62+
)
63+
64+
@classmethod
65+
def from_time(cls, time: ModelTime) -> "Tdis":
66+
"""Create a time discretization from a `ModelTime`."""
67+
return cls(
68+
nper=time.nper,
69+
time_units=time.time_units,
70+
start_date_time=time.start_datetime,
71+
perlen=time.perlen,
72+
nstp=time.nstp,
73+
tsmult=time.tsmult,
74+
)

0 commit comments

Comments
 (0)