Skip to content

Commit 1ce8632

Browse files
committed
just use a converter, so gwf can accept a dis or a grid, and sim can accept a modeltime or a tdis
1 parent bdde703 commit 1ce8632

File tree

8 files changed

+399
-341
lines changed

8 files changed

+399
-341
lines changed

flopy4/mf6/converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import sparse
55
from numpy.typing import NDArray
6-
from xattree import _get_xatspec
6+
from xattree import get_xatspec
77

88
from flopy4.mf6.config import SPARSE_THRESHOLD
99
from flopy4.mf6.constants import FILL_DNODATA
@@ -16,8 +16,8 @@ def convert_array(value, self_, field) -> NDArray:
1616
return value
1717

1818
# get spec
19-
spec = _get_xatspec(type(self_))
20-
field = spec.arrays[field.name]
19+
spec = get_xatspec(type(self_))
20+
field = spec[field.name]
2121
if not field.dims:
2222
raise ValueError(f"Field {field} missing dims")
2323

flopy4/mf6/gwf/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from pathlib import Path
22
from typing import Optional
33

4-
import attrs
54
from attrs import define
65
from flopy.discretization.grid import Grid
76
from xattree import field, xattree
@@ -18,12 +17,11 @@
1817

1918
@xattree
2019
class Gwf(Model):
21-
dis: Dis = field()
20+
dis: Dis = field(converter=lambda grid: Dis.from_grid(grid))
2221
ic: Ic = field()
2322
oc: Oc = field()
2423
npf: Npf = field()
2524
chd: list[Chd] = field()
26-
grid: Grid = attrs.field(default=None)
2725

2826
@define
2927
class NewtonOptions:
@@ -46,3 +44,7 @@ class NewtonOptions:
4644
nc_filerecord: Optional[Path] = field(
4745
default=None, metadata={"block": "options"}
4846
)
47+
48+
@property
49+
def grid(self) -> Grid:
50+
return self.dis.to_grid()

flopy4/mf6/gwf/dis.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from attrs import Converter
3+
from flopy.discretization.structuredgrid import StructuredGrid
34
from numpy.typing import NDArray
45
from xattree import array, dim, field, xattree
56

@@ -82,3 +83,49 @@ class Dis(Package):
8283

8384
def __attrs_post_init__(self):
8485
self.nnodes = self.ncol * self.nrow * self.nlay
86+
87+
def to_grid(self) -> StructuredGrid:
88+
"""
89+
Convert the discretization to a `StructuredGrid`.
90+
91+
Returns
92+
-------
93+
StructuredGrid
94+
A `StructuredGrid` with the same dimensions and data as the `Dis`.
95+
"""
96+
return StructuredGrid(
97+
nlay=self.nlay,
98+
nrow=self.nrow,
99+
ncol=self.ncol,
100+
delr=self.delr,
101+
delc=self.delc,
102+
top=self.top,
103+
botm=self.botm,
104+
idomain=self.idomain,
105+
)
106+
107+
@classmethod
108+
def from_grid(cls, grid: StructuredGrid) -> "Dis":
109+
"""
110+
Create a discretization from a `StructuredGrid`.
111+
112+
Parameters
113+
----------
114+
grid : StructuredGrid
115+
A structured grid.
116+
117+
Returns
118+
-------
119+
Dis
120+
A discretization with the same dimensions and data as the grid.
121+
"""
122+
return Dis(
123+
nlay=grid.nlay,
124+
nrow=grid.nrow,
125+
ncol=grid.ncol,
126+
delr=grid.delr,
127+
delc=grid.delc,
128+
top=grid.top,
129+
botm=grid.botm,
130+
idomain=grid.idomain,
131+
)

flopy4/mf6/simulation.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import attrs
2-
from flopy.discretization.modeltime import ModelTime
31
from xattree import field, xattree
42

53
from flopy4.mf6.component import Component
@@ -14,5 +12,4 @@ class Simulation(Component):
1412
models: dict[str, Model] = field()
1513
exchanges: dict[str, Exchange] = field()
1614
solutions: dict[str, Solution] = field()
17-
tdis: Tdis = field()
18-
time: ModelTime = attrs.field(default=None)
15+
tdis: Tdis = field(converter=lambda time: Tdis.from_time(time))

flopy4/mf6/tdis.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from attrs import Converter, define
6+
from flopy.discretization.modeltime import ModelTime
67
from numpy.typing import NDArray
78
from xattree import ROOT, array, dim, field, xattree
89

@@ -48,3 +49,26 @@ class PeriodData:
4849
metadata={"block": "perioddata"},
4950
converter=Converter(convert_array, takes_self=True, takes_field=True),
5051
)
52+
53+
def to_time(self) -> ModelTime:
54+
"""Convert to a `ModelTime` object."""
55+
return ModelTime(
56+
nper=self.nper,
57+
time_units=self.time_units,
58+
start_date_time=self.start_date_time,
59+
perlen=self.perlen,
60+
nstp=self.nstp,
61+
tsmult=self.tsmult,
62+
)
63+
64+
@classmethod
65+
def from_time(cls, time: ModelTime) -> "Tdis":
66+
"""Create a time discretization from a `ModelTime`."""
67+
return cls(
68+
nper=time.nper,
69+
time_units=time.time_units,
70+
start_date_time=time.start_datetime,
71+
perlen=time.perlen,
72+
nstp=time.nstp,
73+
tsmult=time.tsmult,
74+
)

pixi.lock

Lines changed: 220 additions & 200 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

test/test_component.py

Lines changed: 29 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -49,44 +49,12 @@ def test_init_gwf_explicit_dims():
4949
)
5050

5151
assert isinstance(gwf.data, DataTree)
52-
assert gwf.dis is dis
52+
assert gwf.dis is not dis # dimension order switched.. is this ok?
5353
assert gwf.ic is ic
5454
assert gwf.oc is oc
5555
assert gwf.npf is npf
5656
assert gwf.chd[0] is chd
57-
assert gwf.data.dis is dis.data
58-
assert gwf.data.ic is ic.data
59-
assert gwf.data.oc is oc.data
60-
assert gwf.data.npf is npf.data
61-
assert np.array_equal(npf.k, np.ones(4))
62-
assert np.array_equal(npf.data.k, np.ones(4))
63-
64-
65-
@pytest.mark.skip(reason="TODO")
66-
def test_init_gwf_from_grid():
67-
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
68-
grid = StructuredGrid(nlay=1, nrow=2, ncol=2)
69-
dis = Dis(grid=grid)
70-
ic = Ic(grid=grid)
71-
oc = Oc(grid=grid)
72-
npf = Npf(grid=grid)
73-
chd = Chd(grid=grid)
74-
gwf = Gwf(
75-
dis=dis,
76-
ic=ic,
77-
oc=oc,
78-
npf=npf,
79-
chd=[chd],
80-
grid=grid,
81-
)
82-
83-
assert isinstance(gwf.data, DataTree)
84-
assert gwf.dis is dis
85-
assert gwf.ic is ic
86-
assert gwf.oc is oc
87-
assert gwf.npf is npf
88-
assert gwf.chd[0] is chd
89-
assert gwf.data.dis is dis.data
57+
assert gwf.data.dis is not dis.data
9058
assert gwf.data.ic is ic.data
9159
assert gwf.data.oc is oc.data
9260
assert gwf.data.npf is npf.data
@@ -138,7 +106,7 @@ def test_init_gwf_dis_first():
138106
chd = Chd(parent=gwf, strict=False)
139107

140108
assert isinstance(gwf.data, DataTree)
141-
assert gwf.dis is dis
109+
assert gwf.dis is not dis
142110
assert gwf.ic is ic
143111
assert gwf.oc is oc
144112
assert gwf.npf is npf
@@ -147,6 +115,25 @@ def test_init_gwf_dis_first():
147115
assert np.array_equal(npf.data.k, np.ones(4))
148116

149117

118+
def test_init_gwf_dis_first_with_grid():
119+
grid = StructuredGrid(nlay=1, nrow=10, ncol=10)
120+
gwf = Gwf(dis=grid)
121+
dis = gwf.dis
122+
ic = Ic(parent=gwf)
123+
oc = Oc(parent=gwf, strict=False)
124+
npf = Npf(parent=gwf)
125+
chd = Chd(parent=gwf, strict=False)
126+
127+
assert isinstance(gwf.data, DataTree)
128+
assert gwf.dis is dis
129+
assert gwf.ic is ic
130+
assert gwf.oc is oc
131+
assert gwf.npf is npf
132+
assert gwf.chd[0] is chd
133+
assert np.array_equal(npf.k, np.ones(100))
134+
assert np.array_equal(npf.data.k, np.ones(100))
135+
136+
150137
def test_init_gwf_top_down_misaligned():
151138
grid = StructuredGrid(nlay=1, nrow=10, ncol=10)
152139
dims = {
@@ -198,7 +185,7 @@ def test_init_sim_explicit_dims():
198185
assert isinstance(sim.data, DataTree)
199186
assert sim.data.tdis is tdis.data
200187
assert sim.data.gwf is gwf.data
201-
assert gwf.dis is dis
188+
assert gwf.dis is not dis # gwf.dis has inherited dim nper
202189
assert gwf.ic is ic
203190
assert gwf.oc is oc
204191
assert gwf.npf is npf
@@ -219,35 +206,16 @@ def test_init_big_sim():
219206
# if size over threshold, arrays should be sparse
220207
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
221208
grid = StructuredGrid(nlay=1, nrow=100, ncol=100)
222-
dims = {
223-
"nlay": grid.nlay,
224-
"nrow": grid.nrow,
225-
"ncol": grid.ncol,
226-
}
227-
dis = Dis(**dims)
228-
dims["nper"] = time.nper
229-
dims["nnodes"] = grid.nnodes
230-
ic = Ic(dims=dims)
231-
oc = Oc(dims=dims)
232-
npf = Npf(dims=dims)
233-
chd = Chd(dims=dims, head={"*": {(0, 0, 0): 1.0, (0, 99, 99): 0.0}})
234-
gwf = Gwf(
235-
dis=dis,
236-
ic=ic,
237-
oc=oc,
238-
npf=npf,
239-
chd=[chd],
240-
dims=dims,
241-
)
242-
tdis = Tdis(dims=dims)
243-
sim = Simulation(tdis=tdis, models={"gwf": gwf})
209+
sim = Simulation(tdis=time)
210+
gwf = Gwf(parent=sim, dis=grid)
211+
ic = Ic(parent=gwf)
212+
oc = Oc(parent=gwf)
213+
npf = Npf(parent=gwf)
214+
chd = Chd(parent=gwf, head={"*": {(0, 0, 0): 1.0, (0, 99, 99): 0.0}})
244215

245-
assert sim.tdis is tdis
246216
assert sim.models["gwf"] is gwf
247217
assert isinstance(sim.data, DataTree)
248-
assert sim.data.tdis is tdis.data
249218
assert sim.data.gwf is gwf.data
250-
assert gwf.dis is dis
251219
assert gwf.ic is ic
252220
assert gwf.oc is oc
253221
assert gwf.npf is npf

0 commit comments

Comments
 (0)