Skip to content

Commit 3db8f7a

Browse files
committed
closer
1 parent dc3ffbd commit 3db8f7a

File tree

8 files changed

+607
-415
lines changed

8 files changed

+607
-415
lines changed

flopy4/mf6/component.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from collections.abc import MutableMapping
33

44
from xattree import xattree
5-
from flopy4.io import Writer
5+
6+
from flopy4.mf6.io import Writer
67

78
COMPONENTS = {}
89
"""MF6 component registry."""

flopy4/mf6/filters.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,52 @@
1-
from collections.abc import Iterable, Mapping
2-
from inspect import isclass
31
import types
42
from typing import Union, get_args, get_origin
3+
54
import numpy as np
65
import xarray as xr
6+
import xattree
77
from jinja2 import pass_context
8-
9-
from attrs import Attribute
8+
from numpy.typing import NDArray
109
from xattree import Xattribute
1110

1211

12+
def fieldkind(field: Xattribute) -> str:
13+
"""Get the kind of a field."""
14+
if isinstance(field, (xattree.Array, xattree.Coord)):
15+
return "array"
16+
if isinstance(field, xattree.Dim):
17+
return "scalar"
18+
if isinstance(field, xattree.Child):
19+
raise TypeError(f"Child field {field.name} unsupported in this context")
20+
type_ = field.type
21+
if type_ is None:
22+
raise TypeError(f"Field {field.name} has no type")
23+
if issubclass(type_, xattree.Scalar):
24+
return "scalar"
25+
args = get_args(type_)
26+
origin = get_origin(type_)
27+
if origin in (Union, types.UnionType):
28+
if args[-1] is types.NoneType: # Optional
29+
type_ = args[0]
30+
assert type_ is not None
31+
else:
32+
return "union"
33+
if issubclass(type_, xattree.Scalar):
34+
return "scalar"
35+
if issubclass(type_, xattree.Attr):
36+
return "record"
37+
raise TypeError(f"Unsupported field type {type_} for field {field.name}")
38+
39+
1340
@pass_context
14-
def value(ctx, field: Xattribute):
15-
"""Return the kind of the field."""
16-
# TODO
17-
pass
41+
def fieldvalue(ctx, field: Xattribute):
42+
return ctx["data"][field.name]
1843

1944

20-
def dask_expand(data: xr.DataArray):
21-
for block in data.data.to_delayed():
45+
def arraydelayed(value: xr.DataArray):
46+
for block in value.data.to_delayed():
2247
block_data = block.compute()
2348
yield block_data
2449

2550

26-
def nparray2string(data: np.ndarray):
27-
return np.array2string(data, separator=" ")[1:-1] # remove brackets
51+
def array2string(value: NDArray) -> str:
52+
return np.array2string(value, separator=" ")[1:-1] # remove brackets

flopy4/mf6/io.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,44 @@
11
import sys
2+
23
import numpy as np
34
from jinja2 import Environment, PackageLoader
4-
from flopy4.mf6 import filters
5-
from flopy4.mf6.spec import blocks_dict
65

6+
from flopy4.mf6 import filters
7+
from flopy4.mf6.spec import blocks_dict, fields_dict
78

89
env = Environment(
9-
loader=PackageLoader("flopy4.mf6"),
10+
loader=PackageLoader("flopy4.mf6"),
1011
trim_blocks=True,
1112
lstrip_blocks=True,
1213
)
13-
env.filters["dask_expand"] = filters.dask_expand
14-
env.filters["nparray2string"] = filters.nparray2string
14+
env.filters["fieldkind"] = filters.fieldkind
15+
env.filters["fieldvalue"] = filters.fieldvalue
16+
env.filters["arraydelayed"] = filters.arraydelayed
17+
env.filters["array2string"] = filters.array2string
1518

1619

1720
class Writer:
18-
def _write_ascii(self, path) -> None:
19-
# TODO: factor out an ascii writer separately
20-
21-
block_spec = blocks_dict(type(self))
22-
blocks = {}
23-
for block_name, block in block_spec.items():
24-
blocks[block_name] = {}
25-
for field_name, field in block.items():
26-
if field_name == "data":
27-
continue
28-
blocks[block_name][field_name] = {
29-
"spec": field,
30-
"value": getattr(self, field_name),
31-
}
21+
# TODO remove type: ignore statements below.
22+
# but idk how to properly type a mixin class.
23+
# this one assumes the presence of attributes:
24+
# - name
25+
# - path
26+
# - data
3227

28+
def _write_ascii(self) -> None:
29+
cls = type(self)
30+
fields = fields_dict(cls)
31+
blocks = blocks_dict(cls)
3332
template = env.get_template("blocks.jinja")
34-
iterator = template.generate(blocks=blocks)
35-
with np.printoptions(
36-
precision=4, linewidth=sys.maxsize, threshold=sys.maxsize
37-
):
38-
with open(path, "w") as f:
33+
iterator = template.generate(fields=fields, blocks=blocks, data=self.data) # type: ignore
34+
# are these printoptions always applicable?
35+
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
36+
# TODO don't hardcode the filename, maybe a filename attribute?
37+
with open(self.path / self.name, "w") as f: # type: ignore
3938
f.writelines(iterator)
4039

4140
def write(self) -> None:
41+
# TODO: factor out an ascii writer separately
4242
self._write_ascii()
4343
for child in self.children.values(): # type: ignore
4444
child.write()

flopy4/mf6/spec.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
These include field decorators and introspection functions.
44
"""
55

6-
from attrs import NOTHING, Attribute, fields_dict
6+
from attrs import NOTHING, Attribute
7+
from xattree import fields_dict as xattree_fields_dict
78

89
from flopy4.spec import array as flopy_array
910
from flopy4.spec import coord as flopy_coord
@@ -140,11 +141,24 @@ def blocks_dict(cls) -> dict[str, Block]:
140141
(field) name to `attrs.Attribute`.
141142
"""
142143
fields = fields_dict(cls)
143-
fields = {k: v for k, v in fields.items() if "block" in v.metadata}
144144
blocks: dict[str, Block] = {}
145145
for k, v in fields.items():
146146
block = v.metadata["block"]
147147
if block not in blocks:
148148
blocks[block] = {}
149149
blocks[block][k] = v
150150
return dict(sorted(blocks.items(), key=_block_sort_key))
151+
152+
153+
def fields(cls) -> list[Attribute]:
154+
"""Return an ordered list of fields for a component class."""
155+
return list(fields_dict(cls).values())
156+
157+
158+
def fields_dict(cls) -> dict[str, Attribute]:
159+
"""
160+
Return an ordered dictionary of fields for a component class,
161+
whose keys are field names. Each field is an `attrs.Attribute`.
162+
"""
163+
fields = xattree_fields_dict(cls)
164+
return {k: v for k, v in fields.items() if "block" in v.metadata}

flopy4/mf6/templates/blocks.jinja

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
{% for block in blocks %}
2-
BEGIN {{ block.name }} {% if block.index is not none %}{{ block.index }}{% endif %}
3-
{% for field in block %}
1+
{% for block_name, block_ in blocks.items() %}
2+
BEGIN {{ block_name }}
3+
{% for field in block_.values() %}
44
{{ macros.field(field) }}
55
{% endfor %}
6-
END {{ block.name }}
6+
END {{ block_name }}
77

88
{% endfor %}

flopy4/mf6/templates/macros.jinja

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{% macro field(field) %}
2-
{% if field|kind == 'scalar' %}
2+
{% if field|fieldkind == 'scalar' %}
33
{{ macros.scalar(field) }}
4-
{% elif field|kind == 'union' %}
4+
{% elif field|fieldkind == 'union' %}
55
{{ macros.union(field) }}
6-
{% elif field|kind == 'record' %}
6+
{% elif field|fieldkind == 'record' %}
77
{{ macros.record(field) }}
8-
{% elif field|kind == 'array' %}
8+
{% elif field|fieldkind == 'array' %}
99
{{ macros.array(field) }}
10-
{% elif field|kind == 'list' %}
10+
{% elif field|fieldkind == 'list' %}
1111
{{ macros.list(field) }}
1212
{% endif %}
1313
{% endmacro %}
@@ -48,8 +48,8 @@ CONSTANT
4848
{% endmacro %}
4949

5050
{% macro internal_array(field) %}
51-
{% for chunk in field.value|dask_expand %}
52-
{{ chunk|nparray2string }}
51+
{% for chunk in field|value|arraydelayed %}
52+
{{ chunk|array2string }}
5353
{% endfor %}
5454
{% endmacro %}
5555

0 commit comments

Comments
 (0)