Skip to content

Commit ed6b868

Browse files
authored
reshaping no longer necessary (#279)
get rid of reshaping in TWRI even for grid and layer array package flavors
1 parent 8cf6664 commit ed6b868

File tree

7 files changed

+125
-35
lines changed

7 files changed

+125
-35
lines changed

docs/dev/array-converter-design.md

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,15 @@ def _reshape_grid(
251251
"""
252252
Perform structured↔flat grid conversion.
253253
254-
Handles:
254+
Handles full 3D grids (nodes dimension):
255255
- (nlay, nrow, ncol) → (nodes,)
256256
- (nper, nlay, nrow, ncol) → (nper, nodes)
257-
- Preserves xarray metadata if input is xarray
257+
258+
Handles per-layer 2D arrays (ncpl dimension):
259+
- (nrow, ncol) → (ncpl,)
260+
- (nper, nrow, ncol) → (nper, ncpl)
261+
262+
Preserves xarray metadata if input is xarray
258263
"""
259264
```
260265

@@ -539,19 +544,54 @@ coords = {coord_name: self._coords[coord_name], "x": self._coords["x"]}
539544

540545
**Location**: `flopy4/mf6/utils/grid.py` lines 261-284
541546

547+
#### 4. Missing Converters on Grid-Based Packages
548+
**Issue**: Grid-based ("g") and array-based ("a") package variants (Chdg, Welg, Drng, Rcha) were missing `converter=Converter(structure_array, ...)` on their period block array fields. Without converters, xattree validates dimension counts before reshape logic runs, causing errors when passing structured arrays.
549+
550+
**Solution**: Added converters to all grid-based and array-based package fields:
551+
- `Chdg.head`: Added converter for automatic reshaping
552+
- `Welg.q`: Added converter for automatic reshaping
553+
- `Drng.elev` and `Drng.cond`: Added converters for automatic reshaping
554+
- `Rcha.recharge`: Added converter for automatic reshaping
555+
556+
**Location**: `flopy4/mf6/gwf/chdg.py`, `welg.py`, `drng.py`, `rcha.py`
557+
558+
#### 5. Per-Layer Array Dimension (ncpl)
559+
**Issue**: Array-based packages use `ncpl` ("number of cells per layer") dimension for 2D per-layer data. The reshape detection only handled `nodes` (3D full grids) but not `ncpl` (2D per-layer), causing errors like:
560+
```
561+
ValueError: Shape mismatch: (3, 15, 15) vs (3, 225)
562+
```
563+
564+
**Solution**: Extended `_detect_grid_reshape` to handle `ncpl` dimension:
565+
```python
566+
# Handle 'ncpl' dimension (cells per layer, 2D per-layer arrays)
567+
if "ncpl" in expected_dims and has_structured_2d:
568+
# Case: (nrow, ncol) → (ncpl,)
569+
# Case: (nper, nrow, ncol) → (nper, ncpl)
570+
```
571+
572+
This allows automatic reshaping for array-based packages like Rcha.
573+
574+
**Location**: `flopy4/mf6/converter/structure.py` lines 113-127
575+
542576
### Validation Benefits
543577

544578
The stricter validation in the new converter caught the StructuredGrid bug that had existed undetected. While this temporarily broke tests, it exposed a real issue that would have caused problems downstream. This demonstrates the value of proper input validation.
545579

546580
### User-Facing Improvements
547581

548582
Users can now:
549-
1. Pass structured arrays `(nlay, nrow, ncol)` directly - automatic reshaping to `(nodes,)`
583+
1. Pass structured arrays directly - automatic reshaping for all dimensions:
584+
- 3D full grids: `(nlay, nrow, ncol)``(nodes,)`
585+
- 2D per-layer: `(nrow, ncol)``(ncpl,)`
586+
- Time-varying 3D: `(nper, nlay, nrow, ncol)``(nper, nodes)`
587+
- Time-varying 2D: `(nper, nrow, ncol)``(nper, ncpl)`
550588
2. Use DataFrames from `package.stress_period_data` to initialize new packages
551589
3. Mix value types within dicts (scalars, arrays, xarrays, DataFrames)
552590
4. Rely on strict validation to catch dimension mismatches early
553591

554-
Example - no manual reshaping needed:
592+
Examples - no manual reshaping needed:
593+
594+
**Example 1: Standard packages (3D grid arrays)**
555595
```python
556596
# Before (manual reshape required)
557597
icelltype = np.stack([np.full((nrow, ncol), val) for val in [1, 0, 0]])
@@ -561,3 +601,29 @@ npf = Npf(icelltype=icelltype.reshape((nodes,)), ...)
561601
icelltype = np.stack([np.full((nrow, ncol), val) for val in [1, 0, 0]])
562602
npf = Npf(icelltype=icelltype, ...) # Automatically reshaped to (nodes,)
563603
```
604+
605+
**Example 2: Grid-based packages (time-varying 3D arrays)**
606+
```python
607+
# Before (manual reshape required)
608+
head = np.full((nper, nlay, nrow, ncol), FILL_DNODATA)
609+
head[0, :2, :, 0] = 0.0
610+
chdg = Chdg(head=head.reshape(nper, -1), ...)
611+
612+
# After (automatic reshape)
613+
head = np.full((nper, nlay, nrow, ncol), FILL_DNODATA)
614+
head[0, :2, :, 0] = 0.0
615+
chdg = Chdg(head=head, ...) # Automatically reshaped to (nper, nodes)
616+
```
617+
618+
**Example 3: Array-based packages (time-varying 2D per-layer arrays)**
619+
```python
620+
# Before (manual reshape required)
621+
recharge = np.full((nper, nrow, ncol), FILL_DNODATA)
622+
recharge[0, ...] = 3.0e-8
623+
rcha = Rcha(recharge=recharge.reshape(nper, -1), ...)
624+
625+
# After (automatic reshape)
626+
recharge = np.full((nper, nrow, ncol), FILL_DNODATA)
627+
recharge[0, ...] = 3.0e-8
628+
rcha = Rcha(recharge=recharge, ...) # Automatically reshaped to (nper, ncpl)
629+
```

docs/examples/twri.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def plot_head(head, workspace):
101101
rch_rate = np.full((nlay, nrow, ncol), flopy4.mf6.constants.FILL_DNODATA)
102102
rate = np.repeat(np.expand_dims(rch_rate, axis=0), repeats=nper, axis=0)
103103
rate[0, 0, ...] = 3.0e-8
104-
rch = flopy4.mf6.gwf.Rch(recharge=rate.reshape(nper, -1), dims=dims)
104+
rch = flopy4.mf6.gwf.Rch(recharge=rate, dims=dims)
105105

106106
# Output control
107107
# TODO: show both ways to set up the Oc package, strings
@@ -211,7 +211,7 @@ def plot_head(head, workspace):
211211
print_input=True,
212212
print_flows=True,
213213
save_flows=True,
214-
head=head.reshape(nper, -1),
214+
head=head,
215215
dims=dims,
216216
)
217217

@@ -225,8 +225,8 @@ def plot_head(head, workspace):
225225
print_input=True,
226226
print_flows=True,
227227
save_flows=True,
228-
elev=elev.reshape(nper, -1),
229-
cond=cond.reshape(nper, -1),
228+
elev=elev,
229+
cond=cond,
230230
dims=dims,
231231
)
232232

@@ -235,14 +235,14 @@ def plot_head(head, workspace):
235235
for layer, row, col in wel_nodes:
236236
q[0, layer, row, col] = wel_q
237237
welg = flopy4.mf6.gwf.Welg(
238-
q=q.reshape(nper, -1),
238+
q=q,
239239
dims=dims,
240240
)
241241

242242
# recharge
243243
recharge = np.repeat(np.expand_dims(LAYER_NODATA, axis=0), repeats=nper, axis=0)
244244
recharge[0, ...] = 3.0e-8
245-
rcha = flopy4.mf6.gwf.Rcha(recharge=recharge.reshape(nper, -1), dims=dims)
245+
rcha = flopy4.mf6.gwf.Rcha(recharge=recharge, dims=dims)
246246

247247
# remove list based inputs
248248
# TODO: show variations on removing packages

flopy4/mf6/converter/structure.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,34 +86,45 @@ def _detect_grid_reshape(
8686
target_shape : tuple | None
8787
Target shape for reshape, or None
8888
"""
89-
# Check if we expect flat 'nodes' dimension
90-
if "nodes" not in expected_dims:
91-
return False, None
92-
9389
# Get expected shape
9490
expected_shape = tuple(dim_dict.get(d, d) for d in expected_dims)
9591

9692
# Check if value has structured dimensions
97-
has_structured = "nlay" in dim_dict and "nrow" in dim_dict and "ncol" in dim_dict
98-
99-
if not has_structured:
100-
return False, None
101-
102-
nlay = dim_dict["nlay"]
103-
nrow = dim_dict["nrow"]
104-
ncol = dim_dict["ncol"]
105-
nodes = dim_dict.get("nodes", nlay * nrow * ncol)
106-
107-
# Check for structured→flat conversion
108-
# Case 1: (nlay, nrow, ncol) → (nodes,)
109-
if value_shape == (nlay, nrow, ncol) and expected_shape == (nodes,):
110-
return True, (nodes,)
111-
112-
# Case 2: (nper, nlay, nrow, ncol) → (nper, nodes)
113-
if "nper" in expected_dims:
114-
nper = dim_dict["nper"]
115-
if value_shape == (nper, nlay, nrow, ncol) and expected_shape == (nper, nodes):
116-
return True, (nper, nodes)
93+
has_structured_3d = "nlay" in dim_dict and "nrow" in dim_dict and "ncol" in dim_dict
94+
has_structured_2d = "nrow" in dim_dict and "ncol" in dim_dict
95+
96+
# Handle 'nodes' dimension (full 3D grid)
97+
if "nodes" in expected_dims and has_structured_3d:
98+
nlay = dim_dict["nlay"]
99+
nrow = dim_dict["nrow"]
100+
ncol = dim_dict["ncol"]
101+
nodes = dim_dict.get("nodes", nlay * nrow * ncol)
102+
103+
# Case 1: (nlay, nrow, ncol) → (nodes,)
104+
if value_shape == (nlay, nrow, ncol) and expected_shape == (nodes,):
105+
return True, (nodes,)
106+
107+
# Case 2: (nper, nlay, nrow, ncol) → (nper, nodes)
108+
if "nper" in expected_dims:
109+
nper = dim_dict["nper"]
110+
if value_shape == (nper, nlay, nrow, ncol) and expected_shape == (nper, nodes):
111+
return True, (nper, nodes)
112+
113+
# Handle 'ncpl' dimension (cells per layer, 2D per-layer arrays)
114+
if "ncpl" in expected_dims and has_structured_2d:
115+
nrow = dim_dict["nrow"]
116+
ncol = dim_dict["ncol"]
117+
ncpl = dim_dict.get("ncpl", nrow * ncol)
118+
119+
# Case 3: (nrow, ncol) → (ncpl,)
120+
if value_shape == (nrow, ncol) and expected_shape == (ncpl,):
121+
return True, (ncpl,)
122+
123+
# Case 4: (nper, nrow, ncol) → (nper, ncpl)
124+
if "nper" in expected_dims:
125+
nper = dim_dict["nper"]
126+
if value_shape == (nper, nrow, ncol) and expected_shape == (nper, ncpl):
127+
return True, (nper, ncpl)
117128

118129
return False, None
119130

@@ -360,7 +371,7 @@ def _parse_dataframe(
360371
if has_structured:
361372
cellid = (int(row["layer"]), int(row["row"]), int(row["col"]))
362373
else:
363-
cellid = (int(row["node"]),) # type: ignore
374+
cellid = (int(row["node"]),) # type: ignore
364375

365376
# Extract field value
366377
value = row[field_name]

flopy4/mf6/gwf/chdg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from pathlib import Path
22
from typing import ClassVar, Optional
33

4+
import attrs
45
import numpy as np
56
from numpy.typing import NDArray
67
from xattree import xattree
78

9+
from flopy4.mf6.converter import structure_array
810
from flopy4.mf6.package import Package
911
from flopy4.mf6.spec import array, field, path
1012
from flopy4.mf6.utils.grid import update_maxbound
@@ -33,6 +35,7 @@ class Chdg(Package):
3335
"nodes",
3436
),
3537
default=None,
38+
converter=attrs.Converter(structure_array, takes_self=True, takes_field=True),
3639
on_setattr=update_maxbound,
3740
)
3841
aux: Optional[NDArray[np.float64]] = array(

flopy4/mf6/gwf/drng.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from pathlib import Path
22
from typing import ClassVar, Optional
33

4+
import attrs
45
import numpy as np
56
from numpy.typing import NDArray
67
from xattree import xattree
78

9+
from flopy4.mf6.converter import structure_array
810
from flopy4.mf6.package import Package
911
from flopy4.mf6.spec import array, field, path
1012
from flopy4.mf6.utils.grid import update_maxbound
@@ -34,6 +36,7 @@ class Drng(Package):
3436
"nodes",
3537
),
3638
default=None,
39+
converter=attrs.Converter(structure_array, takes_self=True, takes_field=True),
3740
on_setattr=update_maxbound,
3841
)
3942
cond: Optional[NDArray[np.float64]] = array(
@@ -43,6 +46,7 @@ class Drng(Package):
4346
"nodes",
4447
),
4548
default=None,
49+
converter=attrs.Converter(structure_array, takes_self=True, takes_field=True),
4650
on_setattr=update_maxbound,
4751
)
4852
aux: Optional[NDArray[np.float64]] = array(

flopy4/mf6/gwf/rcha.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from pathlib import Path
22
from typing import ClassVar, Optional
33

4+
import attrs
45
import numpy as np
56
from numpy.typing import NDArray
67
from xattree import xattree
78

9+
from flopy4.mf6.converter import structure_array
810
from flopy4.mf6.package import Package
911
from flopy4.mf6.spec import array, field, path
1012
from flopy4.utils import to_path
@@ -41,6 +43,7 @@ class Rcha(Package):
4143
"ncpl",
4244
),
4345
default=None,
46+
converter=attrs.Converter(structure_array, takes_self=True, takes_field=True),
4447
)
4548
aux: Optional[NDArray[np.float64]] = array(
4649
block="period",

flopy4/mf6/gwf/welg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from pathlib import Path
22
from typing import ClassVar, Optional
33

4+
import attrs
45
import numpy as np
56
from numpy.typing import NDArray
67
from xattree import xattree
78

9+
from flopy4.mf6.converter import structure_array
810
from flopy4.mf6.package import Package
911
from flopy4.mf6.spec import array, field, path
1012
from flopy4.mf6.utils.grid import update_maxbound
@@ -39,6 +41,7 @@ class Welg(Package):
3941
"nodes",
4042
),
4143
default=None,
44+
converter=attrs.Converter(structure_array, takes_self=True, takes_field=True),
4245
on_setattr=update_maxbound,
4346
)
4447
aux: Optional[NDArray[np.float64]] = array(

0 commit comments

Comments
 (0)