Skip to content

Commit 038ca7a

Browse files
committed
closer
1 parent 8d8dd54 commit 038ca7a

File tree

8 files changed

+653
-453
lines changed

8 files changed

+653
-453
lines changed

flopy4/mf6/component.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from collections.abc import MutableMapping
33

44
from xattree import xattree
5-
from flopy4.io import Writer
5+
6+
from flopy4.mf6.io import Writer
67

78
COMPONENTS = {}
89
"""MF6 component registry."""

flopy4/mf6/filters.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,52 @@
1-
from collections.abc import Iterable, Mapping
2-
from inspect import isclass
31
import types
42
from typing import Union, get_args, get_origin
3+
54
import numpy as np
65
import xarray as xr
6+
import xattree
77
from jinja2 import pass_context
8-
9-
from attrs import Attribute
8+
from numpy.typing import NDArray
109
from xattree import Xattribute
1110

1211

12+
def fieldkind(field: Xattribute) -> str:
13+
"""Get the kind of a field."""
14+
if isinstance(field, (xattree.Array, xattree.Coord)):
15+
return "array"
16+
if isinstance(field, xattree.Dim):
17+
return "scalar"
18+
if isinstance(field, xattree.Child):
19+
raise TypeError(f"Child field {field.name} unsupported in this context")
20+
type_ = field.type
21+
if type_ is None:
22+
raise TypeError(f"Field {field.name} has no type")
23+
if issubclass(type_, xattree.Scalar):
24+
return "scalar"
25+
args = get_args(type_)
26+
origin = get_origin(type_)
27+
if origin in (Union, types.UnionType):
28+
if args[-1] is types.NoneType: # Optional
29+
type_ = args[0]
30+
assert type_ is not None
31+
else:
32+
return "union"
33+
if issubclass(type_, xattree.Scalar):
34+
return "scalar"
35+
if issubclass(type_, xattree.Attr):
36+
return "record"
37+
raise TypeError(f"Unsupported field type {type_} for field {field.name}")
38+
39+
1340
@pass_context
14-
def value(ctx, field: Xattribute):
15-
"""Return the kind of the field."""
16-
# TODO
17-
pass
41+
def fieldvalue(ctx, field: Xattribute):
42+
return ctx["data"][field.name]
1843

1944

20-
def dask_expand(data: xr.DataArray):
21-
for block in data.data.to_delayed():
45+
def arraydelayed(value: xr.DataArray):
46+
for block in value.data.to_delayed():
2247
block_data = block.compute()
2348
yield block_data
2449

2550

26-
def nparray2string(data: np.ndarray):
27-
return np.array2string(data, separator=" ")[1:-1] # remove brackets
51+
def array2string(value: NDArray) -> str:
52+
return np.array2string(value, separator=" ")[1:-1] # remove brackets

flopy4/mf6/io.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,44 @@
11
import sys
2+
23
import numpy as np
34
from jinja2 import Environment, PackageLoader
4-
from flopy4.mf6 import filters
5-
from flopy4.mf6.spec import blocks_dict
65

6+
from flopy4.mf6 import filters
7+
from flopy4.mf6.spec import blocks_dict, fields_dict
78

89
env = Environment(
9-
loader=PackageLoader("flopy4.mf6"),
10+
loader=PackageLoader("flopy4.mf6"),
1011
trim_blocks=True,
1112
lstrip_blocks=True,
1213
)
13-
env.filters["dask_expand"] = filters.dask_expand
14-
env.filters["nparray2string"] = filters.nparray2string
14+
env.filters["fieldkind"] = filters.fieldkind
15+
env.filters["fieldvalue"] = filters.fieldvalue
16+
env.filters["arraydelayed"] = filters.arraydelayed
17+
env.filters["array2string"] = filters.array2string
1518

1619

1720
class Writer:
18-
def _write_ascii(self, path) -> None:
19-
# TODO: factor out an ascii writer separately
20-
21-
block_spec = blocks_dict(type(self))
22-
blocks = {}
23-
for block_name, block in block_spec.items():
24-
blocks[block_name] = {}
25-
for field_name, field in block.items():
26-
if field_name == "data":
27-
continue
28-
blocks[block_name][field_name] = {
29-
"spec": field,
30-
"value": getattr(self, field_name),
31-
}
21+
# TODO remove type: ignore statements below.
22+
# but idk how to properly type a mixin class.
23+
# this one assumes the presence of attributes:
24+
# - name
25+
# - path
26+
# - data
3227

28+
def _write_ascii(self) -> None:
29+
cls = type(self)
30+
fields = fields_dict(cls)
31+
blocks = blocks_dict(cls)
3332
template = env.get_template("blocks.jinja")
34-
iterator = template.generate(blocks=blocks)
35-
with np.printoptions(
36-
precision=4, linewidth=sys.maxsize, threshold=sys.maxsize
37-
):
38-
with open(path, "w") as f:
33+
iterator = template.generate(fields=fields, blocks=blocks, data=self.data) # type: ignore
34+
# are these printoptions always applicable?
35+
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
36+
# TODO don't hardcode the filename, maybe a filename attribute?
37+
with open(self.path / self.name, "w") as f: # type: ignore
3938
f.writelines(iterator)
4039

4140
def write(self) -> None:
41+
# TODO: factor out an ascii writer separately
4242
self._write_ascii()
4343
for child in self.children.values(): # type: ignore
4444
child.write()

flopy4/mf6/spec.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,20 @@
33
These include field decorators and introspection functions.
44
"""
55

6-
from attrs import NOTHING, Attribute, fields_dict
7-
from xattree import array as xattree_array
8-
from xattree import coord as xattree_coord
9-
from xattree import dim as xattree_dim
10-
from xattree import field as xattree_field
6+
from attrs import NOTHING, Attribute
7+
from xattree import (
8+
array as xattree_array,
9+
)
10+
from xattree import (
11+
coord as xattree_coord,
12+
)
13+
from xattree import (
14+
dim as xattree_dim,
15+
)
16+
from xattree import (
17+
field as xattree_field,
18+
)
19+
from xattree import fields_dict as xattree_fields_dict
1120

1221

1322
def field(
@@ -139,11 +148,24 @@ def blocks_dict(cls) -> dict[str, Block]:
139148
(field) name to `attrs.Attribute`.
140149
"""
141150
fields = fields_dict(cls)
142-
fields = {k: v for k, v in fields.items() if "block" in v.metadata}
143151
blocks: dict[str, Block] = {}
144152
for k, v in fields.items():
145153
block = v.metadata["block"]
146154
if block not in blocks:
147155
blocks[block] = {}
148156
blocks[block][k] = v
149157
return dict(sorted(blocks.items(), key=_block_sort_key))
158+
159+
160+
def fields(cls) -> list[Attribute]:
161+
"""Return an ordered list of fields for a component class."""
162+
return list(fields_dict(cls).values())
163+
164+
165+
def fields_dict(cls) -> dict[str, Attribute]:
166+
"""
167+
Return an ordered dictionary of fields for a component class,
168+
whose keys are field names. Each field is an `attrs.Attribute`.
169+
"""
170+
fields = xattree_fields_dict(cls)
171+
return {k: v for k, v in fields.items() if "block" in v.metadata}

flopy4/mf6/templates/blocks.jinja

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
{% for block in blocks %}
2-
BEGIN {{ block.name }} {% if block.index is not none %}{{ block.index }}{% endif %}
3-
{% for field in block %}
1+
{% for block_name, block_ in blocks.items() %}
2+
BEGIN {{ block_name }}
3+
{% for field in block_.values() %}
44
{{ macros.field(field) }}
55
{% endfor %}
6-
END {{ block.name }}
6+
END {{ block_name }}
77

88
{% endfor %}

flopy4/mf6/templates/macros.jinja

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{% macro field(field) %}
2-
{% if field|kind == 'scalar' %}
2+
{% if field|fieldkind == 'scalar' %}
33
{{ macros.scalar(field) }}
4-
{% elif field|kind == 'union' %}
4+
{% elif field|fieldkind == 'union' %}
55
{{ macros.union(field) }}
6-
{% elif field|kind == 'record' %}
6+
{% elif field|fieldkind == 'record' %}
77
{{ macros.record(field) }}
8-
{% elif field|kind == 'array' %}
8+
{% elif field|fieldkind == 'array' %}
99
{{ macros.array(field) }}
10-
{% elif field|kind == 'list' %}
10+
{% elif field|fieldkind == 'list' %}
1111
{{ macros.list(field) }}
1212
{% endif %}
1313
{% endmacro %}
@@ -48,8 +48,8 @@ CONSTANT
4848
{% endmacro %}
4949

5050
{% macro internal_array(field) %}
51-
{% for chunk in field.value|dask_expand %}
52-
{{ chunk|nparray2string }}
51+
{% for chunk in field|value|arraydelayed %}
52+
{{ chunk|array2string }}
5353
{% endfor %}
5454
{% endmacro %}
5555

0 commit comments

Comments
 (0)