Skip to content

Commit 0db2281

Browse files
authored
Sparse converter (#106)
1 parent ccc51ea commit 0db2281

File tree

8 files changed

+2302
-2105
lines changed

8 files changed

+2302
-2105
lines changed

.vscode/settings.json

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
{
2+
"editor.formatOnSave": true,
3+
"files.insertFinalNewline": true,
24
"python.testing.pytestArgs": [
35
"test"
46
],
57
"python.testing.unittestEnabled": false,
68
"python.testing.pytestEnabled": true,
79
"[python]": {
8-
"editor.formatOnSave": true,
910
"editor.defaultFormatter": "charliermarsh.ruff",
1011
"editor.codeActionsOnSave": {
1112
"source.fixAll": "explicit"
1213
}
14+
},
15+
"mypy-type-checker.importStrategy": "fromEnvironment",
16+
"files.exclude": {
17+
"**/.git": true,
18+
"**/.svn": true,
19+
"**/.hg": true,
20+
"**/.DS_Store": true,
21+
"**/Thumbs.db": true,
22+
".pixi": true,
23+
".ruff_cache": true,
24+
".pytest_cache": true
1325
}
14-
}
26+
}

docs/examples/quickstart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# check CHD
3030
assert chd.data["head"][0, 0].item() == 1.0
3131
assert chd.data["head"][0, 99].item() == 0.0
32-
assert np.allclose(chd.data["head"][:, 1:99], np.full(98, 1e30))
32+
assert np.allclose(chd.data["head"][:, 1:99].data.todense(), np.full(98, 1e30))
3333

3434
# TODO: xarray index aliasing nlay/ncol/nrow to k/i/j?
3535
# assert chd.data["head"].loc(dict(k=0, i=0, j=0)) == 1.

flopy4/mf6/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# TODO use https://environ-config.readthedocs.io/en/stable/?
2+
3+
SPARSE_THRESHOLD = 1000

flopy4/mf6/converters.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from typing import Any, Tuple
2+
13
import numpy as np
4+
import sparse
25
from numpy.typing import NDArray
36
from xattree import _get_xatspec
47

8+
from flopy4.mf6.config import SPARSE_THRESHOLD
59
from flopy4.mf6.constants import FILL_DNODATA
610

711

@@ -26,11 +30,29 @@ def convert_array(value, self_, field) -> NDArray:
2630
if any(unresolved):
2731
raise ValueError(f"Couldn't resolve dims: {unresolved}")
2832

29-
# create array
30-
# TDOD: support other fill values, configurable by field?
31-
a = np.full(
32-
shape, fill_value=field.default or FILL_DNODATA
33-
) # , dtype=field.dtype)
33+
if np.prod(shape) > SPARSE_THRESHOLD:
34+
a: dict[Tuple[Any, ...], Any] = dict()
35+
36+
def set_(arr, val, *ind):
37+
arr[tuple(ind)] = val
38+
39+
def final(arr):
40+
coords = np.array(list(map(list, zip(*arr.keys()))))
41+
return sparse.COO(
42+
coords,
43+
list(arr.values()),
44+
shape=shape,
45+
fill_value=field.default or FILL_DNODATA,
46+
)
47+
else:
48+
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore
49+
50+
def set_(arr, val, *ind):
51+
arr[ind] = val
52+
53+
def final(arr):
54+
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
55+
return arr
3456

3557
def _get_nn(cellid):
3658
match len(cellid):
@@ -53,16 +75,19 @@ def _get_nn(cellid):
5375
kper = 0
5476
match len(shape):
5577
case 1:
56-
a[kper] = period
78+
set_(a, period, kper)
79+
# a[(kper,)] = period
5780
case _:
5881
for cellid, v in period.items():
5982
nn = _get_nn(cellid)
60-
a[kper, nn] = v
83+
set_(a, v, kper, nn)
84+
# a[(kper, nn)] = v
6185
if kper == "*":
6286
break
6387
else:
6488
for cellid, v in value.items():
6589
nn = _get_nn(cellid)
66-
a[nn] = v
90+
set_(a, v, nn)
91+
# a[(nn,)] = v
6792

68-
return a
93+
return final(a)

pixi.lock

Lines changed: 1591 additions & 1524 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,19 @@ classifiers = [
3434
]
3535
requires-python = ">=3.11"
3636
dependencies = [
37-
"attrs", # todo: bounds?
38-
"cattrs", # todo: bounds?
37+
"attrs", # todo: bounds?
38+
"cattrs", # todo: bounds?
3939
"flopy",
4040
"Jinja2>=3.0",
4141
"numpy>=1.20.3",
4242
"pandas>=2.0.0",
4343
"toml>=0.10",
4444
"networkx>=3.4.2,<4",
45-
"xarray[parallel,io]>=2024.11.0,<2025",
45+
"xarray[parallel,io]>=2024.11.0",
4646
"scipy>=1.14.1,<2",
4747
"modflow-devtools[dfn] @ git+https://github.com/MODFLOW-USGS/modflow-devtools.git",
4848
"xattree @ git+https://github.com/modflowpy/xattree.git",
49+
"sparse>=0.15.5,<1",
4950
]
5051
dynamic = ["version"]
5152

@@ -63,7 +64,7 @@ test = [
6364
"pytest!=8.1.0",
6465
"pytest-dotenv",
6566
"pytest-xdist",
66-
"pytest-benchmark"
67+
"pytest-benchmark",
6768
]
6869
build = ["build", "twine"]
6970

@@ -143,4 +144,4 @@ install = { cmd = "pre-commit install --install-hooks" }
143144
[tool.mypy]
144145
mypy_path = "flopy4"
145146
ignore_missing_imports = true
146-
warn_unreachable = true
147+
warn_unreachable = true

test/test_component.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,62 @@ def test_init_sim_explicit_dims():
140140
assert np.array_equal(sim.models["gwf"].npf.data.k, np.ones(100))
141141
assert chd.head[0, 0] == 1.0
142142
assert chd.head[0, 99] == 0.0
143-
assert np.array_equal(chd.head[0, 1:99], np.full((98,), FILL_DNODATA))
144-
assert np.array_equal(chd.head, chd.data.head)
145-
assert np.array_equal(chd.head, sim.models["gwf"].chd[0].data.head)
143+
assert np.array_equal(chd.head[0, 1:99].data, np.full((98,), FILL_DNODATA))
144+
assert np.array_equal(chd.head.data, chd.data.head.data)
145+
assert np.array_equal(
146+
chd.head.data,
147+
sim.models["gwf"].chd[0].data.head.data,
148+
)
149+
150+
151+
def test_init_big_sim():
152+
# if size over threshold, arrays should be sparse
153+
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
154+
grid = StructuredGrid(nlay=1, nrow=100, ncol=100)
155+
dims = {
156+
"nlay": grid.nlay,
157+
"nrow": grid.nrow,
158+
"ncol": grid.ncol,
159+
}
160+
dis = Dis(**dims)
161+
dims["nper"] = time.nper
162+
dims["nnodes"] = grid.nnodes
163+
ic = Ic(dims=dims)
164+
oc = Oc(dims=dims)
165+
npf = Npf(dims=dims)
166+
chd = Chd(dims=dims, head={"*": {(0, 0, 0): 1.0, (0, 99, 99): 0.0}})
167+
gwf = Gwf(
168+
dis=dis,
169+
ic=ic,
170+
oc=oc,
171+
npf=npf,
172+
chd=[chd],
173+
dims=dims,
174+
)
175+
tdis = Tdis(dims=dims)
176+
sim = Simulation(tdis=tdis, models={"gwf": gwf})
177+
178+
assert sim.tdis is tdis
179+
assert sim.models["gwf"] is gwf
180+
assert isinstance(sim.data, DataTree)
181+
assert sim.data.tdis is tdis.data
182+
assert sim.data.gwf is gwf.data
183+
assert gwf.dis is dis
184+
assert gwf.ic is ic
185+
assert gwf.oc is oc
186+
assert gwf.npf is npf
187+
assert gwf.chd[0] is chd
188+
assert np.array_equal(sim.models["gwf"].npf.k, np.ones(10000))
189+
assert np.array_equal(sim.models["gwf"].npf.data.k, np.ones(10000))
190+
assert chd.head[0, 0] == 1.0
191+
assert chd.head[0, 9999] == 0.0
192+
assert np.array_equal(
193+
chd.head[0, 1:9999].data.todense(), np.full((9998,), FILL_DNODATA)
194+
)
195+
assert np.array_equal(
196+
chd.head.data.todense(), chd.data.head.data.todense()
197+
)
198+
assert np.array_equal(
199+
chd.head.data.todense(),
200+
sim.models["gwf"].chd[0].data.head.data.todense(),
201+
)

0 commit comments

Comments
 (0)