Skip to content

Commit 78500e8

Browse files
committed
various. pass component directly into template so we can just use getattr() and let xattree do the datatree lookup instead of doing it ourselves. also clean up the component dfn attribute implementation. prep for more work on the templates.
1 parent 30a7fb4 commit 78500e8

File tree

9 files changed

+580
-540
lines changed

9 files changed

+580
-540
lines changed

flopy4/mf6/codec.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55

66
from flopy4.mf6 import filters
77
from flopy4.mf6.component import Component
8-
from flopy4.mf6.spec import blocks_dict, fields_dict
98
from flopy4.uio import DEFAULT_REGISTRY
109

1110
JINJA_ENV = Environment(
1211
loader=PackageLoader("flopy4.mf6"),
1312
trim_blocks=True,
1413
lstrip_blocks=True,
1514
)
16-
JINJA_ENV.filters["field_kind"] = filters.field_kind
17-
JINJA_ENV.filters["fieldvalue"] = filters.fieldvalue
18-
JINJA_ENV.filters["arraydelayed"] = filters.arraydelayed
15+
JINJA_ENV.filters["blocks"] = filters.blocks
16+
JINJA_ENV.filters["field_kind"] = filters.field_type
17+
JINJA_ENV.filters["field_value"] = filters.field_value
18+
JINJA_ENV.filters["array_delay"] = filters.array_delay
1919
JINJA_ENV.filters["array2string"] = filters.array2string
2020
JINJA_ENV.filters["is_dict"] = filters.is_dict
2121
JINJA_TEMPLATE_NAME = "blocks.jinja"
@@ -27,11 +27,8 @@ def _load_ascii(self) -> None:
2727

2828

2929
def _write_ascii(self) -> None:
30-
cls = type(self)
31-
fields = fields_dict(cls)
32-
blocks = blocks_dict(cls)
3330
template = JINJA_ENV.get_template(JINJA_TEMPLATE_NAME)
34-
iterator = template.generate(fields=fields, blocks=blocks, data=unstructure(self.data)) # type: ignore
31+
iterator = template.generate(dfn=type(self).dfn, data=self)
3532
# are these printoptions always applicable?
3633
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
3734
# TODO don't hardcode the filename, maybe a filename attribute?

flopy4/mf6/component.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from abc import ABC
22
from collections.abc import MutableMapping
33

4-
from attrs import Attribute
5-
from modflow_devtools.dfn import Dfn, Var
4+
from modflow_devtools.dfn import Dfn, Field
65
from xattree import xattree
76

8-
from flopy4.mf6.spec import fields_dict
7+
from flopy4.mf6.spec import fields_dict, to_dfn_field
98
from flopy4.uio import IO, Loader, Writer
109

1110
COMPONENTS = {}
@@ -30,11 +29,8 @@ class Component(ABC, MutableMapping):
3029

3130
@classmethod
3231
def __attrs_init_subclass__(cls):
33-
# add class to the component registry
3432
COMPONENTS[cls.__name__.lower()] = cls
35-
36-
def __attrs_post_init__(self):
37-
self._dfn = self._get_dfn()
33+
cls.dfn = cls.get_dfn()
3834

3935
def __getitem__(self, key):
4036
return self.children[key] # type: ignore
@@ -51,37 +47,22 @@ def __iter__(self):
5147
def __len__(self):
5248
return len(self.children) # type: ignore
5349

54-
@property
55-
def dfn(self) -> Dfn:
56-
"""Return the component's definition."""
57-
return self._dfn
58-
59-
def _get_dfn(self) -> Dfn:
60-
def _to_dfn_spec(attribute: Attribute) -> Var:
61-
return Var(
62-
name=attribute.name,
63-
type=attribute.type,
64-
shape=attribute.metadata.get("dims", None),
65-
block=attribute.metadata.get("block", None),
66-
default=attribute.default,
67-
children={k: _to_dfn_spec(v) for k, v in fields_dict(attribute.type)} # type: ignore
68-
if attribute.metadata.get("kind", None) == "child" # type: ignore
69-
else None, # type: ignore
70-
)
71-
72-
fields = {k: _to_dfn_spec(v) for k, v in fields_dict(self.__class__).items()}
73-
blocks: dict[str, dict[str, Var]] = {}
74-
for k, v in fields.items():
75-
if (block := v.get("block", None)) is not None:
76-
blocks.setdefault(block, {})[k] = v
50+
@classmethod
51+
def get_dfn(cls) -> Dfn:
52+
fields = {field_name: to_dfn_field(field) for field_name, field in fields_dict(cls).items()}
53+
blocks: dict[str, dict[str, Field]] = {}
54+
for field_name, field in fields.items():
55+
if (block := field.get("block", None)) is not None:
56+
blocks.setdefault(block, {})[field_name] = field
7757
else:
78-
blocks[k] = v
58+
blocks[field_name] = field
59+
7960
return Dfn(
80-
name=self.name, # type: ignore
81-
advanced=getattr(self, "advanced_package", False),
82-
multi=getattr(self, "multi_package", False),
83-
ref=getattr(self, "sub_package", None),
84-
sln=getattr(self, "solution_package", None),
61+
name=cls.__name__.lower(),
62+
advanced=getattr(cls, "advanced_package", False),
63+
multi=getattr(cls, "multi_package", False),
64+
ref=getattr(cls, "sub_package", None),
65+
sln=getattr(cls, "solution_package", None),
8566
**blocks,
8667
)
8768

flopy4/mf6/filters.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,30 @@
22

33
import numpy as np
44
import xarray as xr
5-
from attrs import Attribute
65
from jinja2 import pass_context
6+
from modflow_devtools.dfn import Dfn, Field
77
from numpy.typing import NDArray
88

99

10-
def field_kind(field: Attribute) -> str:
11-
"""
12-
Get a field's `xattree` kind. Kind is either:
10+
def blocks(dfn: Dfn) -> dict:
11+
return {k: v for k, v in dfn.items() if k not in Dfn.__annotations__}
12+
1313

14-
- 'child' for child fields
15-
- 'array' for array fields
16-
- 'coord' for coordinate array fields
17-
- 'dim' for integer fields describing a dimension's size
18-
- 'attr' for all other fields
14+
def field_type(field: Field) -> str:
15+
"""
16+
Get a field's type as defined by the MODFLOW 6 input definition language:
17+
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types
1918
"""
20-
if (meta := field.metadata) is None:
21-
raise TypeError(f"Field {field.name} has no metadata")
22-
if (xatmeta := meta.get("xattree", None)) is None:
23-
raise TypeError(f"Field {field.name} has no xattree metadata")
24-
if "kind" not in xatmeta:
25-
raise TypeError(f"Field {field.name} has no kind")
26-
return xatmeta.get("kind") or "attr"
19+
return field["type"]
2720

2821

2922
@pass_context
30-
def fieldvalue(ctx, field: Attribute):
31-
"""Get a field's value from the data tree via the template context."""
32-
return ctx["data"].attrs.get(field.name) or ctx["data"].get(field.name)
23+
def field_value(ctx, field: Field):
24+
"""Get a field's value via the template context."""
25+
return getattr(ctx["data"], field["name"])
3326

3427

35-
def arraydelayed(value: xr.DataArray):
28+
def array_delay(value: xr.DataArray):
3629
"""Yield chunks (lines) from a Dask array."""
3730
# TODO: Determine a good chunk size,
3831
# because if the underlying array is only numpy, it will stay one block.

flopy4/mf6/spec.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
These include field decorators and introspection functions.
44
"""
55

6+
import builtins
7+
import types
8+
from typing import Union, get_args, get_origin
9+
10+
import numpy as np
611
from attrs import NOTHING, Attribute
12+
from modflow_devtools.dfn import Field, FieldType
713

814
from flopy4.spec import array as flopy_array
915
from flopy4.spec import coord as flopy_coord
@@ -162,3 +168,63 @@ def fields_dict(cls) -> dict[str, Attribute]:
162168
"""
163169
fields = flopy_fields_dict(cls)
164170
return {k: v for k, v in fields.items() if "block" in v.metadata}
171+
172+
173+
def get_dfn_field_type(attribute: Attribute) -> FieldType:
174+
"""
175+
Get a `xattree` field's type as defined by the MODFLOW 6 input
176+
definition language:
177+
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types
178+
179+
The type of the field is determined from `xattree` metadata.
180+
"""
181+
if (xatmeta := attribute.metadata.get("xattree", None)) is None:
182+
raise ValueError(f"Attribute {attribute.name} in {attribute.name} has no xattree metadata.")
183+
kind = xatmeta["kind"]
184+
match kind:
185+
case "child":
186+
raise ValueError(f"Top-level field should not be a child: {attribute.name}")
187+
case "array":
188+
return "recarray"
189+
case "coord":
190+
return "recarray"
191+
case "dim":
192+
return "integer"
193+
case "attr":
194+
match attribute.type:
195+
case builtins.str | np.str_:
196+
return "string"
197+
case builtins.bool | np.bool:
198+
return "keyword"
199+
case builtins.int | np.integer:
200+
return "integer"
201+
case builtins.float | np.floating:
202+
return "double precision"
203+
204+
case t if (
205+
get_origin(t) in (Union, types.UnionType) and get_args(t)[-1] is types.NoneType
206+
):
207+
return "union"
208+
case _:
209+
return "record"
210+
raise ValueError(f"Could not map {attribute.name} to a valid MF6 type.")
211+
212+
213+
def to_dfn_field(attribute: Attribute) -> Field:
214+
"""
215+
Convert a `xattree` field specification to a field as defined by the
216+
MODFLOW 6 input definition language:
217+
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types.
218+
"""
219+
if (xatmeta := attribute.metadata.get("xattree", None)) is None:
220+
raise ValueError(f"Attribute {attribute.name} in {attribute.name} has no xattree metadata.")
221+
return Field(
222+
name=attribute.name,
223+
type=get_dfn_field_type(attribute),
224+
shape=xatmeta.get("dims", None),
225+
block=attribute.metadata.get("block", None),
226+
default=attribute.default,
227+
children={k: to_dfn_field(v) for k, v in fields_dict(attribute.type)} # type: ignore
228+
if attribute.metadata.get("kind", None) == "child" # type: ignore
229+
else None, # type: ignore
230+
)

flopy4/mf6/templates/blocks.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{% import 'macros.jinja' as macros with context %}
2-
{% for block_name, block_ in blocks.items() %}
2+
{% for block_name, block_ in (dfn|blocks).items() %}
33
BEGIN {{ block_name }}
44
{% for field in block_.values() %}
55
{{ macros.field(field) }}

flopy4/mf6/templates/macros.jinja

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,55 @@
11
{% macro field(field) %}
22
{% set kind = field|field_kind %}
33
{% if kind == 'attr' %}
4-
{% if field.type|is_dict %}
5-
{{ record(field) }}
6-
{% else %}
7-
{{ scalar(field) }}
4+
{% if field.type|is_dict %}{{ record(field) }}{% else %}{{ scalar(field) }}
85
{% endif %}
96
{% elif kind in ['array', 'coord'] %}
107
{{ array(field) }}
118
{% elif kind == 'dim' %}
129
{{ scalar(field) }}
1310
{% elif kind == 'child' %}
1411
{# TODO #}
15-
{% endif %}
12+
{% endif -%}
1613
{%- endmacro %}
1714

18-
{% macro scalar(field) %}
19-
{% set value = field|fieldvalue %}
20-
{% if value is not none %}
21-
{{ field.name }} {{ value }}
22-
{% endif %}
15+
{% macro scalar(field) -%}
16+
{% set value = field|field_value %}
17+
{% if value is not none %}{{ field.name }} {{ value }}{% endif %}
2318
{%- endmacro %}
2419

2520
{% macro keystring(field) %} {# union #}
26-
{% for item in (field|fieldvalue).items() %}
21+
{% for item in (field|field_value).items() %}
2722
{{ field(item) }}
2823
{% endfor %}
2924
{%- endmacro %}
3025

3126
{% macro record(field) %}
32-
{% for item in field|fieldvalue %}
27+
{% for item in field|field_value %}
3328
{% if item.tagged %}{{ item.name }} {% endif %}{{ field(field) }}
3429
{% endfor %}
3530
{%- endmacro %}
3631

3732
{% macro array(field, how="internal") %}
3833
{% if how == "layered constant" %}
3934
{{ field.name }} LAYERED
40-
{% for val in field|fieldvalue %}
35+
{% for val in field|field_value %}
4136
CONSTANT
4237
{% endfor %}
4338
{% elif how == "constant" %}
44-
{{ field.name }} CONSTANT {{ field|fieldvalue }}
39+
{{ field.name }} CONSTANT {{ field|field_value }}
4540
{% elif how == "layered" %}
4641
{% if layered %}
47-
{{ field.name }}{% for val in field|fieldvalue %} {{ val }}{% endfor %}
42+
{{ field.name }}{% for val in field|field_value %} {{ val }}{% endfor %}
4843
{% endif %}
4944
{% elif how == "internal" %}
5045
{{ field.name }} {{ internal_array(field) }}
5146
{% elif how == "external" %}
52-
{{ field.name}} OPEN/CLOSE {{ field|fieldvalue }}
47+
{{ field.name}} OPEN/CLOSE {{ field|field_value }}
5348
{% endif %}
5449
{%- endmacro %}
5550

5651
{% macro internal_array(field) %}
57-
{% for chunk in field|fieldvalue|arraydelayed %}
52+
{% for chunk in field|field_value|array_delay %}
5853
{{ chunk|array2string }}
5954
{% endfor %}
6055
{%- endmacro %}

0 commit comments

Comments
 (0)