Skip to content

Commit d0b4aa4

Browse files
committed
add threshold, separate test
1 parent f23139b commit d0b4aa4

File tree

5 files changed

+2144
-2112
lines changed

5 files changed

+2144
-2112
lines changed

flopy4/mf6/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# TODO use https://environ-config.readthedocs.io/en/stable/?
2+
3+
SPARSE_THRESHOLD = 1000

flopy4/mf6/converters.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numpy.typing import NDArray
66
from xattree import _get_xatspec
77

8+
from flopy4.mf6.config import SPARSE_THRESHOLD
89
from flopy4.mf6.constants import FILL_DNODATA
910

1011

@@ -29,7 +30,29 @@ def convert_array(value, self_, field) -> NDArray:
2930
if any(unresolved):
3031
raise ValueError(f"Couldn't resolve dims: {unresolved}")
3132

32-
a: dict[Tuple[Any, ...], Any] = dict()
33+
if np.prod(shape) > SPARSE_THRESHOLD:
34+
a: dict[Tuple[Any, ...], Any] = dict()
35+
36+
def set_(arr, val, *ind):
37+
arr[tuple(ind)] = val
38+
39+
def final(arr):
40+
coords = np.array(list(map(list, zip(*arr.keys()))))
41+
return sparse.COO(
42+
coords,
43+
list(arr.values()),
44+
shape=shape,
45+
fill_value=field.default or FILL_DNODATA,
46+
)
47+
else:
48+
a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore
49+
50+
def set_(arr, val, *ind):
51+
arr[ind] = val
52+
53+
def final(arr):
54+
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
55+
return arr
3356

3457
def _get_nn(cellid):
3558
match len(cellid):
@@ -52,22 +75,19 @@ def _get_nn(cellid):
5275
kper = 0
5376
match len(shape):
5477
case 1:
55-
a[(kper,)] = period
78+
set_(a, period, kper)
79+
# a[(kper,)] = period
5680
case _:
5781
for cellid, v in period.items():
5882
nn = _get_nn(cellid)
59-
a[(kper, nn)] = v
83+
set_(a, v, kper, nn)
84+
# a[(kper, nn)] = v
6085
if kper == "*":
6186
break
6287
else:
6388
for cellid, v in value.items():
6489
nn = _get_nn(cellid)
65-
a[(nn,)] = v
90+
set_(a, v, nn)
91+
# a[(nn,)] = v
6692

67-
coords = np.array(list(map(list, zip(*a.keys()))))
68-
return sparse.COO(
69-
coords,
70-
list(a.values()),
71-
shape=shape,
72-
fill_value=field.default or FILL_DNODATA,
73-
)
93+
return final(a)

0 commit comments

Comments
 (0)