From 9950ea254c991337c20efe4e4a377b8778c70096 Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Tue, 15 Jul 2025 17:36:14 -0400 Subject: [PATCH] progress --- flopy4/mf6/attr_hooks.py | 6 +- flopy4/mf6/codec/writer/__init__.py | 14 +- .../mf6/codec/writer/templates/blocks.jinja | 6 +- .../mf6/codec/writer/templates/macros.jinja | 67 ++--- flopy4/mf6/converter.py | 77 +++-- flopy4/mf6/filters.py | 281 +++++++++--------- flopy4/mf6/gwf/oc.py | 22 +- flopy4/mf6/spec.py | 3 + test/test_codec.py | 71 ++--- test/test_filters.py | 97 ------ 10 files changed, 272 insertions(+), 372 deletions(-) delete mode 100644 test/test_filters.py diff --git a/flopy4/mf6/attr_hooks.py b/flopy4/mf6/attr_hooks.py index 04aac6a0..a341bfc4 100644 --- a/flopy4/mf6/attr_hooks.py +++ b/flopy4/mf6/attr_hooks.py @@ -26,7 +26,11 @@ def update_maxbound(instance, attribute, new_value): period_arrays = [] instance_fields = fields(instance.__class__) for field in instance_fields: - if field.metadata and field.metadata.get("block") == "period" and "dims" in field.metadata: + if ( + field.metadata + and field.metadata.get("block") == "period" + and field.metadata.get("xattree", {}).get("dims") + ): period_arrays.append(field.name) maxbound_values = [] diff --git a/flopy4/mf6/codec/writer/__init__.py b/flopy4/mf6/codec/writer/__init__.py index a6947f2c..616c077d 100644 --- a/flopy4/mf6/codec/writer/__init__.py +++ b/flopy4/mf6/codec/writer/__init__.py @@ -11,15 +11,13 @@ trim_blocks=True, lstrip_blocks=True, ) -_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks -_JINJA_ENV.filters["list_blocks"] = filters.list_blocks +_JINJA_ENV.filters["is_dataset"] = filters.is_dataset +_JINJA_ENV.filters["field_format"] = filters.field_format _JINJA_ENV.filters["array_how"] = filters.array_how _JINJA_ENV.filters["array_chunks"] = filters.array_chunks _JINJA_ENV.filters["array2string"] = filters.array2string -_JINJA_ENV.filters["field_type"] = filters.field_type -_JINJA_ENV.filters["array2list"] = filters.array2list -_JINJA_ENV.filters["keystring2list"] = filters.keystring2list -_JINJA_ENV.filters["keystring2list_multifield"] = filters.keystring2list_multifield +_JINJA_ENV.filters["data2list"] = filters.data2list +_JINJA_ENV.filters["data2keystring"] = filters.data2keystring _JINJA_TEMPLATE_NAME = "blocks.jinja" _PRINT_OPTIONS = { "precision": 4, @@ -31,11 +29,11 @@ def dumps(data) -> str: template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME) with np.printoptions(**_PRINT_OPTIONS): # type: ignore - return template.render(data=data) + return template.render(blocks=data) def dump(data, path: str | PathLike) -> None: template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME) - iterator = template.generate(data=data) + iterator = template.generate(blocks=data) with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore f.writelines(iterator) diff --git a/flopy4/mf6/codec/writer/templates/blocks.jinja b/flopy4/mf6/codec/writer/templates/blocks.jinja index 04f138f5..8e7923df 100644 --- a/flopy4/mf6/codec/writer/templates/blocks.jinja +++ b/flopy4/mf6/codec/writer/templates/blocks.jinja @@ -1,5 +1,5 @@ {% import 'macros.jinja' as macros with context %} -{% for block_name, block_value in (data|dict_blocks).items() %} +{% for block_name, block_value in blocks.items() %} BEGIN {{ block_name.upper() }} {% for field_name, field_value in block_value.items() if (field_value) is not none -%} {{ macros.field(field_name, field_value) }} @@ -7,7 +7,3 @@ BEGIN {{ block_name.upper() }} END {{ block_name.upper() }} {% endfor %} - -{% for block_name, block_value in (data|list_blocks).items() -%} -{{ macros.list(block_name, block_value, multi=block_name in ["period"]) }} -{%- endfor %} diff --git a/flopy4/mf6/codec/writer/templates/macros.jinja b/flopy4/mf6/codec/writer/templates/macros.jinja index d9c0ce12..e991219a 100644 --- a/flopy4/mf6/codec/writer/templates/macros.jinja +++ b/flopy4/mf6/codec/writer/templates/macros.jinja @@ -1,38 +1,40 @@ {% macro field(name, value) %} -{% set type = value|field_type %} -{% if type in ['keyword', 'integer', 'double precision', 'string'] %} +{% set format = value|field_format %} +{% if format in ['keyword', 'integer', 'double precision', 'string'] %} {{ scalar(name, value) }} -{% elif type == 'record' %} +{% elif format == 'record' %} {{ record(name, value) }} -{% elif type == 'keystring' %} +{% elif format == 'keystring' %} {{ keystring(name, value) }} -{% elif type == 'recarray' %} -{{ recarray(name, value, how=value|array_how) }} +{% elif format == 'array' %} +{{ array(name, value, how=value|array_how) }} +{% elif format == 'list' %} +{{ list(name, value) }} {% endif %} {% endmacro %} {% macro scalar(name, value) %} -{% set type = value|field_type %} -{% if value is not none %}{{ name.upper() }}{% if type != 'keyword' %} {{ value }}{% endif %}{% endif %} +{% set format = value|field_format %} +{% if value is not none %}{{ name.upper() }}{% if format != 'keyword' %} {{ value }}{% endif %}{% endif %} {% endmacro %} {% macro keystring(name, value) %} -{% for item in value.values() -%} -{{ field(item) }} -{%- endfor %} +{% for option in (value|data2keystring) -%} +{{ record("", option) }} +{% endfor %} {% endmacro %} {% macro record(name, value) %} -{% if value is mapping %} -{% for item in value.values() -%} -{{ item.name.upper() }} {{ field(item) }} +{%- if value is mapping %} +{% for field_name, field_value in value.items() -%} +{{ field_name.upper() }} {{ field(field_value) }} {%- endfor %} {% else %} {{ value|join(" ") }} -{% endif %} +{%- endif %} {% endmacro %} -{% macro recarray(name, value, how="internal") %} +{% macro array(name, value, how="internal") %} {{ name.upper() }}{% if "layered" in how %} LAYERED{% endif %} {% if how == "constant" %} @@ -51,37 +53,8 @@ OPEN/CLOSE {{ value }} {% endif %} {% endmacro %} -{% macro list(name, value, multi=False) %} -{% if multi %} -{# iterate through time periods and combine fields #} -{% set field_arrays = {} %} -{% for field_name, field_value in value.items() %} -{% if field_value is not none %} -{% set _ = field_arrays.update({field_name: field_value}) %} -{% endif %} -{% endfor %} - -{% set first_array = field_arrays.values() | list | first %} -{% if first_array is not none %} -{% set nper = first_array.shape[0] %} -{% for period_idx in range(nper) %} -BEGIN {{ name.upper() }} {{ period_idx + 1 }} -{% for row in (field_arrays|keystring2list_multifield(period_idx)) %} +{% macro list(name, value) %} +{% for row in (value|data2list) %} {{ row|join(" ") }} {% endfor %} -END {{ name.upper() }} {{ period_idx + 1 }} - -{% endfor %} -{% endif %} -{% else %} -{% for field_name, field_value in value.items() %} -{% if field_value is not none %} -BEGIN {{ name.upper() }} -{% for row in (field_value|array2list) %} -{{ row|join(" ") }} -{% endfor %} -END {{ name.upper() }} -{% endif %} -{% endfor %} -{% endif %} {% endmacro %} diff --git a/flopy4/mf6/converter.py b/flopy4/mf6/converter.py index eb25f5d7..a1808d53 100644 --- a/flopy4/mf6/converter.py +++ b/flopy4/mf6/converter.py @@ -2,21 +2,29 @@ from pathlib import Path from typing import Any +import xarray as xr import xattree from cattrs import Converter from flopy4.mf6.component import Component -from flopy4.mf6.spec import get_blocks +from flopy4.mf6.spec import fields_dict, get_blocks -def _transform_path_to_record(field_name: str, path_value: Path) -> tuple: - """Transform a Path field to its corresponding MF6 record format.""" - # Infer record structure from field name +def _attach_field_metadata( + dataset: xr.Dataset, component_type: type, field_names: list[str] +) -> None: + field_metadata = {} + component_fields = fields_dict(component_type) + for field_name in field_names: + if field_name in component_fields: + field_metadata[field_name] = component_fields[field_name].metadata + dataset.attrs["field_metadata"] = field_metadata + + +def _path_to_record(field_name: str, path_value: Path) -> tuple: if field_name.endswith("_file"): base_name = field_name.replace("_file", "").upper() return (base_name, "FILEOUT", str(path_value)) - - # Default fallback return (field_name.upper(), "FILEOUT", str(path_value)) @@ -26,27 +34,62 @@ def unstructure_component(value: Component) -> dict[str, Any]: blocks: dict[str, dict[str, Any]] = {} for block_name, block in blockspec.items(): blocks[block_name] = {} + period_data = {} + period_blocks = {} # type: ignore for field_name in block.keys(): field_value = data[field_name] - - # Transform Path fields to record format + # convert: + # - paths to records + # - datetime to ISO format + # - auxiliary fields to tuples + # - xarray DataArrays with 'nper' dimension to kper-sliced datasets + # (and split the period data into separate kper-indexed blocks) + # - other values to their original form if isinstance(field_value, Path) and field_value is not None: - field_value = _transform_path_to_record(field_name, field_value) - - # Transform datetime fields to string format + blocks[block_name][field_name] = _path_to_record(field_name, field_value) elif isinstance(field_value, datetime) and field_value is not None: - field_value = field_value.isoformat() - - # Transform auxiliary fields to tuple for single-line record format + blocks[block_name][field_name] = field_value.isoformat() elif ( field_name == "auxiliary" and hasattr(field_value, "values") and field_value is not None ): - field_value = tuple(field_value.values.tolist()) + blocks[block_name][field_name] = tuple(field_value.values.tolist()) + elif isinstance(field_value, xr.DataArray) and "nper" in field_value.dims: + has_spatial_dims = any( + dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "nnodes"] + ) + if has_spatial_dims: + period_data[field_name] = { + kper: field_value.isel(nper=kper) + for kper in range(field_value.sizes["nper"]) + } + else: + if block_name not in period_data: + period_data[block_name] = {} + period_data[block_name][field_name] = field_value # type: ignore + else: + if field_value is not None: + blocks[block_name][field_name] = field_value + + if block_name in period_data and isinstance(period_data[block_name], dict): + dataset = xr.Dataset(period_data[block_name]) + _attach_field_metadata(dataset, type(value), list(period_data[block_name].keys())) # type: ignore + blocks[block_name] = {block_name: dataset} + del period_data[block_name] + + for arr_name, periods in period_data.items(): + for kper, arr in periods.items(): + if kper not in period_blocks: + period_blocks[kper] = {} + period_blocks[kper][arr_name] = arr + + for kper, block in period_blocks.items(): + dataset = xr.Dataset(block) + _attach_field_metadata(dataset, type(value), list(block.keys())) + blocks[f"{block_name} {kper + 1}"] = {block_name: dataset} - blocks[block_name][field_name] = field_value - return blocks + return {name: block for name, block in blocks.items() if block} def _make_converter() -> Converter: diff --git a/flopy4/mf6/filters.py b/flopy4/mf6/filters.py index 75e1858c..95a05395 100644 --- a/flopy4/mf6/filters.py +++ b/flopy4/mf6/filters.py @@ -3,65 +3,32 @@ from typing import Any import numpy as np -import pandas as pd import xarray as xr from numpy.typing import NDArray from flopy4.mf6.constants import FILL_DNODATA -def is_list_block(block: dict) -> bool: - """ - Check if a block is a list block, which is a block that - contains only one recarray field using list input. - """ - meaningful_fields = {k: v for k, v in block.items() if v is not None} - if len(meaningful_fields) == 0: - return False - - # TODO: how to not hard-code these? - stress_fields = { - "head", - "q", - "elev", - "cond", - "rate", - "flux", - "concentration", - "stage", - "bhead", - "aux", - "boundname", - } - for field_name in meaningful_fields.keys(): - if field_name.lower() not in stress_fields: - return False - return True - - -def dict_blocks(data: dict) -> dict: - """ - Get dictionary blocks: blocks which can contain - one or more fields, as opposed to a list block, which - may only contain one recarray field, using list input. - """ - return { - name: block - for name, block in data.items() - if block is not None and not is_list_block(block) - } +def _is_keystring_format(dataset: xr.Dataset) -> bool: + """Check if dataset should use keystring format based on metadata.""" + field_metadata = dataset.attrs.get("field_metadata", {}) + return any(meta.get("format") == "keystring" for meta in field_metadata.values()) + + +def _is_tabular_time_format(dataset: xr.Dataset) -> bool: + """True if a dataset has multiple columns and only one dimension 'nper'.""" + return len(dataset.data_vars) > 1 and all( + "nper" in var.dims and len(var.dims) == 1 for var in dataset.data_vars.values() + ) -def list_blocks(data: dict) -> dict: - """Get list blocks, which contain only one recarray field.""" - return { - name: block for name, block in data.items() if block is not None and is_list_block(block) - } +def is_dataset(value: Any) -> bool: + return isinstance(value, xr.Dataset) -def field_type(value: Any) -> str: +def field_format(value: Any) -> str: """ - Get a field's type as defined by the MODFLOW 6 input definition language: + Get a field's formatting type as defined by the MF6 definition language: https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types """ if isinstance(value, bool): @@ -74,19 +41,35 @@ def field_type(value: Any) -> str: return "string" if isinstance(value, (dict, tuple)): return "record" - if isinstance(value, (list, np.ndarray, xr.DataArray)): - return "recarray" + if isinstance(value, xr.DataArray): + if value.dtype == "object": + return "list" + return "array" + if isinstance(value, (xr.Dataset, list)): + if isinstance(value, xr.Dataset): + if _is_keystring_format(value): + return "keystring" + if _is_tabular_time_format(value): + return "list" + return "list" return "keystring" +def has_time_dim(value: Any) -> bool: + return isinstance(value, xr.DataArray) and "nper" in value.dims + + def array_how(value: xr.DataArray) -> str: + # TODO + # - detect constant arrays? + # - above certain size, use external? return "internal" def array_chunks(value: xr.DataArray, chunks: Mapping[Hashable, int] | None = None): """ - Yield chunks from an array of up to 3 dimensions. If the - array is not already chunked, split it into chunks of the + Yield chunks from a dask-backed array of up to 3 dimensions. + If it's not already chunked, split it into chunks of the specified sizes, given as a dictionary mapping dimension names to chunk sizes. @@ -100,11 +83,11 @@ def array_chunks(value: xr.DataArray, chunks: Mapping[Hashable, int] | None = No of shape (i, j). - If the array is 1D or 2D, yield it as a single chunk. + + If the array is not a dask array, yield it as a single chunk. """ - # Check if it's a dask array (has .blocks attribute) if hasattr(value.data, "blocks"): - # Dask array - use chunking logic if value.chunks is None: if chunks is None: match value.ndim: @@ -125,7 +108,7 @@ def array_chunks(value: xr.DataArray, chunks: Mapping[Hashable, int] | None = No for chunk in value.data.blocks: yield np.squeeze(chunk.compute()) else: - # Regular numpy array - yield as single chunk + # regular array, single chunk yield np.squeeze(value.values) @@ -135,6 +118,8 @@ def array2string(value: NDArray) -> str: If the array is 1D, it is converted to a 1-line string, with elements separated by whitespace. If the array is 2D, each row becomes a line in the string. + + Used for writing array-based input to MF6 input files. """ buffer = StringIO() value = np.asarray(value) @@ -155,121 +140,129 @@ def array2string(value: NDArray) -> str: return buffer.getvalue().strip() -def array2list(value: xr.DataArray, include_zeros: bool = False): - """ - Generator that yields sparse (indices, value, *aux) tuples from a `DataArray`. - Iterates only over meaningful values (excludes zeros, NaN, and `FILL_DNODATA`). +def nonempty(arr: NDArray | xr.DataArray) -> NDArray: + if isinstance(arr, xr.DataArray): + arr = arr.values + if arr.dtype == "object": + mask = arr != None # noqa: E711 + else: + mask = ~np.ma.masked_invalid(arr).mask + mask = mask & (arr != FILL_DNODATA) + return mask + - Parameters - ---------- - value : xr.DataArray - The input array to iterate over sparsely - include_zeros : bool, optional - If True, include zero values in iteration. Default False. +def data2list(value: list | xr.DataArray | xr.Dataset): + """ + Yield record tuples from a list, `DataArray` or `Dataset`. Yields ------ tuple - Tuples of (layer, row, col, value) with 1-based indexing for MF6 + Tuples of (*cellid, *values) or (*values) depending on spatial dimensions """ - from flopy4.mf6.constants import FILL_DNODATA - if not include_zeros: - mask = (value != 0) & (value != FILL_DNODATA) & ~np.isnan(value) - else: - mask = (value != FILL_DNODATA) & ~np.isnan(value) + if isinstance(value, list): + for item in value: + yield item + return + + if isinstance(value, xr.Dataset): + yield from dataset2list(value) + return + + # handle scalar + if value.ndim == 0: + if not np.isnan(value.item()) and value.item() is not None: + yield (value.item(),) + return + spatial_dims = [d for d in value.dims if d in ("nlay", "nrow", "ncol", "nnodes")] + has_spatial_dims = len(spatial_dims) > 0 + mask = nonempty(value) indices = np.where(mask) values = value.values[mask] for i, val in enumerate(values): - idx_1based = tuple(idx[i] + 1 for idx in indices) - yield idx_1based + (val,) - - -def keystring2list(value: xr.DataArray): - """ - Generator for object arrays containing structured data (keystrings). - Yields structured records for non-null entries. - - Parameters - ---------- - value : xr.DataArray - Array with object dtype containing structured data - - Yields - ------ - tuple - Tuples of (layer, row, col, *structured_values) with 1-based indexing - """ - coord_arrays = np.meshgrid(*[np.arange(s) for s in value.shape], indexing="ij") - flat_values = value.values.flat - flat_coords = zip(*[arr.flat for arr in coord_arrays]) - for coords, val in zip(flat_coords, flat_values): - if val is not None and not pd.isna(val): - coords_1based = tuple(c + 1 for c in coords) - if hasattr(val, "_asdict"): # Named tuple - yield coords_1based + tuple(val._asdict().values()) - elif isinstance(val, dict): - yield coords_1based + tuple(val.values()) - else: - yield coords_1based + (val,) + if has_spatial_dims: + cellid = tuple(idx[i] + 1 for idx in indices) + result = cellid + (val,) + else: + result = (val,) + yield result -def keystring2list_multifield(field_arrays: dict, period_idx: int): +def dataset2list(value: xr.Dataset): """ - Combines multiple fields (e.g., elev, cond) for a given stress period - - Parameters - ---------- - field_arrays : dict - Dictionary of field_name -> xarray.DataArray - period_idx : int - Time period index (0-based) + Yield record tuples from an xarray Dataset. For regular/tabular list-based format. Yields ------ tuple - Tuples of (layer, row, col, field1_value, field2_value, ...) - with 1-based indexing + Tuples of (*cellid, *values) or (*values) depending on spatial dimensions """ - if not field_arrays: + if value is None or not any(value.data_vars): return - # determine spatial structure from first array - first_field = next(iter(field_arrays.values())) - if not isinstance(first_field, (np.ndarray, xr.DataArray)): + # handle scalar + first_arr = next(iter(value.data_vars.values())) + if first_arr.ndim == 0: + field_vals = [] + for field_name in value.data_vars.keys(): + field_val = value[field_name] + if hasattr(field_val, "item"): + field_vals.append(field_val.item()) + else: + field_vals.append(field_val) + yield tuple(field_vals) return - # get period slice - period_slices: dict[str, Any] = {} - for field_name, field_array in field_arrays.items(): - if isinstance(field_array, xr.DataArray): - period_data = field_array.isel(nper=period_idx) - period_slices[field_name] = period_data.values - elif isinstance(field_array, np.ndarray): - period_slices[field_name] = field_array[period_idx] - - # Find all locations where at least one field has meaningful data + # build mask combined_mask: Any = None - for field_name, period_data in period_slices.items(): - meaningful_mask = ( - (period_data != 0) & (period_data != FILL_DNODATA) & ~np.isnan(period_data) - ) - if combined_mask is None: - combined_mask = meaningful_mask - else: - combined_mask = combined_mask | meaningful_mask - + for field_name, arr in value.data_vars.items(): + mask = nonempty(arr) + combined_mask = mask if combined_mask is None else combined_mask | mask if combined_mask is None or not np.any(combined_mask): return + spatial_dims = [d for d in first_arr.dims if d in ("nlay", "nrow", "ncol", "nnodes")] + has_spatial_dims = len(spatial_dims) > 0 indices = np.where(combined_mask) for i in range(len(indices[0])): - idx_1based = tuple(idx[i] + 1 for idx in indices) - field_values = [] - for field_name in field_arrays.keys(): - period_data = period_slices[field_name] - val = period_data[tuple(idx[i] for idx in indices)] - field_values.append(val) - - yield idx_1based + tuple(field_values) + field_vals = [] + for field_name in value.data_vars.keys(): + field_val = value[field_name][tuple(idx[i] for idx in indices)] + if hasattr(field_val, "item"): + field_vals.append(field_val.item()) + else: + field_vals.append(field_val) + if has_spatial_dims: + cellid = tuple(idx[i] + 1 for idx in indices) + yield cellid + tuple(field_vals) + else: + yield tuple(field_vals) + + +def data2keystring(value: dict | xr.Dataset): + """ + Yield record tuples from a dict or dataset. For irregular list-based format, i.e. keystrings. + + Yields + ------ + tuple + Tuples of (field_name, value) for use with record macro + """ + if isinstance(value, dict): + if not value: + return + for field_name, field_val in value.items(): + yield (field_name.upper(), field_val) + elif isinstance(value, xr.Dataset): + if value is None or not any(value.data_vars): + return + + for field_name in value.data_vars.keys(): + field_val = value[field_name] + if hasattr(field_val, "item"): + val = field_val.item() + else: + val = field_val + yield (field_name.upper(), val) diff --git a/flopy4/mf6/gwf/oc.py b/flopy4/mf6/gwf/oc.py index e37cebe2..b50f715d 100644 --- a/flopy4/mf6/gwf/oc.py +++ b/flopy4/mf6/gwf/oc.py @@ -23,11 +23,11 @@ class Format: @define(slots=False) class Steps: - all: bool = field() - first: bool = field() - last: bool = field() - steps: list[int] = field() - frequency: int = field() + all: bool = field(default=True) + first: bool | None = field(default=None) + last: bool | None = field(default=None) + steps: list[int] | None = field(default=None) + frequency: int | None = field(default=None) @define(slots=False) class Period: @@ -51,34 +51,38 @@ class Period: ) format: Optional[Format] = field(block="options", default=None, init=False) save_head: Optional[NDArray[np.object_]] = array( - Steps, + object, block="period", default="all", dims=("nper",), converter=Converter(dict_to_array, takes_self=True, takes_field=True), reader="urword", + format="keystring", ) save_budget: Optional[NDArray[np.object_]] = array( - Steps, + object, block="period", default="all", dims=("nper",), converter=Converter(dict_to_array, takes_self=True, takes_field=True), reader="urword", + format="keystring", ) print_head: Optional[NDArray[np.object_]] = array( - Steps, + object, block="period", default="all", dims=("nper",), converter=Converter(dict_to_array, takes_self=True, takes_field=True), reader="urword", + format="keystring", ) print_budget: Optional[NDArray[np.object_]] = array( - Steps, + object, block="period", default="all", dims=("nper",), converter=Converter(dict_to_array, takes_self=True, takes_field=True), reader="urword", + format="keystring", ) diff --git a/flopy4/mf6/spec.py b/flopy4/mf6/spec.py index c618c450..4153b9e2 100644 --- a/flopy4/mf6/spec.py +++ b/flopy4/mf6/spec.py @@ -110,12 +110,15 @@ def array( on_setattr=None, block: str | None = None, reader: Reader = "readarray", + format: str | None = None, ): """Define an array field.""" if block: metadata = metadata or {} metadata["block"] = block metadata["reader"] = reader + if format: + metadata["format"] = format return flopy_array( cls=cls, dims=dims, diff --git a/test/test_codec.py b/test/test_codec.py index 4d94a58a..8d24a6e3 100644 --- a/test/test_codec.py +++ b/test/test_codec.py @@ -1,36 +1,7 @@ -import numpy as np -import xarray as xr - from flopy4.mf6.codec import dumps from flopy4.mf6.converter import COMPONENT_CONVERTER -def test_list_template_rendering(): - """Test that list blocks render correctly with sparse arrays.""" - - nnodes = 9 - data = np.full((1, nnodes), 1e30) - - data[0, 4] = -500.0 - data[0, 8] = -250.0 - - wel_data = xr.DataArray( - data, dims=["nper", "nnodes"], coords={"node": ("nnodes", range(nnodes))} - ) - - test_data = {"period": {"q": wel_data}} - - result = dumps(test_data) - print("List template result:") - print(result) - - assert "BEGIN PERIOD" in result - assert "END PERIOD" in result - assert "5 -500.0" in result # Node 5 (1-based) with -500.0 - assert "9 -250.0" in result # Node 9 (1-based) with -250.0 - assert "1e+30" not in result - - def test_dumps_ic(): from flopy4.mf6.gwf import Dis, Gwf, Ic @@ -113,6 +84,7 @@ def test_dumps_chd(): ) result = dumps(COMPONENT_CONVERTER.unstructure(chd)) + print(result) assert "BEGIN PERIOD 1" in result assert "END PERIOD 1" in result @@ -123,7 +95,8 @@ def test_dumps_chd(): assert len(lines) == 2 assert "1 10.0" in result # First CHD cell - node 1 assert "100 20.0" in result # Second CHD cell - node 100 - assert result + assert "1e+30" not in result + assert "1.0e+30" not in result def test_dumps_wel_sparse(): @@ -156,10 +129,12 @@ def test_dumps_wel_sparse(): lines = [line.strip() for line in period_section.split("\n") if line.strip()] assert len(lines) == 3 - result_lower = result.lower() - assert "-100" in result_lower or "100" in result_lower - assert "-50" in result_lower or "50" in result_lower - assert "25" in result_lower + # node q (nodes are 1-based) + assert "24 -100.0" in result # (0,2,3) -> node 24 + assert "158 -50.0" in result # (1,5,7) -> node 158 + assert "282 25.0" in result # (2,8,1) -> node 282 + assert "1e+30" not in result + assert "1.0e+30" not in result def test_dumps_drn_sparse_multiperiod(): @@ -211,11 +186,14 @@ def test_dumps_drn_sparse_multiperiod(): assert len(period1_lines) == 2 assert len(period2_lines) == 3 - assert "5 10.0 1.0" in result - assert "46 8.0 2.0" in result - assert "7 12.0 1.5" in result - assert "14 9.0 0.8" in result - assert "43 7.0 2.2" in result + # node elev cond + assert "5 10.0 1.0" in result # Period 1: (0,0,4) + assert "46 8.0 2.0" in result # Period 1: (1,4,0) + assert "7 12.0 1.5" in result # Period 2: (0,1,1) + assert "14 9.0 0.8" in result # Period 2: (0,2,3) + assert "43 7.0 2.2" in result # Period 2: (1,3,2) + assert "1e+30" not in result + assert "1.0e+30" not in result def test_dumps_chd_sparse_realistic(): @@ -242,9 +220,11 @@ def test_dumps_chd_sparse_realistic(): lines = [line.strip() for line in period_section.split("\n") if line.strip()] assert len(lines) == 24 - assert "100" in result - assert "95" in result - assert "98" in result + assert "100.0" in result # Left boundary + assert "95.0" in result # Right boundary + assert "98.0" in result # Bottom boundary + assert "1e+30" not in result + assert "1.0e+30" not in result def test_dumps_wel_with_auxiliary(): @@ -279,5 +259,8 @@ def test_dumps_wel_with_auxiliary(): lines = [line.strip() for line in period_section.split("\n") if line.strip()] assert len(lines) == 2 - assert "-75" in result or "75" in result - assert "-25" in result or "25" in result + # node q aux_value + assert "8 -75.0 1.0" in result # (0,1,2) -> node 8, q=-75.0, aux=1.0 + assert "45 -25.0 2.0" in result # (1,3,4) -> node 45, q=-25.0, aux=2.0 + assert "1e+30" not in result + assert "1.0e+30" not in result diff --git a/test/test_filters.py b/test/test_filters.py deleted file mode 100644 index eaaa3cc3..00000000 --- a/test/test_filters.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Tests for flopy4.mf6.filters module.""" - -import numpy as np -import xarray as xr - -from flopy4.mf6.filters import array2list, keystring2list - - -def test_array2list(): - data = np.zeros((3, 4, 5)) - data[0, 1, 2] = 100.0 - data[1, 3, 0] = -50.0 - data[2, 0, 4] = 25.5 - - arr = xr.DataArray(data, dims=["layer", "row", "col"]) - sparse_data = list(array2list(arr)) - assert len(sparse_data) == 3 - expected = {(1, 2, 3, 100.0), (2, 4, 1, -50.0), (3, 1, 5, 25.5)} - assert set(sparse_data) == expected - - -def test_array2list_with_zeros(): - data = np.array([[0.0, 1.0], [2.0, 0.0]]) - arr = xr.DataArray(data, dims=["row", "col"]) - - sparse_no_zeros = list(array2list(arr)) - assert len(sparse_no_zeros) == 2 - assert set(sparse_no_zeros) == {(1, 2, 1.0), (2, 1, 2.0)} - - sparse_with_zeros = list(array2list(arr, include_zeros=True)) - assert len(sparse_with_zeros) == 4 - expected_with_zeros = {(1, 1, 0.0), (1, 2, 1.0), (2, 1, 2.0), (2, 2, 0.0)} - assert set(sparse_with_zeros) == expected_with_zeros - - -def test_array2list_nan_handling(): - data = np.array([[1.0, np.nan], [0.0, 2.0]]) - arr = xr.DataArray(data, dims=["row", "col"]) - sparse_data = list(array2list(arr, include_zeros=True)) - expected = {(1, 1, 1.0), (2, 1, 0.0), (2, 2, 2.0)} - assert set(sparse_data) == expected - - -def test_keystring2list(): - data = np.full((2, 3), None, dtype=object) - data[0, 1] = {"rate": -100.0, "aux1": 1, "aux2": 2} - data[1, 2] = {"rate": -200.0, "aux1": 3, "aux2": 4} - - arr = xr.DataArray(data, dims=["row", "col"]) - sparse_data = list(keystring2list(arr)) - assert len(sparse_data) == 2 - - rows = {entry[:2] for entry in sparse_data} - assert rows == {(1, 2), (2, 3)} - - for entry in sparse_data: - if entry[:2] == (1, 2): - assert entry[2:] == (-100.0, 1, 2) - elif entry[:2] == (2, 3): - assert entry[2:] == (-200.0, 3, 4) - - -def test_keystring2list_with_namedtuple(): - from collections import namedtuple - - WelData = namedtuple("WelData", ["rate", "aux1", "aux2"]) - - data = np.full((2, 2), None, dtype=object) - data[0, 0] = WelData(-100.0, 1, 2) - data[1, 1] = WelData(-200.0, 3, 4) - - arr = xr.DataArray(data, dims=["row", "col"]) - sparse_data = list(keystring2list(arr)) - - assert len(sparse_data) == 2 - expected = {(1, 1, -100.0, 1, 2), (2, 2, -200.0, 3, 4)} - assert set(sparse_data) == expected - - -def test_array2list_1d(): - """Test sparse iteration with 1D arrays.""" - data = np.array([0, 5, 0, 10]) - arr = xr.DataArray(data, dims=["index"]) - - sparse_data = list(array2list(arr)) - expected = {(2, 5), (4, 10)} # 1-based indexing - assert set(sparse_data) == expected - - -def test_array2list_2d(): - """Test sparse iteration with 2D arrays.""" - data = np.array([[1, 0], [0, 2]]) - arr = xr.DataArray(data, dims=["row", "col"]) - - sparse_data = list(array2list(arr)) - expected = {(1, 1, 1), (2, 2, 2)} # 1-based indexing - assert set(sparse_data) == expected