Skip to content

Commit 30a7fb4

Browse files
authored
input file writing (#136)
first hack at an input file writer. may later move to devtools to live with the input file reader. just getting the basics down. also draft a unified io framework supporting arbitrary formats mappable to component types.
1 parent 2c01529 commit 30a7fb4

File tree

14 files changed

+479
-109
lines changed

14 files changed

+479
-109
lines changed
File renamed without changes.

flopy4/mf6/codec.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import sys
2+
3+
import numpy as np
4+
from jinja2 import Environment, PackageLoader
5+
6+
from flopy4.mf6 import filters
7+
from flopy4.mf6.component import Component
8+
from flopy4.mf6.spec import blocks_dict, fields_dict
9+
from flopy4.uio import DEFAULT_REGISTRY
10+
11+
JINJA_ENV = Environment(
12+
loader=PackageLoader("flopy4.mf6"),
13+
trim_blocks=True,
14+
lstrip_blocks=True,
15+
)
16+
JINJA_ENV.filters["field_kind"] = filters.field_kind
17+
JINJA_ENV.filters["fieldvalue"] = filters.fieldvalue
18+
JINJA_ENV.filters["arraydelayed"] = filters.arraydelayed
19+
JINJA_ENV.filters["array2string"] = filters.array2string
20+
JINJA_ENV.filters["is_dict"] = filters.is_dict
21+
JINJA_TEMPLATE_NAME = "blocks.jinja"
22+
23+
24+
def _load_ascii(self) -> None:
25+
# TODO
26+
pass
27+
28+
29+
def _write_ascii(self) -> None:
30+
cls = type(self)
31+
fields = fields_dict(cls)
32+
blocks = blocks_dict(cls)
33+
template = JINJA_ENV.get_template(JINJA_TEMPLATE_NAME)
34+
iterator = template.generate(fields=fields, blocks=blocks, data=unstructure(self.data)) # type: ignore
35+
# are these printoptions always applicable?
36+
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
37+
# TODO don't hardcode the filename, maybe a filename attribute?
38+
with open(self.path / self.name, "w") as f: # type: ignore
39+
f.writelines(iterator)
40+
41+
42+
# TODO: where to do this? probably not here..on plugin discovery?
43+
DEFAULT_REGISTRY.register_loader(Component, "ascii", _load_ascii)
44+
DEFAULT_REGISTRY.register_writer(Component, "ascii", _write_ascii)

flopy4/mf6/component.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,31 @@
66
from xattree import xattree
77

88
from flopy4.mf6.spec import fields_dict
9+
from flopy4.uio import IO, Loader, Writer
910

1011
COMPONENTS = {}
1112
"""MF6 component registry."""
1213

1314

1415
@xattree
1516
class Component(ABC, MutableMapping):
17+
"""
18+
Base class for MF6 components.
19+
20+
Notes
21+
-----
22+
All subclasses of `Component` must be decorated with `xattree`.
23+
24+
We use the `children` attribute provided by `xattree`. We know
25+
children are also `Component`s, but mypy does not. TODO: fix??
26+
"""
27+
28+
_load = IO(Loader) # type: ignore
29+
_write = IO(Writer) # type: ignore
30+
1631
@classmethod
1732
def __attrs_init_subclass__(cls):
33+
# add class to the component registry
1834
COMPONENTS[cls.__name__.lower()] = cls
1935

2036
def __attrs_post_init__(self):
@@ -70,11 +86,11 @@ def _to_dfn_spec(attribute: Attribute) -> Var:
7086
)
7187

7288
def load(self) -> None:
73-
# TODO: load
89+
self._load(format=format)
7490
for child in self.children.values(): # type: ignore
7591
child.load()
7692

7793
def write(self) -> None:
78-
# TODO: write
94+
self._write(format=format)
7995
for child in self.children.values(): # type: ignore
8096
child.write()

flopy4/mf6/filters.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
import xarray as xr
5+
from attrs import Attribute
6+
from jinja2 import pass_context
7+
from numpy.typing import NDArray
8+
9+
10+
def field_kind(field: Attribute) -> str:
11+
"""
12+
Get a field's `xattree` kind. Kind is either:
13+
14+
- 'child' for child fields
15+
- 'array' for array fields
16+
- 'coord' for coordinate array fields
17+
- 'dim' for integer fields describing a dimension's size
18+
- 'attr' for all other fields
19+
"""
20+
if (meta := field.metadata) is None:
21+
raise TypeError(f"Field {field.name} has no metadata")
22+
if (xatmeta := meta.get("xattree", None)) is None:
23+
raise TypeError(f"Field {field.name} has no xattree metadata")
24+
if "kind" not in xatmeta:
25+
raise TypeError(f"Field {field.name} has no kind")
26+
return xatmeta.get("kind") or "attr"
27+
28+
29+
@pass_context
30+
def fieldvalue(ctx, field: Attribute):
31+
"""Get a field's value from the data tree via the template context."""
32+
return ctx["data"].attrs.get(field.name) or ctx["data"].get(field.name)
33+
34+
35+
def arraydelayed(value: xr.DataArray):
36+
"""Yield chunks (lines) from a Dask array."""
37+
# TODO: Determine a good chunk size,
38+
# because if the underlying array is only numpy, it will stay one block.
39+
for chunk in value.chunk():
40+
yield chunk.compute()
41+
42+
43+
def array2string(value: NDArray) -> str:
44+
"""Convert an array to a string."""
45+
return np.array2string(value, separator=" ")[1:-1] # remove brackets
46+
47+
48+
def is_dict(value: Any) -> bool:
49+
"""Check if the value is a dictionary."""
50+
return isinstance(value, dict)

flopy4/mf6/spec.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def blocks_dict(cls) -> dict[str, Block]:
141141
(field) name to `attrs.Attribute`.
142142
"""
143143
fields = fields_dict(cls)
144-
fields = {k: v for k, v in fields.items() if "block" in v.metadata}
145144
blocks: dict[str, Block] = {}
146145
for k, v in fields.items():
147146
block = v.metadata["block"]

flopy4/mf6/templates/blocks.jinja

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{% import 'macros.jinja' as macros with context %}
2+
{% for block_name, block_ in blocks.items() %}
3+
BEGIN {{ block_name }}
4+
{% for field in block_.values() %}
5+
{{ macros.field(field) }}
6+
{% endfor %}
7+
END {{ block_name }}
8+
{% endfor %}

flopy4/mf6/templates/macros.jinja

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
{% macro field(field) %}
2+
{% set kind = field|field_kind %}
3+
{% if kind == 'attr' %}
4+
{% if field.type|is_dict %}
5+
{{ record(field) }}
6+
{% else %}
7+
{{ scalar(field) }}
8+
{% endif %}
9+
{% elif kind in ['array', 'coord'] %}
10+
{{ array(field) }}
11+
{% elif kind == 'dim' %}
12+
{{ scalar(field) }}
13+
{% elif kind == 'child' %}
14+
{# TODO #}
15+
{% endif %}
16+
{%- endmacro %}
17+
18+
{% macro scalar(field) %}
19+
{% set value = field|fieldvalue %}
20+
{% if value is not none %}
21+
{{ field.name }} {{ value }}
22+
{% endif %}
23+
{%- endmacro %}
24+
25+
{% macro keystring(field) %} {# union #}
26+
{% for item in (field|fieldvalue).items() %}
27+
{{ field(item) }}
28+
{% endfor %}
29+
{%- endmacro %}
30+
31+
{% macro record(field) %}
32+
{% for item in field|fieldvalue %}
33+
{% if item.tagged %}{{ item.name }} {% endif %}{{ field(field) }}
34+
{% endfor %}
35+
{%- endmacro %}
36+
37+
{% macro array(field, how="internal") %}
38+
{% if how == "layered constant" %}
39+
{{ field.name }} LAYERED
40+
{% for val in field|fieldvalue %}
41+
CONSTANT
42+
{% endfor %}
43+
{% elif how == "constant" %}
44+
{{ field.name }} CONSTANT {{ field|fieldvalue }}
45+
{% elif how == "layered" %}
46+
{% if layered %}
47+
{{ field.name }}{% for val in field|fieldvalue %} {{ val }}{% endfor %}
48+
{% endif %}
49+
{% elif how == "internal" %}
50+
{{ field.name }} {{ internal_array(field) }}
51+
{% elif how == "external" %}
52+
{{ field.name}} OPEN/CLOSE {{ field|fieldvalue }}
53+
{% endif %}
54+
{%- endmacro %}
55+
56+
{% macro internal_array(field) %}
57+
{% for chunk in field|fieldvalue|arraydelayed %}
58+
{{ chunk|array2string }}
59+
{% endfor %}
60+
{%- endmacro %}
61+
62+
{% macro list(field) %}
63+
{# TODO #}
64+
{%- endmacro %}

flopy4/mf6/utils/cbc_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from attrs import define
1313
from flopy.discretization import StructuredGrid
1414

15-
from flopy4.structured_grid import StructuredGridWrapper
15+
from flopy4.discretization.structured_grid import StructuredGridWrapper
1616

1717
from .grid_utils import get_coords
1818

flopy4/uio.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
Unified IO framework. Program interfaces can plug in custom
3+
load/write routines for pairs of component class and format.
4+
5+
Most of this module is stolen/simplified from astropy, at:
6+
- https://github.com/astropy/astropy/tree/main/astropy/io.
7+
"""
8+
9+
__all__ = ["IO", "Loader", "Writer", "DEFAULT_REGISTRY"]
10+
11+
12+
from typing import Literal
13+
14+
15+
class Registry:
16+
"""
17+
Registry for IO operations. Loaders and writers
18+
are registered by format and the class they are
19+
associated with.
20+
"""
21+
22+
def __init__(self):
23+
self._loaders = {}
24+
self._writers = {}
25+
26+
def get_loader(self, cls, format=None):
27+
return next(
28+
iter(
29+
[
30+
fn
31+
for ((fmt, cls_), fn) in self._loaders.items()
32+
if fmt == format and issubclass(cls, cls_)
33+
]
34+
)
35+
)
36+
37+
def get_writer(self, cls, format=None):
38+
return next(
39+
iter(
40+
[
41+
fn
42+
for ((fmt, cls_), fn) in self._writers.items()
43+
if fmt == format and issubclass(cls, cls_)
44+
]
45+
)
46+
)
47+
48+
def register_loader(self, cls, format, function):
49+
if format in self._loaders:
50+
raise ValueError(f"Loader for format {format} already registered.")
51+
self._loaders[cls, format] = (cls, function)
52+
53+
def register_writer(self, cls, format, function):
54+
if format in self._writers:
55+
raise ValueError(f"Writer for format {format} already registered.")
56+
self._writers[cls, format] = (cls, function)
57+
58+
def load(self, cls, *args, format=None, **kwargs):
59+
return self.get_loader(cls, format)(*args, **kwargs)
60+
61+
def write(self, cls, *args, format=None, **kwargs):
62+
return self.get_writer(cls, format)(*args, **kwargs)
63+
64+
65+
DEFAULT_REGISTRY = Registry()
66+
67+
68+
Op = Literal["load", "write"]
69+
70+
71+
class IO(property):
72+
"""Wrap a file IO descriptor as a property."""
73+
74+
def __get__(self, instance, owner_cls):
75+
return self.fget(instance, owner_cls)
76+
77+
78+
class IODescriptor:
79+
"""Base class for file IO operations, implemented as descriptors."""
80+
81+
def __init__(self, instance, cls, op: Op, registry: Registry | None = None):
82+
self._registry = registry or DEFAULT_REGISTRY
83+
self._instance = instance
84+
self._cls = cls
85+
self._op: Op = op
86+
87+
@property
88+
def registry(self):
89+
return self._registry
90+
91+
def list_formats(self, out=None):
92+
formats = self._registry.get_formats(self._cls, self._op)
93+
94+
if out is None:
95+
formats.pprint(max_lines=-1, max_width=-1)
96+
else:
97+
out.write("\n".join(formats.pformat(max_lines=-1, max_width=-1)))
98+
99+
return out
100+
101+
102+
class Loader(IODescriptor):
103+
"""Descriptor for loading data from file."""
104+
105+
def __init__(self, instance, cls):
106+
super().__init__(instance, cls, "load", registry=None)
107+
108+
def __call__(self, *args, **kwargs) -> None:
109+
return self.registry.load(self._cls, *args, **kwargs)
110+
111+
112+
class Writer(IODescriptor):
113+
"""Descriptor for writing data to file."""
114+
115+
def __init__(self, instance, cls):
116+
super().__init__(instance, cls, "write", registry=None)
117+
118+
def __call__(self, *args, **kwargs) -> None:
119+
return self.registry.write(self._cls, *args, **kwargs)

0 commit comments

Comments
 (0)