Skip to content

Commit 0a0f19a

Browse files
author
wpbonelli
committed
generalize maxbound sync
1 parent 7f93165 commit 0a0f19a

File tree

5 files changed

+74
-224
lines changed

5 files changed

+74
-224
lines changed

flopy4/mf6/attr_hooks.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Attribute hooks for attrs on_setattr callbacks."""
2+
3+
import numpy as np
4+
from attrs import fields
5+
6+
from flopy4.mf6.constants import FILL_DNODATA
7+
8+
9+
def update_maxbound(instance, attribute, new_value):
10+
"""
11+
Generalized function to update maxbound when period block arrays change.
12+
13+
This function automatically finds all period block arrays in the instance
14+
and calculates maxbound based on the maximum number of non-default values
15+
across all arrays.
16+
17+
Args:
18+
instance: The package instance
19+
attribute: The attribute being set (from attrs on_setattr)
20+
new_value: The new value being set
21+
22+
Returns:
23+
The new_value (unchanged)
24+
"""
25+
26+
period_arrays = []
27+
instance_fields = fields(instance.__class__)
28+
for field in instance_fields:
29+
if field.metadata and field.metadata.get("block") == "period" and "dims" in field.metadata:
30+
period_arrays.append(field.name)
31+
32+
maxbound_values = []
33+
for array_name in period_arrays:
34+
if attribute and attribute.name == array_name:
35+
array_val = new_value
36+
else:
37+
array_val = getattr(instance, array_name, None)
38+
39+
if array_val is not None:
40+
array_data = (
41+
array_val if array_val.data.shape == array_val.shape else array_val.todense()
42+
)
43+
44+
if array_data.dtype.kind in ["U", "S"]: # String arrays
45+
non_default_count = len(np.where(array_data != "")[0])
46+
else: # Numeric arrays
47+
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])
48+
49+
maxbound_values.append(non_default_count)
50+
if maxbound_values:
51+
instance.maxbound = max(maxbound_values)
52+
53+
return new_value

flopy4/mf6/gwf/chd.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,59 +6,12 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.constants import FILL_DNODATA
9+
from flopy4.mf6.attr_hooks import update_maxbound
1010
from flopy4.mf6.converters import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
1313

1414

15-
def _update_maxbound(instance, attribute, new_value):
16-
"""Update maxbound when period block arrays change."""
17-
if hasattr(instance, "_updating_maxbound"):
18-
return new_value
19-
20-
# Calculate maxbound from all relevant arrays
21-
maxbound_values = []
22-
23-
# Check head array
24-
head_val = (
25-
new_value if attribute and attribute.name == "head" else getattr(instance, "head", None)
26-
)
27-
if head_val is not None:
28-
head = head_val if head_val.data.shape == head_val.shape else head_val.todense()
29-
maxbound_values.append(len(np.where(head != FILL_DNODATA)[0]))
30-
31-
# Check aux array
32-
aux_val = new_value if attribute and attribute.name == "aux" else getattr(instance, "aux", None)
33-
if aux_val is not None:
34-
aux = aux_val if aux_val.data.shape == aux_val.shape else aux_val.todense()
35-
maxbound_values.append(len(np.where(aux != FILL_DNODATA)[0]))
36-
37-
# Check boundname array
38-
boundname_val = (
39-
new_value
40-
if attribute and attribute.name == "boundname"
41-
else getattr(instance, "boundname", None)
42-
)
43-
if boundname_val is not None:
44-
boundname = (
45-
boundname_val
46-
if boundname_val.data.shape == boundname_val.shape
47-
else boundname_val.todense()
48-
)
49-
maxbound_values.append(len(np.where(boundname != "")[0]))
50-
51-
# Update maxbound if we have values
52-
if maxbound_values:
53-
instance._updating_maxbound = True
54-
try:
55-
instance.maxbound = max(maxbound_values)
56-
finally:
57-
delattr(instance, "_updating_maxbound")
58-
59-
return new_value
60-
61-
6215
@xattree
6316
class Chd(Package):
6417
multi_package: ClassVar[bool] = True
@@ -81,7 +34,7 @@ class Chd(Package):
8134
default=None,
8235
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
8336
reader="urword",
84-
on_setattr=_update_maxbound,
37+
on_setattr=update_maxbound,
8538
)
8639
aux: Optional[NDArray[np.float64]] = array(
8740
block="period",
@@ -92,7 +45,7 @@ class Chd(Package):
9245
default=None,
9346
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
9447
reader="urword",
95-
on_setattr=_update_maxbound,
48+
on_setattr=update_maxbound,
9649
)
9750
boundname: Optional[NDArray[np.str_]] = array(
9851
block="period",
@@ -103,10 +56,9 @@ class Chd(Package):
10356
default=None,
10457
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
10558
reader="urword",
106-
on_setattr=_update_maxbound,
59+
on_setattr=update_maxbound,
10760
)
10861

10962
def __attrs_post_init__(self):
110-
# Trigger maxbound calculation on initialization
11163
if self.head is not None or self.aux is not None or self.boundname is not None:
112-
_update_maxbound(self, None, None)
64+
update_maxbound(self, None, None)

flopy4/mf6/gwf/drn.py

Lines changed: 6 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,67 +6,12 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.constants import FILL_DNODATA
9+
from flopy4.mf6.attr_hooks import update_maxbound
1010
from flopy4.mf6.converters import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
1313

1414

15-
def _update_maxbound(instance, attribute, new_value):
16-
"""Update maxbound when period block arrays change."""
17-
if hasattr(instance, "_updating_maxbound"):
18-
return new_value
19-
20-
# Calculate maxbound from all relevant arrays
21-
maxbound_values = []
22-
23-
# Check elev array
24-
elev_val = (
25-
new_value if attribute and attribute.name == "elev" else getattr(instance, "elev", None)
26-
)
27-
if elev_val is not None:
28-
elev = elev_val if elev_val.data.shape == elev_val.shape else elev_val.todense()
29-
maxbound_values.append(len(np.where(elev != FILL_DNODATA)[0]))
30-
31-
# Check cond array
32-
cond_val = (
33-
new_value if attribute and attribute.name == "cond" else getattr(instance, "cond", None)
34-
)
35-
if cond_val is not None:
36-
cond = cond_val if cond_val.data.shape == cond_val.shape else cond_val.todense()
37-
maxbound_values.append(len(np.where(cond != FILL_DNODATA)[0]))
38-
39-
# Check aux array
40-
aux_val = new_value if attribute and attribute.name == "aux" else getattr(instance, "aux", None)
41-
if aux_val is not None:
42-
aux = aux_val if aux_val.data.shape == aux_val.shape else aux_val.todense()
43-
maxbound_values.append(len(np.where(aux != FILL_DNODATA)[0]))
44-
45-
# Check boundname array
46-
boundname_val = (
47-
new_value
48-
if attribute and attribute.name == "boundname"
49-
else getattr(instance, "boundname", None)
50-
)
51-
if boundname_val is not None:
52-
boundname = (
53-
boundname_val
54-
if boundname_val.data.shape == boundname_val.shape
55-
else boundname_val.todense()
56-
)
57-
maxbound_values.append(len(np.where(boundname != "")[0]))
58-
59-
# Update maxbound if we have values
60-
if maxbound_values:
61-
instance._updating_maxbound = True
62-
try:
63-
instance.maxbound = max(maxbound_values)
64-
finally:
65-
delattr(instance, "_updating_maxbound")
66-
67-
return new_value
68-
69-
7015
@xattree
7116
class Drn(Package):
7217
multi_package: ClassVar[bool] = True
@@ -88,15 +33,15 @@ class Drn(Package):
8833
default=None,
8934
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
9035
reader="urword",
91-
on_setattr=_update_maxbound,
36+
on_setattr=update_maxbound,
9237
)
9338
cond: Optional[NDArray[np.float64]] = array(
9439
block="period",
9540
dims=("nper", "nnodes"),
9641
default=None,
9742
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
9843
reader="urword",
99-
on_setattr=_update_maxbound,
44+
on_setattr=update_maxbound,
10045
)
10146
aux: Optional[NDArray[np.float64]] = array(
10247
block="period",
@@ -107,7 +52,7 @@ class Drn(Package):
10752
default=None,
10853
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
10954
reader="urword",
110-
on_setattr=_update_maxbound,
55+
on_setattr=update_maxbound,
11156
)
11257
boundname: Optional[NDArray[np.str_]] = array(
11358
block="period",
@@ -118,15 +63,14 @@ class Drn(Package):
11863
default=None,
11964
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
12065
reader="urword",
121-
on_setattr=_update_maxbound,
66+
on_setattr=update_maxbound,
12267
)
12368

12469
def __attrs_post_init__(self):
125-
# Trigger maxbound calculation on initialization
12670
if (
12771
self.elev is not None
12872
or self.cond is not None
12973
or self.aux is not None
13074
or self.boundname is not None
13175
):
132-
_update_maxbound(self, None, None)
76+
update_maxbound(self, None, None)

flopy4/mf6/gwf/rch.py

Lines changed: 5 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,65 +6,12 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.constants import FILL_DNODATA
9+
from flopy4.mf6.attr_hooks import update_maxbound
1010
from flopy4.mf6.converters import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
1313

1414

15-
def _update_maxbound(instance, attribute, new_value):
16-
"""Update maxbound when period block arrays change."""
17-
if hasattr(instance, "_updating_maxbound"):
18-
return new_value
19-
20-
# Calculate maxbound from all relevant arrays
21-
maxbound_values = []
22-
23-
# Check recharge array
24-
recharge_val = (
25-
new_value
26-
if attribute and attribute.name == "recharge"
27-
else getattr(instance, "recharge", None)
28-
)
29-
if recharge_val is not None:
30-
recharge = (
31-
recharge_val
32-
if recharge_val.data.shape == recharge_val.shape
33-
else recharge_val.todense()
34-
)
35-
maxbound_values.append(len(np.where(recharge != FILL_DNODATA)[0]))
36-
37-
# Check aux array
38-
aux_val = new_value if attribute and attribute.name == "aux" else getattr(instance, "aux", None)
39-
if aux_val is not None:
40-
aux = aux_val if aux_val.data.shape == aux_val.shape else aux_val.todense()
41-
maxbound_values.append(len(np.where(aux != FILL_DNODATA)[0]))
42-
43-
# Check boundname array
44-
boundname_val = (
45-
new_value
46-
if attribute and attribute.name == "boundname"
47-
else getattr(instance, "boundname", None)
48-
)
49-
if boundname_val is not None:
50-
boundname = (
51-
boundname_val
52-
if boundname_val.data.shape == boundname_val.shape
53-
else boundname_val.todense()
54-
)
55-
maxbound_values.append(len(np.where(boundname != "")[0]))
56-
57-
# Update maxbound if we have values
58-
if maxbound_values:
59-
instance._updating_maxbound = True
60-
try:
61-
instance.maxbound = max(maxbound_values)
62-
finally:
63-
delattr(instance, "_updating_maxbound")
64-
65-
return new_value
66-
67-
6815
@xattree
6916
class Rch(Package):
7017
multi_package: ClassVar[bool] = True
@@ -87,7 +34,7 @@ class Rch(Package):
8734
default=None,
8835
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
8936
reader="urword",
90-
on_setattr=_update_maxbound,
37+
on_setattr=update_maxbound,
9138
)
9239
aux: Optional[NDArray[np.float64]] = array(
9340
block="period",
@@ -98,7 +45,7 @@ class Rch(Package):
9845
default=None,
9946
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
10047
reader="urword",
101-
on_setattr=_update_maxbound,
48+
on_setattr=update_maxbound,
10249
)
10350
boundname: Optional[NDArray[np.str_]] = array(
10451
block="period",
@@ -109,10 +56,9 @@ class Rch(Package):
10956
default=None,
11057
converter=Converter(dict_to_array, takes_self=True, takes_field=True),
11158
reader="urword",
112-
on_setattr=_update_maxbound,
59+
on_setattr=update_maxbound,
11360
)
11461

11562
def __attrs_post_init__(self):
116-
# Trigger maxbound calculation on initialization
11763
if self.recharge is not None or self.aux is not None or self.boundname is not None:
118-
_update_maxbound(self, None, None)
64+
update_maxbound(self, None, None)

0 commit comments

Comments
 (0)