Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
44 changes: 44 additions & 0 deletions flopy4/mf6/codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import sys

import numpy as np
from jinja2 import Environment, PackageLoader

from flopy4.mf6 import filters
from flopy4.mf6.component import Component
from flopy4.mf6.spec import blocks_dict, fields_dict
from flopy4.uio import DEFAULT_REGISTRY

JINJA_ENV = Environment(
loader=PackageLoader("flopy4.mf6"),
trim_blocks=True,
lstrip_blocks=True,
)
JINJA_ENV.filters["field_kind"] = filters.field_kind
JINJA_ENV.filters["fieldvalue"] = filters.fieldvalue
JINJA_ENV.filters["arraydelayed"] = filters.arraydelayed
JINJA_ENV.filters["array2string"] = filters.array2string
JINJA_ENV.filters["is_dict"] = filters.is_dict
JINJA_TEMPLATE_NAME = "blocks.jinja"


def _load_ascii(self) -> None:
# TODO
pass


def _write_ascii(self) -> None:
cls = type(self)
fields = fields_dict(cls)
blocks = blocks_dict(cls)
template = JINJA_ENV.get_template(JINJA_TEMPLATE_NAME)
iterator = template.generate(fields=fields, blocks=blocks, data=unstructure(self.data)) # type: ignore
# are these printoptions always applicable?
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
# TODO don't hardcode the filename, maybe a filename attribute?
with open(self.path / self.name, "w") as f: # type: ignore
f.writelines(iterator)


# TODO: where to do this? probably not here..on plugin discovery?
DEFAULT_REGISTRY.register_loader(Component, "ascii", _load_ascii)
DEFAULT_REGISTRY.register_writer(Component, "ascii", _write_ascii)
20 changes: 18 additions & 2 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,31 @@
from xattree import xattree

from flopy4.mf6.spec import fields_dict
from flopy4.uio import IO, Loader, Writer

COMPONENTS = {}
"""MF6 component registry."""


@xattree
class Component(ABC, MutableMapping):
"""
Base class for MF6 components.

Notes
-----
All subclasses of `Component` must be decorated with `xattree`.

We use the `children` attribute provided by `xattree`. We know
children are also `Component`s, but mypy does not. TODO: fix??
"""

_load = IO(Loader) # type: ignore
_write = IO(Writer) # type: ignore

@classmethod
def __attrs_init_subclass__(cls):
# add class to the component registry
COMPONENTS[cls.__name__.lower()] = cls

def __attrs_post_init__(self):
Expand Down Expand Up @@ -70,11 +86,11 @@ def _to_dfn_spec(attribute: Attribute) -> Var:
)

def load(self) -> None:
# TODO: load
self._load(format=format)
for child in self.children.values(): # type: ignore
child.load()

def write(self) -> None:
# TODO: write
self._write(format=format)
for child in self.children.values(): # type: ignore
child.write()
50 changes: 50 additions & 0 deletions flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any

import numpy as np
import xarray as xr
from attrs import Attribute
from jinja2 import pass_context
from numpy.typing import NDArray


def field_kind(field: Attribute) -> str:
"""
Get a field's `xattree` kind. Kind is either:

- 'child' for child fields
- 'array' for array fields
- 'coord' for coordinate array fields
- 'dim' for integer fields describing a dimension's size
- 'attr' for all other fields
"""
if (meta := field.metadata) is None:
raise TypeError(f"Field {field.name} has no metadata")
if (xatmeta := meta.get("xattree", None)) is None:
raise TypeError(f"Field {field.name} has no xattree metadata")
if "kind" not in xatmeta:
raise TypeError(f"Field {field.name} has no kind")
return xatmeta.get("kind") or "attr"


@pass_context
def fieldvalue(ctx, field: Attribute):
"""Get a field's value from the data tree via the template context."""
return ctx["data"].attrs.get(field.name) or ctx["data"].get(field.name)


def arraydelayed(value: xr.DataArray):
"""Yield chunks (lines) from a Dask array."""
# TODO: Determine a good chunk size,
# because if the underlying array is only numpy, it will stay one block.
for chunk in value.chunk():
yield chunk.compute()


def array2string(value: NDArray) -> str:
"""Convert an array to a string."""
return np.array2string(value, separator=" ")[1:-1] # remove brackets


def is_dict(value: Any) -> bool:
"""Check if the value is a dictionary."""
return isinstance(value, dict)
1 change: 0 additions & 1 deletion flopy4/mf6/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def blocks_dict(cls) -> dict[str, Block]:
(field) name to `attrs.Attribute`.
"""
fields = fields_dict(cls)
fields = {k: v for k, v in fields.items() if "block" in v.metadata}
blocks: dict[str, Block] = {}
for k, v in fields.items():
block = v.metadata["block"]
Expand Down
8 changes: 8 additions & 0 deletions flopy4/mf6/templates/blocks.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{% import 'macros.jinja' as macros with context %}
{% for block_name, block_ in blocks.items() %}
BEGIN {{ block_name }}
{% for field in block_.values() %}
{{ macros.field(field) }}
{% endfor %}
END {{ block_name }}
{% endfor %}
64 changes: 64 additions & 0 deletions flopy4/mf6/templates/macros.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{% macro field(field) %}
{% set kind = field|field_kind %}
{% if kind == 'attr' %}
{% if field.type|is_dict %}
{{ record(field) }}
{% else %}
{{ scalar(field) }}
{% endif %}
{% elif kind in ['array', 'coord'] %}
{{ array(field) }}
{% elif kind == 'dim' %}
{{ scalar(field) }}
{% elif kind == 'child' %}
{# TODO #}
{% endif %}
{%- endmacro %}

{% macro scalar(field) %}
{% set value = field|fieldvalue %}
{% if value is not none %}
{{ field.name }} {{ value }}
{% endif %}
{%- endmacro %}

{% macro keystring(field) %} {# union #}
{% for item in (field|fieldvalue).items() %}
{{ field(item) }}
{% endfor %}
{%- endmacro %}

{% macro record(field) %}
{% for item in field|fieldvalue %}
{% if item.tagged %}{{ item.name }} {% endif %}{{ field(field) }}
{% endfor %}
{%- endmacro %}

{% macro array(field, how="internal") %}
{% if how == "layered constant" %}
{{ field.name }} LAYERED
{% for val in field|fieldvalue %}
CONSTANT
{% endfor %}
{% elif how == "constant" %}
{{ field.name }} CONSTANT {{ field|fieldvalue }}
{% elif how == "layered" %}
{% if layered %}
{{ field.name }}{% for val in field|fieldvalue %} {{ val }}{% endfor %}
{% endif %}
{% elif how == "internal" %}
{{ field.name }} {{ internal_array(field) }}
{% elif how == "external" %}
{{ field.name}} OPEN/CLOSE {{ field|fieldvalue }}
{% endif %}
{%- endmacro %}

{% macro internal_array(field) %}
{% for chunk in field|fieldvalue|arraydelayed %}
{{ chunk|array2string }}
{% endfor %}
{%- endmacro %}

{% macro list(field) %}
{# TODO #}
{%- endmacro %}
2 changes: 1 addition & 1 deletion flopy4/mf6/utils/cbc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from attrs import define
from flopy.discretization import StructuredGrid

from flopy4.structured_grid import StructuredGridWrapper
from flopy4.discretization.structured_grid import StructuredGridWrapper

from .grid_utils import get_coords

Expand Down
119 changes: 119 additions & 0 deletions flopy4/uio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Unified IO framework. Program interfaces can plug in custom
load/write routines for pairs of component class and format.

Most of this module is stolen/simplified from astropy, at:
- https://github.com/astropy/astropy/tree/main/astropy/io.
"""

__all__ = ["IO", "Loader", "Writer", "DEFAULT_REGISTRY"]


from typing import Literal


class Registry:
"""
Registry for IO operations. Loaders and writers
are registered by format and the class they are
associated with.
"""

def __init__(self):
self._loaders = {}
self._writers = {}

def get_loader(self, cls, format=None):
return next(
iter(
[
fn
for ((fmt, cls_), fn) in self._loaders.items()
if fmt == format and issubclass(cls, cls_)
]
)
)

def get_writer(self, cls, format=None):
return next(
iter(
[
fn
for ((fmt, cls_), fn) in self._writers.items()
if fmt == format and issubclass(cls, cls_)
]
)
)

def register_loader(self, cls, format, function):
if format in self._loaders:
raise ValueError(f"Loader for format {format} already registered.")
self._loaders[cls, format] = (cls, function)

def register_writer(self, cls, format, function):
if format in self._writers:
raise ValueError(f"Writer for format {format} already registered.")
self._writers[cls, format] = (cls, function)

def load(self, cls, *args, format=None, **kwargs):
return self.get_loader(cls, format)(*args, **kwargs)

def write(self, cls, *args, format=None, **kwargs):
return self.get_writer(cls, format)(*args, **kwargs)


DEFAULT_REGISTRY = Registry()


Op = Literal["load", "write"]


class IO(property):
"""Wrap a file IO descriptor as a property."""

def __get__(self, instance, owner_cls):
return self.fget(instance, owner_cls)


class IODescriptor:
"""Base class for file IO operations, implemented as descriptors."""

def __init__(self, instance, cls, op: Op, registry: Registry | None = None):
self._registry = registry or DEFAULT_REGISTRY
self._instance = instance
self._cls = cls
self._op: Op = op

@property
def registry(self):
return self._registry

def list_formats(self, out=None):
formats = self._registry.get_formats(self._cls, self._op)

if out is None:
formats.pprint(max_lines=-1, max_width=-1)
else:
out.write("\n".join(formats.pformat(max_lines=-1, max_width=-1)))

return out


class Loader(IODescriptor):
"""Descriptor for loading data from file."""

def __init__(self, instance, cls):
super().__init__(instance, cls, "load", registry=None)

def __call__(self, *args, **kwargs) -> None:
return self.registry.load(self._cls, *args, **kwargs)


class Writer(IODescriptor):
"""Descriptor for writing data to file."""

def __init__(self, instance, cls):
super().__init__(instance, cls, "write", registry=None)

def __call__(self, *args, **kwargs) -> None:
return self.registry.write(self._cls, *args, **kwargs)
Loading
Loading