Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions flopy4/mf6/binding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from attrs import define

from flopy4.mf6.component import Component
from flopy4.mf6.exchange import Exchange
from flopy4.mf6.model import Model
from flopy4.mf6.package import Package
from flopy4.mf6.solution import Solution


@define
class Binding:
"""
An MF6 component binding: a record representation of the
component for writing to a parent component's name file.
"""

type: str
fname: str
terms: tuple[str, ...] | None = None

def to_tuple(self) -> tuple[str, ...]:
if self.terms and any(self.terms):
return (self.type, self.fname, *self.terms)
else:
return (self.type, self.fname)

@classmethod
def from_component(cls, component: Component) -> "Binding":
def _get_binding_type(component: Component) -> str:
cls_name = component.__class__.__name__
if isinstance(component, Exchange):
return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6"
elif isinstance(component, Solution):
return f"{component.slntype}6"
else:
return f"{cls_name.upper()}6"

def _get_binding_terms(component: Component) -> tuple[str, ...] | None:
if isinstance(component, Exchange):
return (component.exgmnamea, component.exgmnameb) # type: ignore
elif isinstance(component, Solution):
return tuple(component.models)
elif isinstance(component, (Model, Package)):
return (component.name,) # type: ignore
return None

return cls(
type=_get_binding_type(component),
fname=component.filename or component.default_filename(),
terms=_get_binding_terms(component),
)
4 changes: 2 additions & 2 deletions flopy4/mf6/codec/writer/templates/macros.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
{% elif type == 'array' %}
{{ array(name, value, how=value|array_how) }}
{% elif type == 'list' %}
{{ list(name, value) }}
{{ list(value) }}
{% endif %}
{% endmacro %}

Expand All @@ -22,7 +22,7 @@
{{ inset ~ value|join(" ") -}}
{% endmacro %}

{% macro list(name, value) %}
{% macro list(value) %}
{% for row in (value|data2list) %}
{{ record(row) }}{% if not loop.last %}{{ "\n" }}{% endif %}
{%- endfor %}
Expand Down
114 changes: 33 additions & 81 deletions flopy4/mf6/converter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Iterable, MutableMapping
from collections.abc import Iterable, Mapping
from datetime import datetime
from pathlib import Path
from typing import Any
Expand All @@ -7,68 +7,20 @@
import sparse
import xarray as xr
import xattree
from attrs import define
from cattrs import Converter
from modflow_devtools.dfn.schema.block import block_sort_key
from numpy.typing import NDArray
from xattree import get_xatspec

from flopy4.adapters import get_nn
from flopy4.mf6.binding import Binding
from flopy4.mf6.component import Component
from flopy4.mf6.config import SPARSE_THRESHOLD
from flopy4.mf6.constants import FILL_DNODATA
from flopy4.mf6.context import Context
from flopy4.mf6.exchange import Exchange
from flopy4.mf6.model import Model
from flopy4.mf6.package import Package
from flopy4.mf6.solution import Solution
from flopy4.mf6.spec import fields_dict


@define
class _Binding:
"""
An MF6 component binding: a record representation of the
component for writing to a parent component's name file.
"""

type: str
fname: str
terms: tuple[str, ...] | None = None

def to_tuple(self):
if self.terms and any(self.terms):
return (self.type, self.fname, *self.terms)
else:
return (self.type, self.fname)

@classmethod
def from_component(cls, component: Component) -> "_Binding":
def _get_binding_type(component: Component) -> str:
cls_name = component.__class__.__name__
if isinstance(component, Exchange):
return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6"
elif isinstance(component, Solution):
return f"{component.slntype}6"
else:
return f"{cls_name.upper()}6"

def _get_binding_terms(component: Component) -> tuple[str, ...] | None:
if isinstance(component, Exchange):
return (component.exgmnamea, component.exgmnameb) # type: ignore
elif isinstance(component, Solution):
return tuple(component.models)
elif isinstance(component, (Model, Package)):
return (component.name,) # type: ignore
return None

return cls(
type=_get_binding_type(component),
fname=component.filename or component.default_filename(),
terms=_get_binding_terms(component),
)


def _attach_field_metadata(
dataset: xr.Dataset, component_type: type, field_names: list[str]
) -> None:
Expand All @@ -88,43 +40,43 @@ def _path_to_tuple(field_name: str, path_value: Path) -> tuple:
return (field_name.upper(), "FILEOUT", str(path_value))


def unstructure_component(value: Component) -> dict[str, Any]:
blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore
blocks: dict[str, dict[str, Any]] = {}
def get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
if not isinstance(value, Context):
return {}

blocks = {} # type: ignore
xatspec = xattree.get_xatspec(type(value))

# Handle child component bindings before converting to dict
if isinstance(value, Context):
for field_name, child_spec in xatspec.children.items():
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
block_name = child_spec.metadata["block"] # type: ignore
field_value = getattr(value, field_name, None)

if block_name not in blocks:
blocks[block_name] = {}

if isinstance(field_value, Component):
components = [_Binding.from_component(field_value).to_tuple()]
elif isinstance(field_value, MutableMapping):
components = [
_Binding.from_component(comp).to_tuple()
for comp in field_value.values()
if comp is not None
]
elif isinstance(field_value, Iterable):
components = [
_Binding.from_component(comp).to_tuple()
for comp in field_value
if comp is not None
]
else:
continue
for child_name, child_spec in xatspec.children.items():
if (child := getattr(value, child_name, None)) is None:
continue
if (block_name := child_spec.metadata["block"]) not in blocks: # type: ignore
blocks[block_name] = {}
match child:
case Component():
blocks[block_name][child_name] = [Binding.from_component(child).to_tuple()]
case Mapping():
blocks[block_name][child_name] = [
Binding.from_component(c).to_tuple() for c in child.values() if c is not None
]
case Iterable():
blocks[block_name][child_name] = [
Binding.from_component(c).to_tuple() for c in child if c is not None
]
case _:
raise ValueError(f"Unexpected child type: {type(child)}")

return blocks

if components:
blocks[block_name][field_name] = components

def unstructure_component(value: Component) -> dict[str, Any]:
blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore
blocks: dict[str, dict[str, Any]] = {}
xatspec = xattree.get_xatspec(type(value))
data = xattree.asdict(value)

blocks.update(binding_blocks := get_binding_blocks(value))

for block_name, block in blockspec.items():
if block_name not in blocks:
blocks[block_name] = {}
Expand Down
Loading
Loading