Skip to content

Commit 8f94b9b

Browse files
committed
just use a converter, so gwf can accept a dis or a grid, and sim can accept a modeltime or a tdis
1 parent 41ae04e commit 8f94b9b

File tree

7 files changed

+179
-137
lines changed

7 files changed

+179
-137
lines changed

flopy4/mf6/converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import sparse
55
from numpy.typing import NDArray
6-
from xattree import _get_xatspec
6+
from xattree import get_xatspec
77

88
from flopy4.mf6.config import SPARSE_THRESHOLD
99
from flopy4.mf6.constants import FILL_DNODATA
@@ -16,8 +16,8 @@ def convert_array(value, self_, field) -> NDArray:
1616
return value
1717

1818
# get spec
19-
spec = _get_xatspec(type(self_))
20-
field = spec.arrays[field.name]
19+
spec = get_xatspec(type(self_))
20+
field = spec[field.name]
2121
if not field.dims:
2222
raise ValueError(f"Field {field} missing dims")
2323

flopy4/mf6/gwf/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def budget(self):
4040
merge_to_dataset=True,
4141
)
4242

43-
dis: Dis = field()
43+
dis: Dis = field(converter=lambda grid: Dis.from_grid(grid))
4444
ic: Ic = field()
4545
oc: Oc = field()
4646
npf: Npf = field()
@@ -71,3 +71,7 @@ class NewtonOptions:
7171
nc_filerecord: Optional[Path] = field(
7272
default=None, metadata={"block": "options"}
7373
)
74+
75+
@property
76+
def grid(self) -> Grid:
77+
return self.dis.to_grid()

flopy4/mf6/gwf/dis.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from attrs import Converter
3+
from flopy.discretization.structuredgrid import StructuredGrid
34
from numpy.typing import NDArray
45
from xattree import array, dim, field, xattree
56

@@ -82,3 +83,49 @@ class Dis(Package):
8283

8384
def __attrs_post_init__(self):
8485
self.nnodes = self.ncol * self.nrow * self.nlay
86+
87+
def to_grid(self) -> StructuredGrid:
88+
"""
89+
Convert the discretization to a `StructuredGrid`.
90+
91+
Returns
92+
-------
93+
StructuredGrid
94+
A `StructuredGrid` with the same dimensions and data as the `Dis`.
95+
"""
96+
return StructuredGrid(
97+
nlay=self.nlay,
98+
nrow=self.nrow,
99+
ncol=self.ncol,
100+
delr=self.delr,
101+
delc=self.delc,
102+
top=self.top,
103+
botm=self.botm,
104+
idomain=self.idomain,
105+
)
106+
107+
@classmethod
108+
def from_grid(cls, grid: StructuredGrid) -> "Dis":
109+
"""
110+
Create a discretization from a `StructuredGrid`.
111+
112+
Parameters
113+
----------
114+
grid : StructuredGrid
115+
A structured grid.
116+
117+
Returns
118+
-------
119+
Dis
120+
A discretization with the same dimensions and data as the grid.
121+
"""
122+
return Dis(
123+
nlay=grid.nlay,
124+
nrow=grid.nrow,
125+
ncol=grid.ncol,
126+
delr=grid.delr,
127+
delc=grid.delc,
128+
top=grid.top,
129+
botm=grid.botm,
130+
idomain=grid.idomain,
131+
)

flopy4/mf6/simulation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,5 @@ class Simulation(Component):
1717
models: dict[str, Model] = field()
1818
exchanges: dict[str, Exchange] = field()
1919
solutions: dict[str, Solution] = field()
20-
tdis: Tdis = field()
2120
sim_ws: Path = field(default=None)
22-
time: ModelTime = attrs.field(default=None)
21+
tdis: Tdis = field(converter=lambda time: Tdis.from_time(time))

flopy4/mf6/tdis.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from attrs import Converter, define
6+
from flopy.discretization.modeltime import ModelTime
67
from numpy.typing import NDArray
78
from xattree import ROOT, array, dim, field, xattree
89

@@ -48,3 +49,26 @@ class PeriodData:
4849
metadata={"block": "perioddata"},
4950
converter=Converter(convert_array, takes_self=True, takes_field=True),
5051
)
52+
53+
def to_time(self) -> ModelTime:
54+
"""Convert to a `ModelTime` object."""
55+
return ModelTime(
56+
nper=self.nper,
57+
time_units=self.time_units,
58+
start_date_time=self.start_date_time,
59+
perlen=self.perlen,
60+
nstp=self.nstp,
61+
tsmult=self.tsmult,
62+
)
63+
64+
@classmethod
65+
def from_time(cls, time: ModelTime) -> "Tdis":
66+
"""Create a time discretization from a `ModelTime`."""
67+
return cls(
68+
nper=time.nper,
69+
time_units=time.time_units,
70+
start_date_time=time.start_datetime,
71+
perlen=time.perlen,
72+
nstp=time.nstp,
73+
tsmult=time.tsmult,
74+
)

test/test_component.py

Lines changed: 29 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -49,44 +49,12 @@ def test_init_gwf_explicit_dims():
4949
)
5050

5151
assert isinstance(gwf.data, DataTree)
52-
assert gwf.dis is dis
52+
assert gwf.dis is not dis # dimension order switched.. is this ok?
5353
assert gwf.ic is ic
5454
assert gwf.oc is oc
5555
assert gwf.npf is npf
5656
assert gwf.chd[0] is chd
57-
assert gwf.data.dis is dis.data
58-
assert gwf.data.ic is ic.data
59-
assert gwf.data.oc is oc.data
60-
assert gwf.data.npf is npf.data
61-
assert np.array_equal(npf.k, np.ones(4))
62-
assert np.array_equal(npf.data.k, np.ones(4))
63-
64-
65-
@pytest.mark.skip(reason="TODO")
66-
def test_init_gwf_from_grid():
67-
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
68-
grid = StructuredGrid(nlay=1, nrow=2, ncol=2)
69-
dis = Dis(grid=grid)
70-
ic = Ic(grid=grid)
71-
oc = Oc(grid=grid)
72-
npf = Npf(grid=grid)
73-
chd = Chd(grid=grid)
74-
gwf = Gwf(
75-
dis=dis,
76-
ic=ic,
77-
oc=oc,
78-
npf=npf,
79-
chd=[chd],
80-
grid=grid,
81-
)
82-
83-
assert isinstance(gwf.data, DataTree)
84-
assert gwf.dis is dis
85-
assert gwf.ic is ic
86-
assert gwf.oc is oc
87-
assert gwf.npf is npf
88-
assert gwf.chd[0] is chd
89-
assert gwf.data.dis is dis.data
57+
assert gwf.data.dis is not dis.data
9058
assert gwf.data.ic is ic.data
9159
assert gwf.data.oc is oc.data
9260
assert gwf.data.npf is npf.data
@@ -138,7 +106,7 @@ def test_init_gwf_dis_first():
138106
chd = Chd(parent=gwf, strict=False)
139107

140108
assert isinstance(gwf.data, DataTree)
141-
assert gwf.dis is dis
109+
assert gwf.dis is not dis
142110
assert gwf.ic is ic
143111
assert gwf.oc is oc
144112
assert gwf.npf is npf
@@ -147,6 +115,25 @@ def test_init_gwf_dis_first():
147115
assert np.array_equal(npf.data.k, np.ones(4))
148116

149117

118+
def test_init_gwf_dis_first_with_grid():
119+
grid = StructuredGrid(nlay=1, nrow=10, ncol=10)
120+
gwf = Gwf(dis=grid)
121+
dis = gwf.dis
122+
ic = Ic(parent=gwf)
123+
oc = Oc(parent=gwf, strict=False)
124+
npf = Npf(parent=gwf)
125+
chd = Chd(parent=gwf, strict=False)
126+
127+
assert isinstance(gwf.data, DataTree)
128+
assert gwf.dis is dis
129+
assert gwf.ic is ic
130+
assert gwf.oc is oc
131+
assert gwf.npf is npf
132+
assert gwf.chd[0] is chd
133+
assert np.array_equal(npf.k, np.ones(100))
134+
assert np.array_equal(npf.data.k, np.ones(100))
135+
136+
150137
def test_init_gwf_top_down_misaligned():
151138
grid = StructuredGrid(nlay=1, nrow=10, ncol=10)
152139
dims = {
@@ -198,7 +185,7 @@ def test_init_sim_explicit_dims():
198185
assert isinstance(sim.data, DataTree)
199186
assert sim.data.tdis is tdis.data
200187
assert sim.data.gwf is gwf.data
201-
assert gwf.dis is dis
188+
assert gwf.dis is not dis # gwf.dis has inherited dim nper
202189
assert gwf.ic is ic
203190
assert gwf.oc is oc
204191
assert gwf.npf is npf
@@ -219,35 +206,16 @@ def test_init_big_sim():
219206
# if size over threshold, arrays should be sparse
220207
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
221208
grid = StructuredGrid(nlay=1, nrow=100, ncol=100)
222-
dims = {
223-
"nlay": grid.nlay,
224-
"nrow": grid.nrow,
225-
"ncol": grid.ncol,
226-
}
227-
dis = Dis(**dims)
228-
dims["nper"] = time.nper
229-
dims["nnodes"] = grid.nnodes
230-
ic = Ic(dims=dims)
231-
oc = Oc(dims=dims)
232-
npf = Npf(dims=dims)
233-
chd = Chd(dims=dims, head={"*": {(0, 0, 0): 1.0, (0, 99, 99): 0.0}})
234-
gwf = Gwf(
235-
dis=dis,
236-
ic=ic,
237-
oc=oc,
238-
npf=npf,
239-
chd=[chd],
240-
dims=dims,
241-
)
242-
tdis = Tdis(dims=dims)
243-
sim = Simulation(tdis=tdis, models={"gwf": gwf})
209+
sim = Simulation(tdis=time)
210+
gwf = Gwf(parent=sim, dis=grid)
211+
ic = Ic(parent=gwf)
212+
oc = Oc(parent=gwf)
213+
npf = Npf(parent=gwf)
214+
chd = Chd(parent=gwf, head={"*": {(0, 0, 0): 1.0, (0, 99, 99): 0.0}})
244215

245-
assert sim.tdis is tdis
246216
assert sim.models["gwf"] is gwf
247217
assert isinstance(sim.data, DataTree)
248-
assert sim.data.tdis is tdis.data
249218
assert sim.data.gwf is gwf.data
250-
assert gwf.dis is dis
251219
assert gwf.ic is ic
252220
assert gwf.oc is oc
253221
assert gwf.npf is npf

0 commit comments

Comments
 (0)