Skip to content

Commit 8b9b8e4

Browse files
committed
closer
1 parent b1abe67 commit 8b9b8e4

File tree

12 files changed

+444
-548
lines changed

12 files changed

+444
-548
lines changed
File renamed without changes.

flopy4/io/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
Core IO framework. Program interfaces can plug in bespoke
3+
load/write methods for particular components and formats.
4+
5+
Most of this module is stolen/simplified from astropy, at:
6+
- https://github.com/astropy/astropy/tree/main/astropy/io.
7+
"""
8+
9+
from flopy4.io.framework import IOMethod, Loader, Writer
10+
from flopy4.io.registry import DEFAULT_REGISTRY
11+
12+
__all__ = ["IOMethod", "Loader", "Writer", "DEFAULT_REGISTRY"]

flopy4/io/framework.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Literal
2+
3+
from flopy4.io.registry import Registry
4+
5+
IOMethodName = Literal["load", "write"]
6+
7+
8+
class IO:
9+
def __init__(self, instance, cls, method_name: IOMethodName, registry: Registry = None):
10+
self._registry = registry
11+
self._instance = instance
12+
self._cls = cls
13+
self._method_name: IOMethodName = method_name
14+
15+
@property
16+
def registry(self):
17+
return self._registry
18+
19+
def list_formats(self, out=None):
20+
formats = self._registry.get_formats(self._cls, self._method_name)
21+
22+
if out is None:
23+
formats.pprint(max_lines=-1, max_width=-1)
24+
else:
25+
out.write("\n".join(formats.pformat(max_lines=-1, max_width=-1)))
26+
27+
return out
28+
29+
30+
class IOMethod(property):
31+
def __get__(self, instance, owner_cls):
32+
return self.fget(instance, owner_cls)
33+
34+
35+
class Loader(IO):
36+
def __init__(self, instance, cls):
37+
super().__init__(instance, cls, "load", registry=None)
38+
39+
def __call__(self, *args, **kwargs) -> None:
40+
return self.registry.load(self._cls, *args, **kwargs)
41+
42+
43+
class Writer(IO):
44+
def __init__(self, instance, cls):
45+
super().__init__(instance, cls, "write", registry=None)
46+
47+
def __call__(self, *args, **kwargs) -> None:
48+
return self.registry.write(self._cls, *args, **kwargs)

flopy4/io/registry.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
class Registry:
2+
def __init__(self):
3+
self._loaders = {}
4+
self._writers = {}
5+
6+
def get_loader(self, cls, format=None):
7+
return next(
8+
iter(
9+
[
10+
fn
11+
for ((fmt, cls_), fn) in self._loaders.items()
12+
if fmt == format and issubclass(cls, cls_)
13+
]
14+
)
15+
)
16+
17+
def get_writer(self, cls, format=None):
18+
return next(
19+
iter(
20+
[
21+
fn
22+
for ((fmt, cls_), fn) in self._writers.items()
23+
if fmt == format and issubclass(cls, cls_)
24+
]
25+
)
26+
)
27+
28+
def register_loader(self, cls, format, function):
29+
if format in self._loaders:
30+
raise ValueError(f"Reader for format {format} already registered.")
31+
self._loaders[cls, format] = (cls, function)
32+
33+
def register_writer(self, cls, format, function):
34+
if format in self._writers:
35+
raise ValueError(f"Writer for format {format} already registered.")
36+
self._writers[cls, format] = (cls, function)
37+
38+
def load(self, cls, *args, format=None, **kwargs):
39+
return self.get_loader(cls, format)(*args, **kwargs)
40+
41+
def write(self, cls, *args, format=None, **kwargs):
42+
return self.get_writer(cls, format)(*args, **kwargs)
43+
44+
45+
DEFAULT_REGISTRY = Registry()

flopy4/mf6/codec.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import sys
2+
3+
import numpy as np
4+
from jinja2 import Environment, PackageLoader
5+
6+
from flopy4.io import DEFAULT_REGISTRY
7+
from flopy4.mf6 import filters
8+
from flopy4.mf6.component import Component
9+
from flopy4.mf6.spec import blocks_dict, fields_dict
10+
11+
JINJA_ENV = Environment(
12+
loader=PackageLoader("flopy4.mf6"),
13+
trim_blocks=True,
14+
lstrip_blocks=True,
15+
)
16+
JINJA_ENV.filters["fieldkind"] = filters.fieldkind
17+
JINJA_ENV.filters["fieldvalue"] = filters.fieldvalue
18+
JINJA_ENV.filters["arraydelayed"] = filters.arraydelayed
19+
JINJA_ENV.filters["array2string"] = filters.array2string
20+
JINJA_TEMPLATE_NAME = "blocks.jinja"
21+
22+
23+
def _load_ascii(self) -> None:
24+
# TODO
25+
pass
26+
27+
28+
def _write_ascii(self) -> None:
29+
cls = type(self)
30+
fields = fields_dict(cls)
31+
blocks = blocks_dict(cls)
32+
template = JINJA_ENV.get_template(JINJA_TEMPLATE_NAME)
33+
iterator = template.generate(fields=fields, blocks=blocks, data=unstructure(self.data)) # type: ignore
34+
# are these printoptions always applicable?
35+
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
36+
# TODO don't hardcode the filename, maybe a filename attribute?
37+
with open(self.path / self.name, "w") as f: # type: ignore
38+
f.writelines(iterator)
39+
40+
41+
# TODO: where to do this? probably not here..on plugin discovery?
42+
DEFAULT_REGISTRY.register_loader(Component, "ascii", _load_ascii)
43+
DEFAULT_REGISTRY.register_writer(Component, "ascii", _write_ascii)

flopy4/mf6/component.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from xattree import xattree
55

6-
from flopy4.mf6.io import ComponentReader, ComponentWriter, IOMethod
6+
from flopy4.io import IOMethod, Loader, Writer
77

88
COMPONENTS = {}
99
"""MF6 component registry."""
@@ -14,12 +14,17 @@ class Component(ABC, MutableMapping):
1414
"""
1515
Base class for MF6 components.
1616
17+
Notes
18+
-----
19+
All subclasses of `Component` must be decorated with `xattree`.
20+
1721
We use the `children` attribute provided by `xattree`. We know
18-
children are also `Component`s, but mypy does not. How to fix?
22+
children are also `Component`s, but mypy does not. TODO: fix??
1923
"""
2024

2125
@classmethod
2226
def __attrs_init_subclass__(cls):
27+
# add class to the component registry
2328
COMPONENTS[cls.__name__.lower()] = cls
2429

2530
def __getitem__(self, key):
@@ -37,10 +42,10 @@ def __iter__(self):
3742
def __len__(self):
3843
return len(self.children) # type: ignore
3944

40-
_read = IOMethod(ComponentReader) # type: ignore
41-
_write = IOMethod(ComponentWriter) # type: ignore
45+
_load = IOMethod(Loader) # type: ignore
46+
_write = IOMethod(Writer) # type: ignore
4247

43-
def read(self, format=None) -> None:
48+
def load(self, format=None) -> None:
4449
self._read(format=format)
4550
for child in self.children.values(): # type: ignore
4651
child.read(format=format)

flopy4/mf6/filters.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,41 @@
1-
import types
2-
from typing import Union, get_args, get_origin
3-
41
import numpy as np
52
import xarray as xr
6-
import xattree
3+
from attrs import Attribute
74
from jinja2 import pass_context
85
from numpy.typing import NDArray
9-
from xattree import Xattribute
10-
11-
12-
def fieldkind(field: Xattribute) -> str:
13-
"""Get the kind of a field."""
14-
if isinstance(field, (xattree.Array, xattree.Coord)):
15-
return "array"
16-
if isinstance(field, xattree.Dim):
17-
return "scalar"
18-
if isinstance(field, xattree.Child):
19-
raise TypeError(f"Child field {field.name} unsupported in this context")
20-
type_ = field.type
21-
if type_ is None:
22-
raise TypeError(f"Field {field.name} has no type")
23-
if issubclass(type_, xattree.Scalar):
24-
return "scalar"
25-
args = get_args(type_)
26-
origin = get_origin(type_)
27-
if origin in (Union, types.UnionType):
28-
if args[-1] is types.NoneType: # Optional
29-
type_ = args[0]
30-
assert type_ is not None
31-
else:
32-
return "union"
33-
if issubclass(type_, xattree.Scalar):
34-
return "scalar"
35-
if issubclass(type_, xattree.Attr):
36-
return "record"
37-
raise TypeError(f"Unsupported field type {type_} for field {field.name}")
6+
7+
8+
def fieldkind(field: Attribute) -> str:
9+
"""
10+
Get a field's `xattree` kind. Kind is either:
11+
12+
- 'child' for child fields
13+
- 'array' for array fields
14+
- 'coord' for coordinate array fields
15+
- 'dim' for integer fields describing a dimension's size
16+
- 'attr' for all other fields
17+
"""
18+
if meta := field.metadata is None:
19+
raise TypeError(f"Field {field.name} has no metadata")
20+
if xatmeta := meta.get("xattree", None) is None:
21+
raise TypeError(f"Field {field.name} has no xattree metadata")
22+
if kind := xatmeta.get("kind", None) is None:
23+
raise TypeError(f"Field {field.name} has no kind")
24+
return kind
3825

3926

4027
@pass_context
41-
def fieldvalue(ctx, field: Xattribute):
28+
def fieldvalue(ctx, field: Attribute):
29+
"""Get a field's value from the data tree via the template context."""
4230
return ctx["data"][field.name]
4331

4432

4533
def arraydelayed(value: xr.DataArray):
46-
for block in value.data.to_delayed():
47-
block_data = block.compute()
48-
yield block_data
34+
"""Yield chunks (lines) from a Dask array."""
35+
for chunk in value.data.to_delayed():
36+
yield chunk.compute()
4937

5038

5139
def array2string(value: NDArray) -> str:
40+
"""Convert an array to a string."""
5241
return np.array2string(value, separator=" ")[1:-1] # remove brackets

flopy4/mf6/io.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)