Skip to content

Commit 7f93165

Browse files
author
wpbonelli
committed
mypy fixes
1 parent 1db60ef commit 7f93165

File tree

8 files changed

+1199
-802
lines changed

8 files changed

+1199
-802
lines changed

flopy4/mf6/converter.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
from pathlib import Path
33
from typing import Any
44

5-
import numpy as np
6-
import xarray as xr
75
import xattree
86
from cattrs import Converter
97

108
from flopy4.mf6.component import Component
11-
from flopy4.mf6.constants import FILL_DNODATA
129
from flopy4.mf6.spec import get_blocks
1310

1411

@@ -26,7 +23,7 @@ def _transform_path_to_record(field_name: str, path_value: Path) -> tuple:
2623
def unstructure_component(value: Component) -> dict[str, Any]:
2724
data = xattree.asdict(value)
2825
blockspec = get_blocks(value.dfn)
29-
blocks = {}
26+
blocks: dict[str, dict[str, Any]] = {}
3027
for block_name, block in blockspec.items():
3128
blocks[block_name] = {}
3229
for field_name in block.keys():

flopy4/mf6/filters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def keystring2list_multifield(field_arrays: dict, period_idx: int):
241241
return
242242

243243
# get period slice
244-
period_slices = {}
244+
period_slices: dict[str, Any] = {}
245245
for field_name, field_array in field_arrays.items():
246246
if isinstance(field_array, xr.DataArray):
247247
period_data = field_array.isel(nper=period_idx)
@@ -250,7 +250,7 @@ def keystring2list_multifield(field_arrays: dict, period_idx: int):
250250
period_slices[field_name] = field_array[period_idx]
251251

252252
# Find all locations where at least one field has meaningful data
253-
combined_mask = None
253+
combined_mask: Any = None
254254
for field_name, period_data in period_slices.items():
255255
meaningful_mask = (
256256
(period_data != 0) & (period_data != FILL_DNODATA) & ~np.isnan(period_data)

flopy4/mf6/gwf/chd.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import ClassVar, Optional
33

44
import numpy as np
5-
from attrs import Converter, setters
5+
from attrs import Converter
66
from numpy.typing import NDArray
77
from xattree import xattree
88

@@ -14,38 +14,48 @@
1414

1515
def _update_maxbound(instance, attribute, new_value):
1616
"""Update maxbound when period block arrays change."""
17-
if hasattr(instance, '_updating_maxbound'):
17+
if hasattr(instance, "_updating_maxbound"):
1818
return new_value
19-
19+
2020
# Calculate maxbound from all relevant arrays
2121
maxbound_values = []
22-
22+
2323
# Check head array
24-
head_val = new_value if attribute and attribute.name == 'head' else getattr(instance, 'head', None)
24+
head_val = (
25+
new_value if attribute and attribute.name == "head" else getattr(instance, "head", None)
26+
)
2527
if head_val is not None:
2628
head = head_val if head_val.data.shape == head_val.shape else head_val.todense()
2729
maxbound_values.append(len(np.where(head != FILL_DNODATA)[0]))
28-
29-
# Check aux array
30-
aux_val = new_value if attribute and attribute.name == 'aux' else getattr(instance, 'aux', None)
30+
31+
# Check aux array
32+
aux_val = new_value if attribute and attribute.name == "aux" else getattr(instance, "aux", None)
3133
if aux_val is not None:
3234
aux = aux_val if aux_val.data.shape == aux_val.shape else aux_val.todense()
3335
maxbound_values.append(len(np.where(aux != FILL_DNODATA)[0]))
34-
36+
3537
# Check boundname array
36-
boundname_val = new_value if attribute and attribute.name == 'boundname' else getattr(instance, 'boundname', None)
38+
boundname_val = (
39+
new_value
40+
if attribute and attribute.name == "boundname"
41+
else getattr(instance, "boundname", None)
42+
)
3743
if boundname_val is not None:
38-
boundname = boundname_val if boundname_val.data.shape == boundname_val.shape else boundname_val.todense()
44+
boundname = (
45+
boundname_val
46+
if boundname_val.data.shape == boundname_val.shape
47+
else boundname_val.todense()
48+
)
3949
maxbound_values.append(len(np.where(boundname != "")[0]))
40-
50+
4151
# Update maxbound if we have values
4252
if maxbound_values:
4353
instance._updating_maxbound = True
4454
try:
4555
instance.maxbound = max(maxbound_values)
4656
finally:
47-
delattr(instance, '_updating_maxbound')
48-
57+
delattr(instance, "_updating_maxbound")
58+
4959
return new_value
5060

5161

flopy4/mf6/gwf/drn.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import ClassVar, Optional
33

44
import numpy as np
5-
from attrs import Converter, setters
5+
from attrs import Converter
66
from numpy.typing import NDArray
77
from xattree import xattree
88

@@ -14,44 +14,56 @@
1414

1515
def _update_maxbound(instance, attribute, new_value):
1616
"""Update maxbound when period block arrays change."""
17-
if hasattr(instance, '_updating_maxbound'):
17+
if hasattr(instance, "_updating_maxbound"):
1818
return new_value
19-
19+
2020
# Calculate maxbound from all relevant arrays
2121
maxbound_values = []
22-
22+
2323
# Check elev array
24-
elev_val = new_value if attribute and attribute.name == 'elev' else getattr(instance, 'elev', None)
24+
elev_val = (
25+
new_value if attribute and attribute.name == "elev" else getattr(instance, "elev", None)
26+
)
2527
if elev_val is not None:
2628
elev = elev_val if elev_val.data.shape == elev_val.shape else elev_val.todense()
2729
maxbound_values.append(len(np.where(elev != FILL_DNODATA)[0]))
28-
30+
2931
# Check cond array
30-
cond_val = new_value if attribute and attribute.name == 'cond' else getattr(instance, 'cond', None)
32+
cond_val = (
33+
new_value if attribute and attribute.name == "cond" else getattr(instance, "cond", None)
34+
)
3135
if cond_val is not None:
3236
cond = cond_val if cond_val.data.shape == cond_val.shape else cond_val.todense()
3337
maxbound_values.append(len(np.where(cond != FILL_DNODATA)[0]))
34-
35-
# Check aux array
36-
aux_val = new_value if attribute and attribute.name == 'aux' else getattr(instance, 'aux', None)
38+
39+
# Check aux array
40+
aux_val = new_value if attribute and attribute.name == "aux" else getattr(instance, "aux", None)
3741
if aux_val is not None:
3842
aux = aux_val if aux_val.data.shape == aux_val.shape else aux_val.todense()
3943
maxbound_values.append(len(np.where(aux != FILL_DNODATA)[0]))
40-
44+
4145
# Check boundname array
42-
boundname_val = new_value if attribute and attribute.name == 'boundname' else getattr(instance, 'boundname', None)
46+
boundname_val = (
47+
new_value
48+
if attribute and attribute.name == "boundname"
49+
else getattr(instance, "boundname", None)
50+
)
4351
if boundname_val is not None:
44-
boundname = boundname_val if boundname_val.data.shape == boundname_val.shape else boundname_val.todense()
52+
boundname = (
53+
boundname_val
54+
if boundname_val.data.shape == boundname_val.shape
55+
else boundname_val.todense()
56+
)
4557
maxbound_values.append(len(np.where(boundname != "")[0]))
46-
58+
4759
# Update maxbound if we have values
4860
if maxbound_values:
4961
instance._updating_maxbound = True
5062
try:
5163
instance.maxbound = max(maxbound_values)
5264
finally:
53-
delattr(instance, '_updating_maxbound')
54-
65+
delattr(instance, "_updating_maxbound")
66+
5567
return new_value
5668

5769

@@ -111,5 +123,10 @@ class Drn(Package):
111123

112124
def __attrs_post_init__(self):
113125
# Trigger maxbound calculation on initialization
114-
if self.elev is not None or self.cond is not None or self.aux is not None or self.boundname is not None:
126+
if (
127+
self.elev is not None
128+
or self.cond is not None
129+
or self.aux is not None
130+
or self.boundname is not None
131+
):
115132
_update_maxbound(self, None, None)

flopy4/mf6/gwf/rch.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import ClassVar, Optional
33

44
import numpy as np
5-
from attrs import Converter, setters
5+
from attrs import Converter
66
from numpy.typing import NDArray
77
from xattree import xattree
88

@@ -14,38 +14,54 @@
1414

1515
def _update_maxbound(instance, attribute, new_value):
1616
"""Update maxbound when period block arrays change."""
17-
if hasattr(instance, '_updating_maxbound'):
17+
if hasattr(instance, "_updating_maxbound"):
1818
return new_value
19-
19+
2020
# Calculate maxbound from all relevant arrays
2121
maxbound_values = []
22-
22+
2323
# Check recharge array
24-
recharge_val = new_value if attribute and attribute.name == 'recharge' else getattr(instance, 'recharge', None)
24+
recharge_val = (
25+
new_value
26+
if attribute and attribute.name == "recharge"
27+
else getattr(instance, "recharge", None)
28+
)
2529
if recharge_val is not None:
26-
recharge = recharge_val if recharge_val.data.shape == recharge_val.shape else recharge_val.todense()
30+
recharge = (
31+
recharge_val
32+
if recharge_val.data.shape == recharge_val.shape
33+
else recharge_val.todense()
34+
)
2735
maxbound_values.append(len(np.where(recharge != FILL_DNODATA)[0]))
28-
29-
# Check aux array
30-
aux_val = new_value if attribute and attribute.name == 'aux' else getattr(instance, 'aux', None)
36+
37+
# Check aux array
38+
aux_val = new_value if attribute and attribute.name == "aux" else getattr(instance, "aux", None)
3139
if aux_val is not None:
3240
aux = aux_val if aux_val.data.shape == aux_val.shape else aux_val.todense()
3341
maxbound_values.append(len(np.where(aux != FILL_DNODATA)[0]))
34-
42+
3543
# Check boundname array
36-
boundname_val = new_value if attribute and attribute.name == 'boundname' else getattr(instance, 'boundname', None)
44+
boundname_val = (
45+
new_value
46+
if attribute and attribute.name == "boundname"
47+
else getattr(instance, "boundname", None)
48+
)
3749
if boundname_val is not None:
38-
boundname = boundname_val if boundname_val.data.shape == boundname_val.shape else boundname_val.todense()
50+
boundname = (
51+
boundname_val
52+
if boundname_val.data.shape == boundname_val.shape
53+
else boundname_val.todense()
54+
)
3955
maxbound_values.append(len(np.where(boundname != "")[0]))
40-
56+
4157
# Update maxbound if we have values
4258
if maxbound_values:
4359
instance._updating_maxbound = True
4460
try:
4561
instance.maxbound = max(maxbound_values)
4662
finally:
47-
delattr(instance, '_updating_maxbound')
48-
63+
delattr(instance, "_updating_maxbound")
64+
4965
return new_value
5066

5167

flopy4/mf6/gwf/wel.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,46 @@
1414

1515
def _update_maxbound(instance, attribute, new_value):
1616
"""Update maxbound when period block arrays change."""
17-
if hasattr(instance, '_updating_maxbound'):
17+
if hasattr(instance, "_updating_maxbound"):
1818
return new_value
19-
19+
2020
# Calculate maxbound from all relevant arrays
2121
maxbound_values = []
22-
22+
2323
# Check q array
24-
q_val = new_value if attribute and attribute.name == 'q' else getattr(instance, 'q', None)
24+
q_val = new_value if attribute and attribute.name == "q" else getattr(instance, "q", None)
2525
if q_val is not None:
2626
q = q_val if q_val.data.shape == q_val.shape else q_val.todense()
2727
maxbound_values.append(len(np.where(q != FILL_DNODATA)[0]))
28-
29-
# Check aux array
30-
aux_val = new_value if attribute and attribute.name == 'aux' else getattr(instance, 'aux', None)
28+
29+
# Check aux array
30+
aux_val = new_value if attribute and attribute.name == "aux" else getattr(instance, "aux", None)
3131
if aux_val is not None:
3232
aux = aux_val if aux_val.data.shape == aux_val.shape else aux_val.todense()
3333
maxbound_values.append(len(np.where(aux != FILL_DNODATA)[0]))
34-
34+
3535
# Check boundname array
36-
boundname_val = new_value if attribute and attribute.name == 'boundname' else getattr(instance, 'boundname', None)
36+
boundname_val = (
37+
new_value
38+
if attribute and attribute.name == "boundname"
39+
else getattr(instance, "boundname", None)
40+
)
3741
if boundname_val is not None:
38-
boundname = boundname_val if boundname_val.data.shape == boundname_val.shape else boundname_val.todense()
42+
boundname = (
43+
boundname_val
44+
if boundname_val.data.shape == boundname_val.shape
45+
else boundname_val.todense()
46+
)
3947
maxbound_values.append(len(np.where(boundname != "")[0]))
40-
48+
4149
# Update maxbound if we have values
4250
if maxbound_values:
4351
instance._updating_maxbound = True
4452
try:
4553
instance.maxbound = max(maxbound_values)
4654
finally:
47-
delattr(instance, '_updating_maxbound')
48-
55+
delattr(instance, "_updating_maxbound")
56+
4957
return new_value
5058

5159

flopy4/mf6/spec.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,11 @@ def is_list_field(field: Field) -> bool:
290290
def is_list_block(block: Block) -> bool:
291291
return (
292292
len(block) == 1
293-
and (field := next(iter(block.values())))["type"] == "recarray"
294-
and field["reader"] != "readarray"
295-
) or (all(f["type"] == "recarray" and f["reader"] != "readarray" for f in block.values()))
293+
and (field := next(iter(block.values()))).metadata.get("type") == "recarray"
294+
and field.metadata.get("reader") != "readarray"
295+
) or (
296+
all(
297+
f.metadata.get("type") == "recarray" and f.metadata.get("reader") != "readarray"
298+
for f in block.values()
299+
)
300+
)

0 commit comments

Comments
 (0)