Skip to content

Commit 2a029b0

Browse files
authored
implement setter for stress period data property (#280)
followup on #277 which began with a read-only property, give it a setter. dims are found on 1) the parent if exists, 2) other arrays if present, otherwise an error is raised. just supports structured grids for now. all still rough, needs cleanup at some point
1 parent ed6b868 commit 2a029b0

File tree

3 files changed

+348
-18
lines changed

3 files changed

+348
-18
lines changed

flopy4/mf6/converter/structure.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def structure_keyword(value, field) -> str | None:
1616
return field.name if value else None
1717

1818

19-
def _resolve_dimensions(self_, field) -> tuple[list[str], list[int], dict]:
19+
def _resolve_dimensions(
20+
self_, field, *, dims: dict | None = None
21+
) -> tuple[list[str], list[int], dict]:
2022
"""
2123
Get expected dimensions, shape, and resolved dimension values.
2224
@@ -26,6 +28,9 @@ def _resolve_dimensions(self_, field) -> tuple[list[str], list[int], dict]:
2628
Parent object containing dimension context
2729
field : object
2830
Field specification with dims, dtype, default
31+
dims : dict, optional
32+
Explicit dimension sizes to use. If provided, takes precedence over
33+
dims from parent or self_.__dict__.
2934
3035
Returns
3136
-------
@@ -42,10 +47,15 @@ def _resolve_dimensions(self_, field) -> tuple[list[str], list[int], dict]:
4247
raise ValueError(f"Field {field} missing dims")
4348

4449
# Resolve dims from model context
45-
explicit_dims = self_.__dict__.get("dims", {})
50+
# Priority: 1) explicit dims parameter, 2) self_.__dict__, 3) parent
4651
inherited_dims = dict(self_.parent.data.dims) if self_.parent else {}
52+
explicit_dims = self_.__dict__.get("dims", {})
4753
dim_dict = inherited_dims | explicit_dims
4854

55+
# Override with explicitly provided dims (highest priority)
56+
if dims is not None:
57+
dim_dict.update(dims)
58+
4959
# Check object attributes directly for dimension values
5060
# These override inherited dims (important during initialization when dims are passed as kwargs)
5161
for dim_name in field.dims:
@@ -482,7 +492,13 @@ def _parse_dict_format(
482492

483493

484494
def structure_array(
485-
value, self_, field, *, return_xarray: bool = False, sparse_threshold: int | None = None
495+
value,
496+
self_,
497+
field,
498+
*,
499+
return_xarray: bool = False,
500+
sparse_threshold: int | None = None,
501+
dims: dict | None = None,
486502
) -> xr.DataArray | NDArray | sparse.COO:
487503
"""
488504
Convert various array representations to structured arrays.
@@ -507,14 +523,17 @@ def structure_array(
507523
If True, return xr.DataArray; otherwise return raw array (for backward compatibility)
508524
sparse_threshold : int | None
509525
Override default sparse threshold for COO vs dense
526+
dims : dict | None
527+
Explicit dimension sizes (e.g., {'nper': 10, 'nodes': 100}).
528+
If provided, takes precedence over dims from parent or self_.
510529
511530
Returns
512531
-------
513532
xr.DataArray | np.ndarray | sparse.COO
514533
Structured array with proper shape and metadata
515534
"""
516535
# Resolve dimensions
517-
dims, shape, dim_dict = _resolve_dimensions(self_, field)
536+
dims_names, shape, dim_dict = _resolve_dimensions(self_, field, dims=dims)
518537
threshold = sparse_threshold if sparse_threshold is not None else SPARSE_THRESHOLD
519538

520539
# Handle different input types
@@ -526,7 +545,7 @@ def structure_array(
526545

527546
if isinstance(value, dict):
528547
# Parse dict format with fill-forward logic
529-
parsed_dict = _parse_dict_format(value, dims, tuple(shape), dim_dict, field, self_)
548+
parsed_dict = _parse_dict_format(value, dims_names, tuple(shape), dim_dict, field, self_)
530549

531550
# Build array using sparse or dense approach
532551
if np.prod(shape) > threshold:
@@ -553,15 +572,15 @@ def structure_array(
553572
)
554573
value_data = row[-1]
555574
nn = get_nn(cellid, **dim_dict)
556-
if "nper" in dims:
575+
if "nper" in dims_names:
557576
coords_dict[(key, nn)] = value_data
558577
else:
559578
coords_dict[(nn,)] = value_data
560579
elif isinstance(val, dict):
561580
# Nested dict: {cellid: value}
562581
for cellid, v in val.items():
563582
nn = get_nn(cellid, **dim_dict)
564-
if "nper" in dims:
583+
if "nper" in dims_names:
565584
coords_dict[(key, nn)] = v
566585
else:
567586
coords_dict[(nn,)] = v
@@ -602,7 +621,7 @@ def structure_array(
602621
val = parsed_dict[key]
603622

604623
# Determine fill range (current key to next key or end)
605-
if "nper" in dims:
624+
if "nper" in dims_names:
606625
next_key = (
607626
sorted_keys[idx + 1]
608627
if idx + 1 < len(sorted_keys)
@@ -629,27 +648,27 @@ def structure_array(
629648
)
630649
value_data = row[-1]
631650
nn = get_nn(cellid, **dim_dict)
632-
if "nper" in dims:
651+
if "nper" in dims_names:
633652
result[kper, nn] = value_data
634653
else:
635654
result[nn] = value_data
636655
elif isinstance(val, dict):
637656
# Nested dict: {cellid: value}
638657
for cellid, v in val.items():
639658
nn = get_nn(cellid, **dim_dict)
640-
if "nper" in dims:
659+
if "nper" in dims_names:
641660
result[kper, nn] = v
642661
else:
643662
result[nn] = v
644663
elif isinstance(val, np.ndarray):
645664
# Array value
646-
if "nper" in dims:
665+
if "nper" in dims_names:
647666
result[kper] = val
648667
else:
649668
result = val
650669
elif isinstance(val, xr.DataArray):
651670
# xarray value
652-
if "nper" in dims:
671+
if "nper" in dims_names:
653672
result[kper] = val.values
654673
else:
655674
result = val.values
@@ -667,15 +686,15 @@ def structure_array(
667686

668687
elif isinstance(value, list):
669688
# List format
670-
result = _parse_list_format(value, dims, tuple(shape), field)
689+
result = _parse_list_format(value, dims_names, tuple(shape), field)
671690

672691
elif isinstance(value, (xr.DataArray, np.ndarray)):
673692
# Duck array - validate and reshape if needed
674-
result = _validate_duck_array(value, dims, tuple(shape), dim_dict)
693+
result = _validate_duck_array(value, dims_names, tuple(shape), dim_dict)
675694

676695
# Handle time fill-forward
677-
if "nper" in dims and "nper" in dim_dict:
678-
result = _fill_forward_time(result, dims, dim_dict["nper"])
696+
if "nper" in dims_names and "nper" in dim_dict:
697+
result = _fill_forward_time(result, dims_names, dim_dict["nper"])
679698

680699
elif isinstance(value, (int, float)):
681700
# Scalar - broadcast to full shape
@@ -689,10 +708,10 @@ def structure_array(
689708
if return_xarray and not isinstance(result, xr.DataArray):
690709
# Build coordinates
691710
xr_coords: dict[str, Any] = {}
692-
for dim in dims:
711+
for dim in dims: # type: ignore
693712
if dim in dim_dict:
694713
xr_coords[dim] = np.arange(dim_dict[dim])
695714

696-
result = _to_xarray(result, dims, xr_coords)
715+
result = _to_xarray(result, dims, xr_coords) # type: ignore
697716

698717
return result

flopy4/mf6/package.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,104 @@ def stress_period_data(self) -> pd.DataFrame:
194194
df = df.groupby(coord_columns, as_index=False).first()
195195

196196
return df
197+
198+
@stress_period_data.setter
199+
def stress_period_data(self, value: pd.DataFrame) -> None:
200+
"""
201+
Set stress period data from a DataFrame.
202+
203+
Parameters
204+
----------
205+
value : pd.DataFrame
206+
DataFrame with columns: 'kper' (stress period), spatial coordinates
207+
(either 'layer'/'row'/'col' or 'node'), and field value columns.
208+
209+
Examples
210+
--------
211+
>>> # Modify existing package data
212+
>>> chd = Chd(parent=gwf, head={0: {(0, 0, 0): 1.0}})
213+
>>> df = chd.stress_period_data
214+
>>> df['head'] = df['head'] * 2 # Double all values
215+
>>> chd.stress_period_data = df # Apply changes
216+
217+
>>> # Create new data from scratch
218+
>>> df = pd.DataFrame({
219+
... 'kper': [0, 0, 1],
220+
... 'layer': [0, 0, 0],
221+
... 'row': [0, 5, 0],
222+
... 'col': [0, 5, 5],
223+
... 'head': [10.0, 8.0, 9.0]
224+
... })
225+
>>> chd.stress_period_data = df
226+
"""
227+
import xarray as xr
228+
from xattree import get_xatspec
229+
230+
from flopy4.mf6.converter.structure import structure_array
231+
232+
if not isinstance(value, pd.DataFrame):
233+
raise TypeError(f"Expected DataFrame, got {type(value)}")
234+
235+
# Get xattree field specifications
236+
spec = get_xatspec(type(self)).flat
237+
238+
# Find all period block fields
239+
period_fields = []
240+
field_objects = {}
241+
for field_name, field_spec in spec.items():
242+
if field_spec.metadata.get("block") == "period" and hasattr(field_spec, "dims"): # type: ignore
243+
period_fields.append(field_name)
244+
field_objects[field_name] = field_spec
245+
246+
if not period_fields:
247+
raise TypeError("No period block fields found in package")
248+
249+
# Check which fields are present in the DataFrame
250+
available_fields = [f for f in period_fields if f in value.columns]
251+
if not available_fields:
252+
raise ValueError(
253+
f"DataFrame must contain at least one period field column. "
254+
f"Expected one of {period_fields}, got {value.columns.tolist()}"
255+
)
256+
257+
# Build dimension context for the converter
258+
# Priority: 1) parent model dims, 2) existing array data
259+
dim_dict = {}
260+
261+
# 1. Get dims from parent if available (most common case)
262+
if hasattr(self, "parent") and self.parent is not None and hasattr(self.parent, "data"):
263+
dim_dict.update(dict(self.parent.data.dims))
264+
265+
# 2. Extract dimensions from existing field data
266+
for field_name in period_fields:
267+
field_data = getattr(self, field_name, None)
268+
if field_data is not None and isinstance(field_data, xr.DataArray):
269+
# xarray stores dimension sizes
270+
dim_dict.update(dict(field_data.sizes))
271+
break # One field is enough to get dimensions
272+
273+
# 3. Check if DataFrame requires structured grid dims (nrow, ncol, nlay)
274+
# but they're not available - provide helpful error
275+
has_structured_coords = all(col in value.columns for col in ["layer", "row", "col"])
276+
if has_structured_coords:
277+
missing_dims = [d for d in ["nrow", "ncol", "nlay"] if d not in dim_dict]
278+
if missing_dims:
279+
raise ValueError(
280+
f"DataFrame has structured coordinates (layer/row/col) but package "
281+
f"is missing required dimensions: {missing_dims}. "
282+
f"Attach the package to a parent model with these dimensions, or use "
283+
f"node-based coordinates in the DataFrame instead."
284+
)
285+
286+
# Update each field present in the DataFrame
287+
# Pass dims explicitly to converter - no __dict__ manipulation needed
288+
for field_name in available_fields:
289+
field_obj = field_objects[field_name]
290+
291+
# Call converter with explicit dims parameter
292+
converted_value = structure_array(
293+
value, self, field_obj, dims=dim_dict if dim_dict else None
294+
)
295+
296+
# Set the attribute, which will trigger on_setattr hooks (e.g., update_maxbound)
297+
setattr(self, field_name, converted_value)

0 commit comments

Comments
 (0)