Skip to content

Commit b1abe67

Browse files
committed
steal io registry stuff from astropy
1 parent 60afd78 commit b1abe67

File tree

2 files changed

+124
-27
lines changed

2 files changed

+124
-27
lines changed

flopy4/mf6/component.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33

44
from xattree import xattree
55

6-
from flopy4.mf6.io import Writer
6+
from flopy4.mf6.io import ComponentReader, ComponentWriter, IOMethod
77

88
COMPONENTS = {}
99
"""MF6 component registry."""
1010

1111

1212
@xattree
13-
class Component(ABC, MutableMapping, Writer):
13+
class Component(ABC, MutableMapping):
14+
"""
15+
Base class for MF6 components.
16+
17+
We use the `children` attribute provided by `xattree`. We know
18+
children are also `Component`s, but mypy does not. How to fix?
19+
"""
20+
1421
@classmethod
1522
def __attrs_init_subclass__(cls):
1623
COMPONENTS[cls.__name__.lower()] = cls
@@ -29,3 +36,16 @@ def __iter__(self):
2936

3037
def __len__(self):
3138
return len(self.children) # type: ignore
39+
40+
_read = IOMethod(ComponentReader) # type: ignore
41+
_write = IOMethod(ComponentWriter) # type: ignore
42+
43+
def read(self, format=None) -> None:
44+
self._read(format=format)
45+
for child in self.children.values(): # type: ignore
46+
child.read(format=format)
47+
48+
def write(self, format=None) -> None:
49+
self._write(format=format)
50+
for child in self.children.values(): # type: ignore
51+
child.write(format=format)

flopy4/mf6/io.py

Lines changed: 102 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,97 @@
77
from flopy4.mf6 import filters
88
from flopy4.mf6.spec import blocks_dict, fields_dict
99

10+
# below stolen/simplified from https://github.com/astropy/astropy/tree/main/astropy/io
11+
# just sketching things.. haven't plugged it all together yet
12+
13+
14+
class Registry:
15+
def __init__(self):
16+
self._readers = {}
17+
self._writers = {}
18+
19+
def get_reader(self, cls, format=None):
20+
return next(
21+
iter(
22+
[
23+
fn
24+
for ((fmt, cls_), fn) in self._readers.items()
25+
if fmt == format and issubclass(cls, cls_)
26+
]
27+
)
28+
)
29+
30+
def get_writer(self, cls, format=None):
31+
return next(
32+
iter(
33+
[
34+
fn
35+
for ((fmt, cls_), fn) in self._writers.items()
36+
if fmt == format and issubclass(cls, cls_)
37+
]
38+
)
39+
)
40+
41+
def register_reader(self, cls, format, function):
42+
if format in self._readers:
43+
raise ValueError(f"Reader for format {format} already registered.")
44+
self._readers[cls, format] = (cls, function)
45+
46+
def register_writer(self, cls, format, function):
47+
if format in self._writers:
48+
raise ValueError(f"Writer for format {format} already registered.")
49+
self._writers[cls, format] = (cls, function)
50+
51+
def read(self, cls, *args, format=None, **kwargs):
52+
return self.get_reader(cls, format)(*args, **kwargs)
53+
54+
def write(self, cls, *args, format=None, **kwargs):
55+
return self.get_writer(cls, format)(*args, **kwargs)
56+
57+
58+
class IO:
59+
def __init__(self, instance, cls, method_name, registry=None):
60+
self._registry = registry
61+
self._instance = instance
62+
self._cls = cls
63+
self._method_name = method_name # 'read' or 'write'
64+
65+
@property
66+
def registry(self):
67+
return self._registry
68+
69+
def list_formats(self, out=None):
70+
formats = self._registry.get_formats(self._cls, self._method_name)
71+
72+
if out is None:
73+
formats.pprint(max_lines=-1, max_width=-1)
74+
else:
75+
out.write("\n".join(formats.pformat(max_lines=-1, max_width=-1)))
76+
77+
return out
78+
79+
80+
class IOMethod(property):
81+
def __get__(self, instance, owner_cls):
82+
return self.fget(instance, owner_cls)
83+
84+
85+
class ComponentReader(IO):
86+
def __init__(self, instance, cls):
87+
super().__init__(instance, cls, "read", registry=None)
88+
89+
def __call__(self, *args, **kwargs) -> None:
90+
return self.registry.read(self._cls, *args, **kwargs)
91+
92+
93+
class ComponentWriter(IO):
94+
def __init__(self, instance, cls):
95+
super().__init__(instance, cls, "write", registry=None)
96+
97+
def __call__(self, *args, **kwargs) -> None:
98+
return self.registry.write(self._cls, *args, **kwargs)
99+
100+
10101
env = Environment(
11102
loader=PackageLoader("flopy4.mf6"),
12103
trim_blocks=True,
@@ -18,28 +109,14 @@
18109
env.filters["array2string"] = filters.array2string
19110

20111

21-
class Writer:
22-
# TODO remove type: ignore statements below.
23-
# but idk how to properly type a mixin class.
24-
# this one assumes the presence of attributes:
25-
# - name
26-
# - path
27-
# - data
28-
29-
def _write_ascii(self) -> None:
30-
cls = type(self)
31-
fields = fields_dict(cls)
32-
blocks = blocks_dict(cls)
33-
template = env.get_template("blocks.jinja")
34-
iterator = template.generate(fields=fields, blocks=blocks, data=unstructure(self.data)) # type: ignore
35-
# are these printoptions always applicable?
36-
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
37-
# TODO don't hardcode the filename, maybe a filename attribute?
38-
with open(self.path / self.name, "w") as f: # type: ignore
39-
f.writelines(iterator)
40-
41-
def write(self) -> None:
42-
# TODO: factor out an ascii writer separately
43-
self._write_ascii()
44-
for child in self.children.values(): # type: ignore
45-
child.write()
112+
def _write_ascii(self) -> None:
113+
cls = type(self)
114+
fields = fields_dict(cls)
115+
blocks = blocks_dict(cls)
116+
template = env.get_template("blocks.jinja")
117+
iterator = template.generate(fields=fields, blocks=blocks, data=unstructure(self.data)) # type: ignore
118+
# are these printoptions always applicable?
119+
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
120+
# TODO don't hardcode the filename, maybe a filename attribute?
121+
with open(self.path / self.name, "w") as f: # type: ignore
122+
f.writelines(iterator)

0 commit comments

Comments
 (0)