Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions flopy4/mf6/codec.py

This file was deleted.

71 changes: 71 additions & 0 deletions flopy4/mf6/codec/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import sys
from os import PathLike
from typing import Any

import numpy as np
from jinja2 import Environment, PackageLoader

from flopy4.mf6 import filters
from flopy4.mf6.codec.converter import structure_array, unstructure_array

_JINJA_ENV = Environment(
loader=PackageLoader("flopy4.mf6"),
trim_blocks=True,
lstrip_blocks=True,
)
_JINJA_ENV.filters["blocks"] = filters.blocks
_JINJA_ENV.filters["field_type"] = filters.field_type
_JINJA_ENV.filters["field_value"] = filters.field_value
_JINJA_ENV.filters["is_list"] = filters.is_list
_JINJA_ENV.filters["array_how"] = filters.array_how
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
_JINJA_ENV.filters["array2string"] = filters.array2string

_JINJA_TEMPLATE_NAME = "blocks.jinja"

_PRINT_OPTIONS = {
"precision": 4,
"linewidth": sys.maxsize,
"threshold": sys.maxsize,
}


def unstructure(data):
# TODO unstructure arrays into sparse dicts
# TODO combine OC fields into list input as defined in the MF6 dfn
# TODO return a dictionary instead of the component itself, then
# update filters to use dictinoary access instead of getattr()
return data


def loads(data: str) -> Any:
# TODO
pass


def load(path: str | PathLike) -> Any:
# TODO
pass


def dumps(data) -> str:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
return template.render(dfn=type(data).dfn, data=unstructure(data))


def dump(data, path: str | PathLike) -> None:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
iterator = template.generate(dfn=type(data).dfn, data=unstructure(data))
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
f.writelines(iterator)


__all__ = [
"structure_array",
"unstructure_array",
"loads",
"load",
"dumps",
"dump",
]
38 changes: 37 additions & 1 deletion flopy4/mf6/converters.py → flopy4/mf6/codec/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import numpy as np
import sparse
from numpy.typing import NDArray
from xarray import DataArray
from xattree import get_xatspec

from flopy4.mf6.config import SPARSE_THRESHOLD
from flopy4.mf6.constants import FILL_DNODATA


def convert_array(value, self_, field) -> NDArray:
# TODO: convert to a cattrs structuring hook so we don't have to
# apply separately to all array fields?
def structure_array(value, self_, field) -> NDArray:
"""
Convert a sparse dictionary representation of an array to a
dense numpy array or a sparse COO array.
"""

if not isinstance(value, dict):
# if not a dict, assume it's a numpy array
# and let xarray deal with it if it isn't
Expand Down Expand Up @@ -91,3 +99,31 @@ def _get_nn(cellid):
# a[(nn,)] = v

return final(a)


def unstructure_array(value: DataArray) -> dict:
"""
Convert a dense numpy array or a sparse COO array to a sparse
dictionary representation suitable for serialization into the
MF6 list-based input format.
"""
# make sure dim 'kper' is present
if "kper" not in value.dims:
raise ValueError("array must have 'kper' dimension")

if isinstance(value.data, sparse.COO):
coords = value.coords
data = value.data
else:
coords = np.array(np.nonzero(value)).T # type: ignore
data = value[tuple(coords.T)] # type: ignore
if not coords.size: # type: ignore
return {}
match value.ndim:
case 1:
return {k: v for k, v in zip(coords[:, 0], data)} # type: ignore
case 2:
return {(k, j): v for (k, j), v in zip(coords, data)} # type: ignore
case 3:
return {(k, i, j): v for (k, i, j), v in zip(coords, data)} # type: ignore
return {}
78 changes: 63 additions & 15 deletions flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Hashable, Mapping
from io import StringIO
from typing import Any

import numpy as np
Expand Down Expand Up @@ -36,28 +38,74 @@ def field_value(ctx, field: Field):
return getattr(ctx["data"], field["name"])


def array_delay(value: xr.DataArray, chunks=None):
def array_how(value: xr.DataArray) -> str:
return "internal"


def array_chunks(value: xr.DataArray, chunks: Mapping[Hashable, int] | None = None):
"""
Yield chunks from an array. Each chunk becomes a line in the file.
If the array is not already chunked, it is chunked using the given
chunk size. If no chunk size is provided, the entire array becomes
a single chunk.
Yield chunks from an array of up to 3 dimensions. If the
array is not already chunked, split it into chunks of the
specified sizes, given as a dictionary mapping dimension
names to chunk sizes.

If chunk sizes are not specified, chunk the array with at
most 2 dimensions per chunk, where:

- If the array is 3D, assume the first dimension is the
vertical (i.e. layers) and the others horizontal (rows and
columns, in that order), and yield a chunk per layer, such
that an array with indices (k, i, j) becomes k chunks, each
of shape (i, j).

- If the array is 1D or 2D, yield it as a single chunk.
"""

if value.chunks is None:
chunk_shape = chunks or {dim: size for dim, size in zip(value.dims, value.shape)}
value = value.chunk(chunk_shape)
if chunks is None:
match value.ndim:
case 1:
# 1D array, single chunk
chunks = {value.dims[0]: value.shape[0]}
case 2:
# 2D array, single chunk
chunks = {value.dims[0]: value.shape[0], value.dims[1]: value.shape[1]}
case 3:
# 3D array, chunk for each layer
chunks = {
value.dims[0]: 1,
value.dims[1]: value.shape[1],
value.dims[2]: value.shape[2],
}
value = value.chunk(chunks)
for chunk in value.data.blocks:
yield chunk.compute()


def array2string(value: NDArray) -> str:
"""Convert an array to a string."""
s = np.array2string(value, separator=" ")
if value.shape != ():
s = s[1:-1] # remove brackets
return s.replace("'", "").replace('"', "") # remove quotes
"""
Convert an array to a string. The array can be 1D or 2D.
If the array is 1D, it is converted to a 1-line string,
with elements separated by whitespace. If the array is
2D, each row becomes a line in the string.
"""
buffer = StringIO()
value = np.asarray(value)
if value.ndim > 2:
raise ValueError("Only 1D and 2D arrays are supported.")
if value.ndim == 1:
# add an axis to 1d arrays so np.savetxt writes elements on 1 line
value = value[None]
format = (
"%d"
if np.issubdtype(value.dtype, np.integer)
else "%f"
if np.issubdtype(value.dtype, np.floating)
else "%s"
)
np.savetxt(buffer, value, fmt=format, delimiter=" ")
return buffer.getvalue().strip()


def is_dict(value: Any) -> bool:
"""Check if the value is a dictionary."""
return isinstance(value, dict)
def is_list(value: Any) -> bool:
return isinstance(value, list)
10 changes: 5 additions & 5 deletions flopy4/mf6/gwf/chd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.typing import NDArray
from xattree import xattree

from flopy4.mf6.converters import convert_array
from flopy4.mf6.codec import structure_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field

Expand Down Expand Up @@ -40,7 +40,7 @@ class Steps:
"nnodes",
),
default=None,
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
aux: Optional[NDArray[np.floating]] = array(
block="period",
Expand All @@ -49,7 +49,7 @@ class Steps:
"nnodes",
),
default=None,
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
boundname: Optional[NDArray[np.str_]] = array(
block="period",
Expand All @@ -58,12 +58,12 @@ class Steps:
"nnodes",
),
default=None,
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
steps: Optional[NDArray[np.object_]] = array(
Steps,
block="period",
dims=("nper", "nnodes"),
default=None,
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
12 changes: 6 additions & 6 deletions flopy4/mf6/gwf/dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpy.typing import NDArray
from xattree import xattree

from flopy4.mf6.converters import convert_array
from flopy4.mf6.codec import structure_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, dim, field

Expand Down Expand Up @@ -42,31 +42,31 @@ class Dis(Package):
block="griddata",
default=1.0,
dims=("ncol",),
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
delc: NDArray[np.floating] = array(
block="griddata",
default=1.0,
dims=("nrow",),
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
top: NDArray[np.floating] = array(
block="griddata",
default=1.0,
dims=("nrow", "ncol"),
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
botm: NDArray[np.floating] = array(
block="griddata",
default=0.0,
dims=("nlay", "nrow", "ncol"),
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
idomain: NDArray[np.integer] = array(
block="griddata",
default=1,
dims=("nlay", "nrow", "ncol"),
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
nnodes: int = dim(
coord="node",
Expand Down
4 changes: 2 additions & 2 deletions flopy4/mf6/gwf/ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numpy.typing import NDArray
from xattree import xattree

from flopy4.mf6.converters import convert_array
from flopy4.mf6.codec import structure_array
from flopy4.mf6.package import Package
from flopy4.mf6.spec import array, field

Expand All @@ -14,7 +14,7 @@ class Ic(Package):
block="packagedata",
dims=("nnodes",),
default=1.0,
converter=Converter(convert_array, takes_self=True, takes_field=True),
converter=Converter(structure_array, takes_self=True, takes_field=True),
)
export_array_ascii: bool = field(block="options", default=False)
export_array_netcdf: bool = field(block="options", default=False)
Loading
Loading