|
3 | 3 | from pathlib import Path |
4 | 4 | from typing import ClassVar |
5 | 5 |
|
| 6 | +import numpy as np |
| 7 | +from attrs import fields |
6 | 8 | from modflow_devtools.dfn import Dfn, Field |
7 | 9 | from xattree import xattree |
8 | 10 |
|
| 11 | +from flopy4.mf6.constants import FILL_DNODATA |
9 | 12 | from flopy4.mf6.spec import field, fields_dict, to_dfn_field |
10 | 13 | from flopy4.uio import IO, Loader, Writer |
11 | 14 |
|
| 15 | + |
| 16 | +def update_maxbound(instance, attribute, new_value): |
| 17 | + """ |
| 18 | + Generalized function to update maxbound when period block arrays change. |
| 19 | +
|
| 20 | + This function automatically finds all period block arrays in the instance |
| 21 | + and calculates maxbound based on the maximum number of non-default values |
| 22 | + across all arrays. |
| 23 | +
|
| 24 | + Args: |
| 25 | + instance: The package instance |
| 26 | + attribute: The attribute being set (from attrs on_setattr) |
| 27 | + new_value: The new value being set |
| 28 | +
|
| 29 | + Returns: |
| 30 | + The new_value (unchanged) |
| 31 | + """ |
| 32 | + |
| 33 | + period_arrays = [] |
| 34 | + instance_fields = fields(instance.__class__) |
| 35 | + for f in instance_fields: |
| 36 | + if ( |
| 37 | + f.metadata |
| 38 | + and f.metadata.get("block") == "period" |
| 39 | + and f.metadata.get("xattree", {}).get("dims") |
| 40 | + ): |
| 41 | + period_arrays.append(f.name) |
| 42 | + |
| 43 | + maxbound_values = [] |
| 44 | + for array_name in period_arrays: |
| 45 | + if attribute and attribute.name == array_name: |
| 46 | + array_val = new_value |
| 47 | + else: |
| 48 | + array_val = getattr(instance, array_name, None) |
| 49 | + |
| 50 | + if array_val is not None: |
| 51 | + array_data = ( |
| 52 | + array_val if array_val.data.shape == array_val.shape else array_val.todense() |
| 53 | + ) |
| 54 | + |
| 55 | + if array_data.dtype.kind in ["U", "S"]: # String arrays |
| 56 | + non_default_count = len(np.where(array_data != "")[0]) |
| 57 | + else: # Numeric arrays |
| 58 | + non_default_count = len(np.where(array_data != FILL_DNODATA)[0]) |
| 59 | + |
| 60 | + maxbound_values.append(non_default_count) |
| 61 | + if maxbound_values: |
| 62 | + instance.maxbound = max(maxbound_values) |
| 63 | + |
| 64 | + return new_value |
| 65 | + |
| 66 | + |
12 | 67 | COMPONENTS = {} |
13 | 68 | """MF6 component registry.""" |
14 | 69 |
|
@@ -50,6 +105,36 @@ def default_filename(self) -> str: |
50 | 105 | cls_name = self.__class__.__name__.lower() |
51 | 106 | return f"{name}.{cls_name}" |
52 | 107 |
|
| 108 | + def __attrs_post_init__(self): |
| 109 | + """ |
| 110 | + Post-initialization hook for all components. |
| 111 | +
|
| 112 | + Automatically handles common post-init tasks like computing maxbound |
| 113 | + for components with period block arrays. |
| 114 | + """ |
| 115 | + self._update_maxbound_if_needed() |
| 116 | + |
| 117 | + def _update_maxbound_if_needed(self): |
| 118 | + """ |
| 119 | + Update maxbound if this component has period block arrays. |
| 120 | +
|
| 121 | + This method checks if the component has any period block arrays defined |
| 122 | + and calls update_maxbound if needed. This generalizes the pattern that |
| 123 | + was previously repeated in multiple component classes. |
| 124 | + """ |
| 125 | + # Check if component has a maxbound field and period block arrays |
| 126 | + component_fields = fields(self.__class__) |
| 127 | + has_maxbound = any(f.name == "maxbound" for f in component_fields) |
| 128 | + has_period_arrays = any( |
| 129 | + f.metadata |
| 130 | + and f.metadata.get("block") == "period" |
| 131 | + and f.metadata.get("xattree", {}).get("dims") |
| 132 | + for f in component_fields |
| 133 | + ) |
| 134 | + |
| 135 | + if has_maxbound and has_period_arrays: |
| 136 | + update_maxbound(self, None, None) |
| 137 | + |
53 | 138 | @classmethod |
54 | 139 | def __attrs_init_subclass__(cls): |
55 | 140 | COMPONENTS[cls.__name__.lower()] = cls |
|
0 commit comments