Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions flopy4/mf6/attr_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Attribute hooks for attrs on_setattr callbacks."""

import numpy as np
from attrs import fields

from flopy4.mf6.constants import FILL_DNODATA


def update_maxbound(instance, attribute, new_value):
"""
Generalized function to update maxbound when period block arrays change.

This function automatically finds all period block arrays in the instance
and calculates maxbound based on the maximum number of non-default values
across all arrays.

Args:
instance: The package instance
attribute: The attribute being set (from attrs on_setattr)
new_value: The new value being set

Returns:
The new_value (unchanged)
"""

period_arrays = []
instance_fields = fields(instance.__class__)
for field in instance_fields:
if field.metadata and field.metadata.get("block") == "period" and "dims" in field.metadata:
period_arrays.append(field.name)

maxbound_values = []
for array_name in period_arrays:
if attribute and attribute.name == array_name:
array_val = new_value
else:
array_val = getattr(instance, array_name, None)

if array_val is not None:
array_data = (
array_val if array_val.data.shape == array_val.shape else array_val.todense()
)

if array_data.dtype.kind in ["U", "S"]: # String arrays
non_default_count = len(np.where(array_data != "")[0])
else: # Numeric arrays
non_default_count = len(np.where(array_data != FILL_DNODATA)[0])

maxbound_values.append(non_default_count)
if maxbound_values:
instance.maxbound = max(maxbound_values)

return new_value
86 changes: 1 addition & 85 deletions flopy4/mf6/codec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,6 @@
import sys
from os import PathLike
from typing import Any

import numpy as np
import xattree
from cattrs import Converter
from jinja2 import Environment, PackageLoader

from flopy4.mf6 import filters
from flopy4.mf6.codec.converter import (
unstructure_array,
unstructure_chd,
unstructure_component,
unstructure_oc,
unstructure_tdis,
)

_JINJA_ENV = Environment(
loader=PackageLoader("flopy4.mf6"),
trim_blocks=True,
lstrip_blocks=True,
)
_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks
_JINJA_ENV.filters["list_blocks"] = filters.list_blocks
_JINJA_ENV.filters["field_type"] = filters.field_type
_JINJA_ENV.filters["field_value"] = filters.field_value
_JINJA_ENV.filters["array_how"] = filters.array_how
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
_JINJA_ENV.filters["array2string"] = filters.array2string

_JINJA_TEMPLATE_NAME = "blocks.jinja"

_PRINT_OPTIONS = {
"precision": 4,
"linewidth": sys.maxsize,
"threshold": sys.maxsize,
}


def _make_converter() -> Converter:
# TODO: document what is converter's responsibility vs Jinja's
# TODO: how can we make sure writing remains lazy for list input?
# don't eagerly unstructure to dict, lazily access from the template?

from flopy4.mf6.component import Component
from flopy4.mf6.gwf.chd import Chd
from flopy4.mf6.gwf.oc import Oc
from flopy4.mf6.tdis import Tdis

converter = Converter()
converter.register_unstructure_hook_factory(xattree.has, lambda _: xattree.asdict)
converter.register_unstructure_hook(Component, unstructure_component)
converter.register_unstructure_hook(Tdis, unstructure_tdis)
converter.register_unstructure_hook(Chd, unstructure_chd)
converter.register_unstructure_hook(Oc, unstructure_oc)
return converter


_CONVERTER = _make_converter()


def loads(data: str) -> Any:
# TODO
pass


def load(path: str | PathLike) -> Any:
# TODO
pass


def dumps(data) -> str:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
return template.render(dfn=type(data).dfn, data=_CONVERTER.unstructure(data))


def dump(data, path: str | PathLike) -> None:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
iterator = template.generate(dfn=type(data).dfn, data=_CONVERTER.unstructure(data))
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
f.writelines(iterator)

from flopy4.mf6.codec.writer import dump, dumps, load, loads

__all__ = [
"unstructure_array",
"loads",
"load",
"dumps",
Expand Down
154 changes: 0 additions & 154 deletions flopy4/mf6/codec/converter.py

This file was deleted.

Empty file.
52 changes: 52 additions & 0 deletions flopy4/mf6/codec/writer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import sys
from os import PathLike
from typing import Any

import numpy as np
from jinja2 import Environment, PackageLoader

from flopy4.mf6 import filters

_JINJA_ENV = Environment(
loader=PackageLoader("flopy4.mf6.codec.writer"),
trim_blocks=True,
lstrip_blocks=True,
)
_JINJA_ENV.filters["dict_blocks"] = filters.dict_blocks
_JINJA_ENV.filters["list_blocks"] = filters.list_blocks
_JINJA_ENV.filters["array_how"] = filters.array_how
_JINJA_ENV.filters["array_chunks"] = filters.array_chunks
_JINJA_ENV.filters["array2string"] = filters.array2string
_JINJA_ENV.filters["field_type"] = filters.field_type
_JINJA_ENV.filters["array2list"] = filters.array2list
_JINJA_ENV.filters["keystring2list"] = filters.keystring2list
_JINJA_ENV.filters["keystring2list_multifield"] = filters.keystring2list_multifield
_JINJA_TEMPLATE_NAME = "blocks.jinja"
_PRINT_OPTIONS = {
"precision": 4,
"linewidth": sys.maxsize,
"threshold": sys.maxsize,
}


def loads(data: str) -> Any:
# TODO
pass


def load(path: str | PathLike) -> Any:
# TODO
pass


def dumps(data) -> str:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
return template.render(data=data)


def dump(data, path: str | PathLike) -> None:
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
iterator = template.generate(data=data)
with np.printoptions(**_PRINT_OPTIONS), open(path, "w") as f: # type: ignore
f.writelines(iterator)
13 changes: 13 additions & 0 deletions flopy4/mf6/codec/writer/templates/blocks.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{% import 'macros.jinja' as macros with context %}
{% for block_name, block_value in (data|dict_blocks).items() %}
BEGIN {{ block_name.upper() }}
{% for field_name, field_value in block_value.items() if (field_value) is not none -%}
{{ macros.field(field_name, field_value) }}
{%- endfor %}
END {{ block_name.upper() }}

{% endfor %}

{% for block_name, block_value in (data|list_blocks).items() -%}
{{ macros.list(block_name, block_value, multi=block_name in ["period"]) }}
{%- endfor %}
Loading
Loading