Skip to content

Commit e2bdbb0

Browse files
committed
move update_maxbound to utils module
1 parent 6278a42 commit e2bdbb0

File tree

6 files changed

+60
-58
lines changed

6 files changed

+60
-58
lines changed

flopy4/mf6/component.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,69 +3,17 @@
33
from pathlib import Path
44
from typing import Any, ClassVar
55

6-
import numpy as np
76
from attrs import fields
87
from modflow_devtools.dfn import Dfn, Field
98
from packaging.version import Version
109
from xattree import asdict as xattree_asdict
1110
from xattree import xattree
1211

13-
from flopy4.mf6.constants import FILL_DNODATA, MF6
12+
from flopy4.mf6.constants import MF6
1413
from flopy4.mf6.spec import field, fields_dict, to_field
14+
from flopy4.mf6.utils.grid_utils import update_maxbound
1515
from flopy4.uio import IO, Loader, Writer
1616

17-
18-
def update_maxbound(instance, attribute, new_value):
19-
"""
20-
Generalized function to update maxbound when period block arrays change.
21-
22-
This function automatically finds all period block arrays in the instance
23-
and calculates maxbound based on the maximum number of non-default values
24-
across all arrays.
25-
26-
Args:
27-
instance: The package instance
28-
attribute: The attribute being set (from attrs on_setattr)
29-
new_value: The new value being set
30-
31-
Returns:
32-
The new_value (unchanged)
33-
"""
34-
35-
period_arrays = []
36-
instance_fields = fields(instance.__class__)
37-
for f in instance_fields:
38-
if (
39-
f.metadata
40-
and f.metadata.get("block") == "period"
41-
and f.metadata.get("xattree", {}).get("dims")
42-
):
43-
period_arrays.append(f.name)
44-
45-
maxbound_values = []
46-
for array_name in period_arrays:
47-
if attribute and attribute.name == array_name:
48-
array_val = new_value
49-
else:
50-
array_val = getattr(instance, array_name, None)
51-
52-
if array_val is not None:
53-
array_data = (
54-
array_val if array_val.data.shape == array_val.shape else array_val.todense()
55-
)
56-
57-
if array_data.dtype.kind in ["U", "S"]: # String arrays
58-
non_default_count = len(np.where(array_data != "")[0])
59-
else: # Numeric arrays
60-
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])
61-
62-
maxbound_values.append(non_default_count)
63-
if maxbound_values:
64-
instance.maxbound = max(maxbound_values)
65-
66-
return new_value
67-
68-
6917
COMPONENTS = {}
7018
"""MF6 component registry."""
7119

flopy4/mf6/gwf/chd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.component import update_maxbound
109
from flopy4.mf6.constants import LENBOUNDNAME
1110
from flopy4.mf6.converter import dict_to_array
1211
from flopy4.mf6.package import Package
1312
from flopy4.mf6.spec import array, field, path
13+
from flopy4.mf6.utils.grid_utils import update_maxbound
1414
from flopy4.utils import to_path
1515

1616

flopy4/mf6/gwf/drn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.component import update_maxbound
109
from flopy4.mf6.constants import LENBOUNDNAME
1110
from flopy4.mf6.converter import dict_to_array
1211
from flopy4.mf6.package import Package
1312
from flopy4.mf6.spec import array, field, path
13+
from flopy4.mf6.utils.grid_utils import update_maxbound
1414
from flopy4.utils import to_path
1515

1616

flopy4/mf6/gwf/rch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.component import update_maxbound
109
from flopy4.mf6.constants import LENBOUNDNAME
1110
from flopy4.mf6.converter import dict_to_array
1211
from flopy4.mf6.package import Package
1312
from flopy4.mf6.spec import array, field, path
13+
from flopy4.mf6.utils.grid_utils import update_maxbound
1414
from flopy4.utils import to_path
1515

1616

flopy4/mf6/gwf/wel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.component import update_maxbound
109
from flopy4.mf6.constants import LENBOUNDNAME
1110
from flopy4.mf6.converter import dict_to_array
1211
from flopy4.mf6.package import Package
1312
from flopy4.mf6.spec import array, field, path
13+
from flopy4.mf6.utils.grid_utils import update_maxbound
1414
from flopy4.utils import to_path
1515

1616

flopy4/mf6/utils/grid_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from typing import Any
33

44
import numpy as np
5+
from attrs import fields
56
from flopy.discretization import StructuredGrid
67

8+
from flopy4.mf6.constants import FILL_DNODATA
9+
710

811
def get_coords(grid: StructuredGrid) -> dict[str, Any]:
912
# unpack tuples
@@ -31,3 +34,54 @@ def get_coords(grid: StructuredGrid) -> dict[str, Any]:
3134
coords["dy"] = ("y", dy)
3235
coords["layer"] = np.arange(1, grid.nlay + 1)
3336
return coords
37+
38+
39+
def update_maxbound(instance, attribute, new_value):
40+
"""
41+
Generalized function to update maxbound when period block arrays change.
42+
43+
This function automatically finds all period block arrays in the instance
44+
and calculates maxbound based on the maximum number of non-default values
45+
across all arrays.
46+
47+
Args:
48+
instance: The package instance
49+
attribute: The attribute being set (from attrs on_setattr)
50+
new_value: The new value being set
51+
52+
Returns:
53+
The new_value (unchanged)
54+
"""
55+
56+
period_arrays = []
57+
instance_fields = fields(instance.__class__)
58+
for f in instance_fields:
59+
if (
60+
f.metadata
61+
and f.metadata.get("block") == "period"
62+
and f.metadata.get("xattree", {}).get("dims")
63+
):
64+
period_arrays.append(f.name)
65+
66+
maxbound_values = []
67+
for array_name in period_arrays:
68+
if attribute and attribute.name == array_name:
69+
array_val = new_value
70+
else:
71+
array_val = getattr(instance, array_name, None)
72+
73+
if array_val is not None:
74+
array_data = (
75+
array_val if array_val.data.shape == array_val.shape else array_val.todense()
76+
)
77+
78+
if array_data.dtype.kind in ["U", "S"]: # String arrays
79+
non_default_count = len(np.where(array_data != "")[0])
80+
else: # Numeric arrays
81+
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])
82+
83+
maxbound_values.append(non_default_count)
84+
if maxbound_values:
85+
instance.maxbound = max(maxbound_values)
86+
87+
return new_value

0 commit comments

Comments
 (0)