Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions flopy4/mf6/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@

from flopy4.mf6 import filters
from flopy4.mf6.component import Component
from flopy4.mf6.spec import blocks_dict, fields_dict
from flopy4.uio import DEFAULT_REGISTRY

JINJA_ENV = Environment(
loader=PackageLoader("flopy4.mf6"),
trim_blocks=True,
lstrip_blocks=True,
)
JINJA_ENV.filters["field_kind"] = filters.field_kind
JINJA_ENV.filters["fieldvalue"] = filters.fieldvalue
JINJA_ENV.filters["arraydelayed"] = filters.arraydelayed
JINJA_ENV.filters["blocks"] = filters.blocks
JINJA_ENV.filters["field_kind"] = filters.field_type
JINJA_ENV.filters["field_value"] = filters.field_value
JINJA_ENV.filters["array_delay"] = filters.array_delay
JINJA_ENV.filters["array2string"] = filters.array2string
JINJA_ENV.filters["is_dict"] = filters.is_dict
JINJA_TEMPLATE_NAME = "blocks.jinja"
Expand All @@ -27,11 +27,8 @@ def _load_ascii(self) -> None:


def _write_ascii(self) -> None:
cls = type(self)
fields = fields_dict(cls)
blocks = blocks_dict(cls)
template = JINJA_ENV.get_template(JINJA_TEMPLATE_NAME)
iterator = template.generate(fields=fields, blocks=blocks, data=unstructure(self.data)) # type: ignore
iterator = template.generate(dfn=type(self).dfn, data=self)
# are these printoptions always applicable?
with np.printoptions(precision=4, linewidth=sys.maxsize, threshold=sys.maxsize):
# TODO don't hardcode the filename, maybe a filename attribute?
Expand Down
53 changes: 17 additions & 36 deletions flopy4/mf6/component.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from abc import ABC
from collections.abc import MutableMapping

from attrs import Attribute
from modflow_devtools.dfn import Dfn, Var
from modflow_devtools.dfn import Dfn, Field
from xattree import xattree

from flopy4.mf6.spec import fields_dict
from flopy4.mf6.spec import fields_dict, to_dfn_field
from flopy4.uio import IO, Loader, Writer

COMPONENTS = {}
Expand All @@ -30,11 +29,8 @@ class Component(ABC, MutableMapping):

@classmethod
def __attrs_init_subclass__(cls):
# add class to the component registry
COMPONENTS[cls.__name__.lower()] = cls

def __attrs_post_init__(self):
self._dfn = self._get_dfn()
cls.dfn = cls.get_dfn()

def __getitem__(self, key):
return self.children[key] # type: ignore
Expand All @@ -51,37 +47,22 @@ def __iter__(self):
def __len__(self):
return len(self.children) # type: ignore

@property
def dfn(self) -> Dfn:
"""Return the component's definition."""
return self._dfn

def _get_dfn(self) -> Dfn:
def _to_dfn_spec(attribute: Attribute) -> Var:
return Var(
name=attribute.name,
type=attribute.type,
shape=attribute.metadata.get("dims", None),
block=attribute.metadata.get("block", None),
default=attribute.default,
children={k: _to_dfn_spec(v) for k, v in fields_dict(attribute.type)} # type: ignore
if attribute.metadata.get("kind", None) == "child" # type: ignore
else None, # type: ignore
)

fields = {k: _to_dfn_spec(v) for k, v in fields_dict(self.__class__).items()}
blocks: dict[str, dict[str, Var]] = {}
for k, v in fields.items():
if (block := v.get("block", None)) is not None:
blocks.setdefault(block, {})[k] = v
@classmethod
def get_dfn(cls) -> Dfn:
fields = {field_name: to_dfn_field(field) for field_name, field in fields_dict(cls).items()}
blocks: dict[str, dict[str, Field]] = {}
for field_name, field in fields.items():
if (block := field.get("block", None)) is not None:
blocks.setdefault(block, {})[field_name] = field
else:
blocks[k] = v
blocks[field_name] = field

return Dfn(
name=self.name, # type: ignore
advanced=getattr(self, "advanced_package", False),
multi=getattr(self, "multi_package", False),
ref=getattr(self, "sub_package", None),
sln=getattr(self, "solution_package", None),
name=cls.__name__.lower(),
advanced=getattr(cls, "advanced_package", False),
multi=getattr(cls, "multi_package", False),
ref=getattr(cls, "sub_package", None),
sln=getattr(cls, "solution_package", None),
**blocks,
)

Expand Down
33 changes: 13 additions & 20 deletions flopy4/mf6/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,30 @@

import numpy as np
import xarray as xr
from attrs import Attribute
from jinja2 import pass_context
from modflow_devtools.dfn import Dfn, Field
from numpy.typing import NDArray


def field_kind(field: Attribute) -> str:
"""
Get a field's `xattree` kind. Kind is either:
def blocks(dfn: Dfn) -> dict:
return {k: v for k, v in dfn.items() if k not in Dfn.__annotations__}


- 'child' for child fields
- 'array' for array fields
- 'coord' for coordinate array fields
- 'dim' for integer fields describing a dimension's size
- 'attr' for all other fields
def field_type(field: Field) -> str:
"""
Get a field's type as defined by the MODFLOW 6 input definition language:
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types
"""
if (meta := field.metadata) is None:
raise TypeError(f"Field {field.name} has no metadata")
if (xatmeta := meta.get("xattree", None)) is None:
raise TypeError(f"Field {field.name} has no xattree metadata")
if "kind" not in xatmeta:
raise TypeError(f"Field {field.name} has no kind")
return xatmeta.get("kind") or "attr"
return field["type"]


@pass_context
def fieldvalue(ctx, field: Attribute):
"""Get a field's value from the data tree via the template context."""
return ctx["data"].attrs.get(field.name) or ctx["data"].get(field.name)
def field_value(ctx, field: Field):
"""Get a field's value via the template context."""
return getattr(ctx["data"], field["name"])


def arraydelayed(value: xr.DataArray):
def array_delay(value: xr.DataArray):
"""Yield chunks (lines) from a Dask array."""
# TODO: Determine a good chunk size,
# because if the underlying array is only numpy, it will stay one block.
Expand Down
66 changes: 66 additions & 0 deletions flopy4/mf6/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
These include field decorators and introspection functions.
"""

import builtins
import types
from typing import Union, get_args, get_origin

import numpy as np
from attrs import NOTHING, Attribute
from modflow_devtools.dfn import Field, FieldType

from flopy4.spec import array as flopy_array
from flopy4.spec import coord as flopy_coord
Expand Down Expand Up @@ -162,3 +168,63 @@ def fields_dict(cls) -> dict[str, Attribute]:
"""
fields = flopy_fields_dict(cls)
return {k: v for k, v in fields.items() if "block" in v.metadata}


def get_dfn_field_type(attribute: Attribute) -> FieldType:
"""
Get a `xattree` field's type as defined by the MODFLOW 6 input
definition language:
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types

The type of the field is determined from `xattree` metadata.
"""
if (xatmeta := attribute.metadata.get("xattree", None)) is None:
raise ValueError(f"Attribute {attribute.name} in {attribute.name} has no xattree metadata.")
kind = xatmeta["kind"]
match kind:
case "child":
raise ValueError(f"Top-level field should not be a child: {attribute.name}")
case "array":
return "recarray"
case "coord":
return "recarray"
case "dim":
return "integer"
case "attr":
match attribute.type:
case builtins.str | np.str_:
return "string"
case builtins.bool | np.bool:
return "keyword"
case builtins.int | np.integer:
return "integer"
case builtins.float | np.floating:
return "double precision"

case t if (
get_origin(t) in (Union, types.UnionType) and get_args(t)[-1] is types.NoneType
):
return "union"
case _:
return "record"
raise ValueError(f"Could not map {attribute.name} to a valid MF6 type.")


def to_dfn_field(attribute: Attribute) -> Field:
"""
Convert a `xattree` field specification to a field as defined by the
MODFLOW 6 input definition language:
https://modflow6.readthedocs.io/en/stable/_dev/dfn.html#variable-types.
"""
if (xatmeta := attribute.metadata.get("xattree", None)) is None:
raise ValueError(f"Attribute {attribute.name} in {attribute.name} has no xattree metadata.")
return Field(
name=attribute.name,
type=get_dfn_field_type(attribute),
shape=xatmeta.get("dims", None),
block=attribute.metadata.get("block", None),
default=attribute.default,
children={k: to_dfn_field(v) for k, v in fields_dict(attribute.type)} # type: ignore
if attribute.metadata.get("kind", None) == "child" # type: ignore
else None, # type: ignore
)
2 changes: 1 addition & 1 deletion flopy4/mf6/templates/blocks.jinja
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% import 'macros.jinja' as macros with context %}
{% for block_name, block_ in blocks.items() %}
{% for block_name, block_ in (dfn|blocks).items() %}
BEGIN {{ block_name }}
{% for field in block_.values() %}
{{ macros.field(field) }}
Expand Down
29 changes: 12 additions & 17 deletions flopy4/mf6/templates/macros.jinja
Original file line number Diff line number Diff line change
@@ -1,60 +1,55 @@
{% macro field(field) %}
{% set kind = field|field_kind %}
{% if kind == 'attr' %}
{% if field.type|is_dict %}
{{ record(field) }}
{% else %}
{{ scalar(field) }}
{% if field.type|is_dict %}{{ record(field) }}{% else %}{{ scalar(field) }}
{% endif %}
{% elif kind in ['array', 'coord'] %}
{{ array(field) }}
{% elif kind == 'dim' %}
{{ scalar(field) }}
{% elif kind == 'child' %}
{# TODO #}
{% endif %}
{% endif -%}
{%- endmacro %}

{% macro scalar(field) %}
{% set value = field|fieldvalue %}
{% if value is not none %}
{{ field.name }} {{ value }}
{% endif %}
{% macro scalar(field) -%}
{% set value = field|field_value %}
{% if value is not none %}{{ field.name }} {{ value }}{% endif %}
{%- endmacro %}

{% macro keystring(field) %} {# union #}
{% for item in (field|fieldvalue).items() %}
{% for item in (field|field_value).items() %}
{{ field(item) }}
{% endfor %}
{%- endmacro %}

{% macro record(field) %}
{% for item in field|fieldvalue %}
{% for item in field|field_value %}
{% if item.tagged %}{{ item.name }} {% endif %}{{ field(field) }}
{% endfor %}
{%- endmacro %}

{% macro array(field, how="internal") %}
{% if how == "layered constant" %}
{{ field.name }} LAYERED
{% for val in field|fieldvalue %}
{% for val in field|field_value %}
CONSTANT
{% endfor %}
{% elif how == "constant" %}
{{ field.name }} CONSTANT {{ field|fieldvalue }}
{{ field.name }} CONSTANT {{ field|field_value }}
{% elif how == "layered" %}
{% if layered %}
{{ field.name }}{% for val in field|fieldvalue %} {{ val }}{% endfor %}
{{ field.name }}{% for val in field|field_value %} {{ val }}{% endfor %}
{% endif %}
{% elif how == "internal" %}
{{ field.name }} {{ internal_array(field) }}
{% elif how == "external" %}
{{ field.name}} OPEN/CLOSE {{ field|fieldvalue }}
{{ field.name}} OPEN/CLOSE {{ field|field_value }}
{% endif %}
{%- endmacro %}

{% macro internal_array(field) %}
{% for chunk in field|fieldvalue|arraydelayed %}
{% for chunk in field|field_value|array_delay %}
{{ chunk|array2string }}
{% endfor %}
{%- endmacro %}
Expand Down
Loading
Loading