|
3 | 3 | from pathlib import Path |
4 | 4 | from typing import Any |
5 | 5 |
|
| 6 | +import numpy as np |
| 7 | +import sparse |
6 | 8 | import xarray as xr |
7 | 9 | import xattree |
8 | 10 | from attrs import define |
9 | 11 | from cattrs import Converter |
| 12 | +from numpy.typing import NDArray |
| 13 | +from xattree import get_xatspec |
10 | 14 |
|
| 15 | +from flopy4.adapters import get_nn |
11 | 16 | from flopy4.mf6.component import Component |
| 17 | +from flopy4.mf6.config import SPARSE_THRESHOLD |
| 18 | +from flopy4.mf6.constants import FILL_DNODATA |
12 | 19 | from flopy4.mf6.context import Context |
13 | 20 | from flopy4.mf6.exchange import Exchange |
14 | 21 | from flopy4.mf6.model import Model |
@@ -148,15 +155,15 @@ def unstructure_component(value: Component) -> dict[str, Any]: |
148 | 155 | blocks[block_name][field_name] = tuple(field_value.values.tolist()) |
149 | 156 | elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims: |
150 | 157 | has_spatial_dims = any( |
151 | | - dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nnodes"] |
| 158 | + dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nodes"] |
152 | 159 | ) |
153 | 160 | if has_spatial_dims: |
154 | 161 | # terrible hack to convert flat nodes dimension to 3d structured dims. |
155 | 162 | # long term solution for this is to use a custom xarray index. filters |
156 | 163 | # should then have access to all dimensions needed. |
157 | 164 | dims_ = set(field_value.dims).copy() |
158 | 165 | dims_.remove("nper") |
159 | | - if dims_ == {"nnodes"}: |
| 166 | + if dims_ == {"nodes"}: |
160 | 167 | parent = value.parent # type: ignore |
161 | 168 | field_value = xr.DataArray( |
162 | 169 | field_value.data.reshape( |
@@ -228,3 +235,73 @@ def _make_converter() -> Converter: |
228 | 235 |
|
229 | 236 |
|
230 | 237 | COMPONENT_CONVERTER = _make_converter() |
| 238 | + |
| 239 | + |
| 240 | +def dict_to_array(value, self_, field) -> NDArray: |
| 241 | + """ |
| 242 | + Convert a sparse dictionary representation of an array to a |
| 243 | + dense numpy array or a sparse COO array. |
| 244 | + """ |
| 245 | + |
| 246 | + if not isinstance(value, dict): |
| 247 | + # if not a dict, assume it's a numpy array |
| 248 | + # and let xarray deal with it if it isn't |
| 249 | + return value |
| 250 | + |
| 251 | + spec = get_xatspec(type(self_)).flat |
| 252 | + field = spec[field.name] |
| 253 | + if not field.dims: |
| 254 | + raise ValueError(f"Field {field} missing dims") |
| 255 | + |
| 256 | + # resolve dims |
| 257 | + explicit_dims = self_.__dict__.get("dims", {}) |
| 258 | + inherited_dims = dict(self_.parent.data.dims) if self_.parent else {} |
| 259 | + dims = inherited_dims | explicit_dims |
| 260 | + shape = [dims.get(d, d) for d in field.dims] |
| 261 | + unresolved = [d for d in shape if isinstance(d, str)] |
| 262 | + if any(unresolved): |
| 263 | + raise ValueError(f"Couldn't resolve dims: {unresolved}") |
| 264 | + |
| 265 | + if np.prod(shape) > SPARSE_THRESHOLD: |
| 266 | + a: dict[tuple[Any, ...], Any] = dict() |
| 267 | + |
| 268 | + def set_(arr, val, *ind): |
| 269 | + arr[tuple(ind)] = val |
| 270 | + |
| 271 | + def final(arr): |
| 272 | + coords = np.array(list(map(list, zip(*arr.keys())))) |
| 273 | + return sparse.COO( |
| 274 | + coords, |
| 275 | + list(arr.values()), |
| 276 | + shape=shape, |
| 277 | + fill_value=field.default or FILL_DNODATA, |
| 278 | + ) |
| 279 | + else: |
| 280 | + a = np.full(shape, FILL_DNODATA, dtype=field.dtype) # type: ignore |
| 281 | + |
| 282 | + def set_(arr, val, *ind): |
| 283 | + arr[ind] = val |
| 284 | + |
| 285 | + def final(arr): |
| 286 | + arr[arr == FILL_DNODATA] = field.default or FILL_DNODATA |
| 287 | + return arr |
| 288 | + |
| 289 | + if "nper" in dims: |
| 290 | + for kper, period in value.items(): |
| 291 | + if kper == "*": |
| 292 | + kper = 0 |
| 293 | + match len(shape): |
| 294 | + case 1: |
| 295 | + set_(a, period, kper) |
| 296 | + case _: |
| 297 | + for cellid, v in period.items(): |
| 298 | + nn = get_nn(cellid, **dims) |
| 299 | + set_(a, v, kper, nn) |
| 300 | + if kper == "*": |
| 301 | + break |
| 302 | + else: |
| 303 | + for cellid, v in value.items(): |
| 304 | + nn = get_nn(cellid, **dims) |
| 305 | + set_(a, v, nn) |
| 306 | + |
| 307 | + return final(a) |
0 commit comments