Skip to content

Commit 4729eb9

Browse files
authored
factor out binding module and function in converter, update pixi env (#204)
1 parent 6dd988e commit 4729eb9

File tree

4 files changed

+873
-963
lines changed

4 files changed

+873
-963
lines changed

flopy4/mf6/binding.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from attrs import define
2+
3+
from flopy4.mf6.component import Component
4+
from flopy4.mf6.exchange import Exchange
5+
from flopy4.mf6.model import Model
6+
from flopy4.mf6.package import Package
7+
from flopy4.mf6.solution import Solution
8+
9+
10+
@define
11+
class Binding:
12+
"""
13+
An MF6 component binding: a record representation of the
14+
component for writing to a parent component's name file.
15+
"""
16+
17+
type: str
18+
fname: str
19+
terms: tuple[str, ...] | None = None
20+
21+
def to_tuple(self) -> tuple[str, ...]:
22+
if self.terms and any(self.terms):
23+
return (self.type, self.fname, *self.terms)
24+
else:
25+
return (self.type, self.fname)
26+
27+
@classmethod
28+
def from_component(cls, component: Component) -> "Binding":
29+
def _get_binding_type(component: Component) -> str:
30+
cls_name = component.__class__.__name__
31+
if isinstance(component, Exchange):
32+
return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6"
33+
elif isinstance(component, Solution):
34+
return f"{component.slntype}6"
35+
else:
36+
return f"{cls_name.upper()}6"
37+
38+
def _get_binding_terms(component: Component) -> tuple[str, ...] | None:
39+
if isinstance(component, Exchange):
40+
return (component.exgmnamea, component.exgmnameb) # type: ignore
41+
elif isinstance(component, Solution):
42+
return tuple(component.models)
43+
elif isinstance(component, (Model, Package)):
44+
return (component.name,) # type: ignore
45+
return None
46+
47+
return cls(
48+
type=_get_binding_type(component),
49+
fname=component.filename or component.default_filename(),
50+
terms=_get_binding_terms(component),
51+
)

flopy4/mf6/codec/writer/templates/macros.jinja

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
{% elif type == 'array' %}
1010
{{ array(name, value, how=value|array_how) }}
1111
{% elif type == 'list' %}
12-
{{ list(name, value) }}
12+
{{ list(value) }}
1313
{% endif %}
1414
{% endmacro %}
1515

@@ -22,7 +22,7 @@
2222
{{ inset ~ value|join(" ") -}}
2323
{% endmacro %}
2424

25-
{% macro list(name, value) %}
25+
{% macro list(value) %}
2626
{% for row in (value|data2list) %}
2727
{{ record(row) }}{% if not loop.last %}{{ "\n" }}{% endif %}
2828
{%- endfor %}

flopy4/mf6/converter.py

Lines changed: 33 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterable, MutableMapping
1+
from collections.abc import Iterable, Mapping
22
from datetime import datetime
33
from pathlib import Path
44
from typing import Any
@@ -7,68 +7,20 @@
77
import sparse
88
import xarray as xr
99
import xattree
10-
from attrs import define
1110
from cattrs import Converter
1211
from modflow_devtools.dfn.schema.block import block_sort_key
1312
from numpy.typing import NDArray
1413
from xattree import get_xatspec
1514

1615
from flopy4.adapters import get_nn
16+
from flopy4.mf6.binding import Binding
1717
from flopy4.mf6.component import Component
1818
from flopy4.mf6.config import SPARSE_THRESHOLD
1919
from flopy4.mf6.constants import FILL_DNODATA
2020
from flopy4.mf6.context import Context
21-
from flopy4.mf6.exchange import Exchange
22-
from flopy4.mf6.model import Model
23-
from flopy4.mf6.package import Package
24-
from flopy4.mf6.solution import Solution
2521
from flopy4.mf6.spec import fields_dict
2622

2723

28-
@define
29-
class _Binding:
30-
"""
31-
An MF6 component binding: a record representation of the
32-
component for writing to a parent component's name file.
33-
"""
34-
35-
type: str
36-
fname: str
37-
terms: tuple[str, ...] | None = None
38-
39-
def to_tuple(self):
40-
if self.terms and any(self.terms):
41-
return (self.type, self.fname, *self.terms)
42-
else:
43-
return (self.type, self.fname)
44-
45-
@classmethod
46-
def from_component(cls, component: Component) -> "_Binding":
47-
def _get_binding_type(component: Component) -> str:
48-
cls_name = component.__class__.__name__
49-
if isinstance(component, Exchange):
50-
return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6"
51-
elif isinstance(component, Solution):
52-
return f"{component.slntype}6"
53-
else:
54-
return f"{cls_name.upper()}6"
55-
56-
def _get_binding_terms(component: Component) -> tuple[str, ...] | None:
57-
if isinstance(component, Exchange):
58-
return (component.exgmnamea, component.exgmnameb) # type: ignore
59-
elif isinstance(component, Solution):
60-
return tuple(component.models)
61-
elif isinstance(component, (Model, Package)):
62-
return (component.name,) # type: ignore
63-
return None
64-
65-
return cls(
66-
type=_get_binding_type(component),
67-
fname=component.filename or component.default_filename(),
68-
terms=_get_binding_terms(component),
69-
)
70-
71-
7224
def _attach_field_metadata(
7325
dataset: xr.Dataset, component_type: type, field_names: list[str]
7426
) -> None:
@@ -88,43 +40,43 @@ def _path_to_tuple(field_name: str, path_value: Path) -> tuple:
8840
return (field_name.upper(), "FILEOUT", str(path_value))
8941

9042

91-
def unstructure_component(value: Component) -> dict[str, Any]:
92-
blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore
93-
blocks: dict[str, dict[str, Any]] = {}
43+
def get_binding_blocks(value: Component) -> dict[str, dict[str, list[tuple[str, ...]]]]:
44+
if not isinstance(value, Context):
45+
return {}
46+
47+
blocks = {} # type: ignore
9448
xatspec = xattree.get_xatspec(type(value))
9549

96-
# Handle child component bindings before converting to dict
97-
if isinstance(value, Context):
98-
for field_name, child_spec in xatspec.children.items():
99-
if hasattr(child_spec, "metadata") and "block" in child_spec.metadata: # type: ignore
100-
block_name = child_spec.metadata["block"] # type: ignore
101-
field_value = getattr(value, field_name, None)
102-
103-
if block_name not in blocks:
104-
blocks[block_name] = {}
105-
106-
if isinstance(field_value, Component):
107-
components = [_Binding.from_component(field_value).to_tuple()]
108-
elif isinstance(field_value, MutableMapping):
109-
components = [
110-
_Binding.from_component(comp).to_tuple()
111-
for comp in field_value.values()
112-
if comp is not None
113-
]
114-
elif isinstance(field_value, Iterable):
115-
components = [
116-
_Binding.from_component(comp).to_tuple()
117-
for comp in field_value
118-
if comp is not None
119-
]
120-
else:
121-
continue
50+
for child_name, child_spec in xatspec.children.items():
51+
if (child := getattr(value, child_name, None)) is None:
52+
continue
53+
if (block_name := child_spec.metadata["block"]) not in blocks: # type: ignore
54+
blocks[block_name] = {}
55+
match child:
56+
case Component():
57+
blocks[block_name][child_name] = [Binding.from_component(child).to_tuple()]
58+
case Mapping():
59+
blocks[block_name][child_name] = [
60+
Binding.from_component(c).to_tuple() for c in child.values() if c is not None
61+
]
62+
case Iterable():
63+
blocks[block_name][child_name] = [
64+
Binding.from_component(c).to_tuple() for c in child if c is not None
65+
]
66+
case _:
67+
raise ValueError(f"Unexpected child type: {type(child)}")
68+
69+
return blocks
12270

123-
if components:
124-
blocks[block_name][field_name] = components
12571

72+
def unstructure_component(value: Component) -> dict[str, Any]:
73+
blockspec = dict(sorted(value.dfn.blocks.items(), key=block_sort_key)) # type: ignore
74+
blocks: dict[str, dict[str, Any]] = {}
75+
xatspec = xattree.get_xatspec(type(value))
12676
data = xattree.asdict(value)
12777

78+
blocks.update(binding_blocks := get_binding_blocks(value))
79+
12880
for block_name, block in blockspec.items():
12981
if block_name not in blocks:
13082
blocks[block_name] = {}

0 commit comments

Comments
 (0)