Skip to content

Commit af12022

Browse files
authored
factor out get_nn utility function (#162)
1 parent de75439 commit af12022

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

flopy4/adapters.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,18 @@ def get_cellid(nn: int, grid: Grid) -> tuple[int, ...]:
142142
return (nn,)
143143
case _:
144144
raise TypeError(f"Unsupported grid type: {type(grid)}")
145+
146+
147+
def get_nn(cellid, **kwargs):
148+
ndim = len(cellid)
149+
match ndim:
150+
case 1:
151+
return cellid[0]
152+
case 2:
153+
k, j = cellid
154+
return k * kwargs["ncpl"] + j
155+
case 3:
156+
k, i, j = cellid
157+
return k * kwargs["nrow"] * kwargs["ncol"] + i * kwargs["ncol"] + j
158+
case _:
159+
raise ValueError(f"Invalid cellid: {cellid}")

flopy4/mf6/codec/converter.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from xarray import DataArray
88
from xattree import get_xatspec
99

10-
from flopy4.adapters import get_cellid
10+
from flopy4.adapters import get_cellid, get_nn
1111
from flopy4.mf6.component import Component
1212
from flopy4.mf6.config import SPARSE_THRESHOLD
1313
from flopy4.mf6.constants import FILL_DNODATA
@@ -66,19 +66,6 @@ def final(arr):
6666
arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA
6767
return arr
6868

69-
def _get_nn(cellid):
70-
match len(cellid):
71-
case 1:
72-
return cellid[0]
73-
case 2:
74-
k, j = cellid
75-
return k * dims["ncpl"] + j
76-
case 3:
77-
k, i, j = cellid
78-
return k * dims["nrow"] * dims["ncol"] + i * dims["ncol"] + j
79-
case _:
80-
raise ValueError(f"Invalid cellid: {cellid}")
81-
8269
# populate array. TODO: is there a way to do this
8370
# without hardcoding awareness of kper and cellid?
8471
if "nper" in dims:
@@ -90,13 +77,13 @@ def _get_nn(cellid):
9077
set_(a, period, kper)
9178
case _:
9279
for cellid, v in period.items():
93-
nn = _get_nn(cellid)
80+
nn = get_nn(cellid, **dims)
9481
set_(a, v, kper, nn)
9582
if kper == "*":
9683
break
9784
else:
9885
for cellid, v in value.items():
99-
nn = _get_nn(cellid)
86+
nn = get_nn(cellid, **dims)
10087
set_(a, v, nn)
10188
return final(a)
10289

0 commit comments

Comments
 (0)