Skip to content

Commit 4709fa1

Browse files
authored
move maxbound update logic to base component (#183)
1 parent e2359da commit 4709fa1

File tree

6 files changed

+89
-82
lines changed

6 files changed

+89
-82
lines changed

flopy4/mf6/attr_hooks.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

flopy4/mf6/component.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,67 @@
33
from pathlib import Path
44
from typing import ClassVar
55

6+
import numpy as np
7+
from attrs import fields
68
from modflow_devtools.dfn import Dfn, Field
79
from xattree import xattree
810

11+
from flopy4.mf6.constants import FILL_DNODATA
912
from flopy4.mf6.spec import field, fields_dict, to_dfn_field
1013
from flopy4.uio import IO, Loader, Writer
1114

15+
16+
def update_maxbound(instance, attribute, new_value):
17+
"""
18+
Generalized function to update maxbound when period block arrays change.
19+
20+
This function automatically finds all period block arrays in the instance
21+
and calculates maxbound based on the maximum number of non-default values
22+
across all arrays.
23+
24+
Args:
25+
instance: The package instance
26+
attribute: The attribute being set (from attrs on_setattr)
27+
new_value: The new value being set
28+
29+
Returns:
30+
The new_value (unchanged)
31+
"""
32+
33+
period_arrays = []
34+
instance_fields = fields(instance.__class__)
35+
for f in instance_fields:
36+
if (
37+
f.metadata
38+
and f.metadata.get("block") == "period"
39+
and f.metadata.get("xattree", {}).get("dims")
40+
):
41+
period_arrays.append(f.name)
42+
43+
maxbound_values = []
44+
for array_name in period_arrays:
45+
if attribute and attribute.name == array_name:
46+
array_val = new_value
47+
else:
48+
array_val = getattr(instance, array_name, None)
49+
50+
if array_val is not None:
51+
array_data = (
52+
array_val if array_val.data.shape == array_val.shape else array_val.todense()
53+
)
54+
55+
if array_data.dtype.kind in ["U", "S"]: # String arrays
56+
non_default_count = len(np.where(array_data != "")[0])
57+
else: # Numeric arrays
58+
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])
59+
60+
maxbound_values.append(non_default_count)
61+
if maxbound_values:
62+
instance.maxbound = max(maxbound_values)
63+
64+
return new_value
65+
66+
1267
COMPONENTS = {}
1368
"""MF6 component registry."""
1469

@@ -50,6 +105,36 @@ def default_filename(self) -> str:
50105
cls_name = self.__class__.__name__.lower()
51106
return f"{name}.{cls_name}"
52107

108+
def __attrs_post_init__(self):
109+
"""
110+
Post-initialization hook for all components.
111+
112+
Automatically handles common post-init tasks like computing maxbound
113+
for components with period block arrays.
114+
"""
115+
self._update_maxbound_if_needed()
116+
117+
def _update_maxbound_if_needed(self):
118+
"""
119+
Update maxbound if this component has period block arrays.
120+
121+
This method checks if the component has any period block arrays defined
122+
and calls update_maxbound if needed. This generalizes the pattern that
123+
was previously repeated in multiple component classes.
124+
"""
125+
# Check if component has a maxbound field and period block arrays
126+
component_fields = fields(self.__class__)
127+
has_maxbound = any(f.name == "maxbound" for f in component_fields)
128+
has_period_arrays = any(
129+
f.metadata
130+
and f.metadata.get("block") == "period"
131+
and f.metadata.get("xattree", {}).get("dims")
132+
for f in component_fields
133+
)
134+
135+
if has_maxbound and has_period_arrays:
136+
update_maxbound(self, None, None)
137+
53138
@classmethod
54139
def __attrs_init_subclass__(cls):
55140
COMPONENTS[cls.__name__.lower()] = cls

flopy4/mf6/gwf/chd.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.attr_hooks import update_maxbound
9+
from flopy4.mf6.component import update_maxbound
1010
from flopy4.mf6.converters import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
@@ -58,7 +58,3 @@ class Chd(Package):
5858
reader="urword",
5959
on_setattr=update_maxbound,
6060
)
61-
62-
def __attrs_post_init__(self):
63-
if self.head is not None or self.aux is not None or self.boundname is not None:
64-
update_maxbound(self, None, None)

flopy4/mf6/gwf/drn.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.attr_hooks import update_maxbound
9+
from flopy4.mf6.component import update_maxbound
1010
from flopy4.mf6.converters import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
@@ -65,12 +65,3 @@ class Drn(Package):
6565
reader="urword",
6666
on_setattr=update_maxbound,
6767
)
68-
69-
def __attrs_post_init__(self):
70-
if (
71-
self.elev is not None
72-
or self.cond is not None
73-
or self.aux is not None
74-
or self.boundname is not None
75-
):
76-
update_maxbound(self, None, None)

flopy4/mf6/gwf/rch.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.attr_hooks import update_maxbound
9+
from flopy4.mf6.component import update_maxbound
1010
from flopy4.mf6.converters import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
@@ -58,7 +58,3 @@ class Rch(Package):
5858
reader="urword",
5959
on_setattr=update_maxbound,
6060
)
61-
62-
def __attrs_post_init__(self):
63-
if self.recharge is not None or self.aux is not None or self.boundname is not None:
64-
update_maxbound(self, None, None)

flopy4/mf6/gwf/wel.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numpy.typing import NDArray
77
from xattree import xattree
88

9-
from flopy4.mf6.attr_hooks import update_maxbound
9+
from flopy4.mf6.component import update_maxbound
1010
from flopy4.mf6.converters import dict_to_array
1111
from flopy4.mf6.package import Package
1212
from flopy4.mf6.spec import array, field
@@ -60,7 +60,3 @@ class Wel(Package):
6060
reader="urword",
6161
on_setattr=update_maxbound,
6262
)
63-
64-
def __attrs_post_init__(self):
65-
if self.q is not None or self.aux is not None or self.boundname is not None:
66-
update_maxbound(self, None, None)

0 commit comments

Comments
 (0)