Skip to content

Commit 0c0e686

Browse files
committed
feat: implement WriteContext with context manager support (#256)
- Created WriteContext class for configurable write behavior - Supports float_precision, use_binary, binary_threshold, use_relative_paths - Works as context manager with thread-local storage - Can be attached to components or passed to write() method - Integrated with writer filters to respect precision settings - Added comprehensive tests for all functionality
1 parent fe39f4b commit 0c0e686

File tree

6 files changed

+462
-26
lines changed

6 files changed

+462
-26
lines changed

flopy4/mf6/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,16 @@ def _load_toml(path: Path) -> Component:
4040
return structure(load_toml(fp), path)
4141

4242

43-
def _write_mf6(component: Component) -> None:
43+
def _write_mf6(component: Component, context=None, **kwargs) -> None:
44+
from flopy4.mf6.write_context import WriteContext
45+
46+
# Use provided context or default
47+
ctx = context if context is not None else WriteContext.default()
48+
4449
with open(component.path, "w") as fp:
4550
data = unstructure(component)
4651
try:
47-
dump_mf6(data, fp)
52+
dump_mf6(data, fp, context=ctx)
4853
except Exception as e:
4954
raise WriteError(
5055
f"Failed to write MF6 format file for component '{component.name}' " # type: ignore

flopy4/mf6/codec/writer/__init__.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import IO
2+
from typing import IO, Optional
33

44
import numpy as np
55
from jinja2 import Environment, PackageLoader
@@ -19,21 +19,78 @@
1919
_JINJA_ENV.filters["array2string"] = writer_filters.array2string
2020
_JINJA_ENV.filters["data2list"] = writer_filters.data2list
2121
_JINJA_TEMPLATE_NAME = "blocks.jinja"
22-
_PRINT_OPTIONS = {
23-
"precision": 4,
24-
"linewidth": sys.maxsize,
25-
"threshold": sys.maxsize,
26-
}
2722

2823

29-
def dumps(data) -> str:
24+
def _get_print_options(context=None):
25+
"""Get numpy print options from WriteContext."""
26+
if context is not None:
27+
return context.to_numpy_printoptions()
28+
# Default options
29+
return {
30+
"precision": 4,
31+
"linewidth": sys.maxsize,
32+
"threshold": sys.maxsize,
33+
}
34+
35+
36+
def dumps(data, context=None) -> str:
37+
"""
38+
Serialize data to MF6 format string.
39+
40+
Parameters
41+
----------
42+
data : dict
43+
Data to serialize
44+
context : WriteContext, optional
45+
Configuration context for writing
46+
47+
Returns
48+
-------
49+
str
50+
Serialized MF6 format string
51+
"""
52+
from flopy4.mf6.write_context import WriteContext
53+
54+
# Store context in filter module for filters to access
55+
if context is None:
56+
context = WriteContext.default()
57+
writer_filters._ACTIVE_CONTEXT = context
58+
3059
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
31-
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
32-
return template.render(blocks=data)
60+
print_opts = _get_print_options(context)
61+
with np.printoptions(**print_opts): # type: ignore
62+
result = template.render(blocks=data)
3363

64+
# Clean up
65+
writer_filters._ACTIVE_CONTEXT = None
66+
return result
67+
68+
69+
def dump(data, fp: IO[str], context=None) -> None:
70+
"""
71+
Serialize data to MF6 format and write to file.
72+
73+
Parameters
74+
----------
75+
data : dict
76+
Data to serialize
77+
fp : IO[str]
78+
File pointer to write to
79+
context : WriteContext, optional
80+
Configuration context for writing
81+
"""
82+
from flopy4.mf6.write_context import WriteContext
83+
84+
# Store context in filter module for filters to access
85+
if context is None:
86+
context = WriteContext.default()
87+
writer_filters._ACTIVE_CONTEXT = context
3488

35-
def dump(data, fp: IO[str]) -> None:
3689
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
3790
iterator = template.generate(blocks=data)
38-
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
91+
print_opts = _get_print_options(context)
92+
with np.printoptions(**print_opts): # type: ignore
3993
fp.writelines(iterator)
94+
95+
# Clean up
96+
writer_filters._ACTIVE_CONTEXT = None

flopy4/mf6/codec/writer/filters.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
ArrayHow = Literal["constant", "internal", "external", "layered constant", "layered internal"]
1414

15+
# Module-level variable to store active write context
16+
_ACTIVE_CONTEXT = None
17+
1518

1619
def array_how(value: xr.DataArray) -> ArrayHow:
1720
"""
@@ -45,7 +48,11 @@ def array2const(value: xr.DataArray) -> Scalar:
4548
if np.issubdtype(value.dtype, np.integer):
4649
return value.max().item()
4750
if np.issubdtype(value.dtype, np.floating):
48-
return f"{value.max().item():.8f}"
51+
# Use precision from active context if available
52+
precision = 8 # default
53+
if _ACTIVE_CONTEXT is not None:
54+
precision = _ACTIVE_CONTEXT.float_precision
55+
return f"{value.max().item():.{precision}f}"
4956
return value.ravel()[0]
5057

5158

@@ -112,13 +119,18 @@ def array2string(value: NDArray) -> str:
112119
# add an axis to 1d arrays so np.savetxt writes elements on 1 line
113120
value = value[None]
114121
value = np.atleast_1d(value)
115-
format = (
116-
"%d"
117-
if np.issubdtype(value.dtype, np.integer)
118-
else "%.9e"
119-
if np.issubdtype(value.dtype, np.floating)
120-
else "%s"
121-
)
122+
123+
# Use precision from active context if available
124+
if np.issubdtype(value.dtype, np.floating):
125+
precision = 9 # default
126+
if _ACTIVE_CONTEXT is not None:
127+
precision = _ACTIVE_CONTEXT.float_precision
128+
format = f"%.{precision}e"
129+
elif np.issubdtype(value.dtype, np.integer):
130+
format = "%d"
131+
else:
132+
format = "%s"
133+
122134
np.savetxt(buffer, value, fmt=format, delimiter=" ")
123135
return buffer.getvalue().strip()
124136

flopy4/mf6/component.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC
22
from collections.abc import MutableMapping
33
from pathlib import Path
4-
from typing import Any, ClassVar
4+
from typing import Any, ClassVar, Optional
55

66
from attrs import fields
77
from modflow_devtools.dfn import Dfn, Field
@@ -12,6 +12,7 @@
1212
from flopy4.mf6.constants import MF6
1313
from flopy4.mf6.spec import field, fields_dict, to_field
1414
from flopy4.mf6.utils.grid import update_maxbound
15+
from flopy4.mf6.write_context import WriteContext
1516
from flopy4.uio import IO, Loader, Writer
1617

1718
COMPONENTS = {}
@@ -39,6 +40,9 @@ class Component(ABC, MutableMapping):
3940
filename: str | None = field(default=None)
4041
"""The name of the component's input file."""
4142

43+
write_context: Optional[WriteContext] = field(default=None, repr=False)
44+
"""Configuration context for writing input files."""
45+
4246
@property
4347
def path(self) -> Path:
4448
"""The path to the component's input file."""
@@ -142,16 +146,32 @@ def load(self, format: str = MF6) -> None:
142146
for child in self.children.values(): # type: ignore
143147
child.load(format=format)
144148

145-
def write(self, format: str = MF6) -> None:
146-
"""Write the component and any children."""
149+
def write(self, format: str = MF6, context: Optional[WriteContext] = None) -> None:
150+
"""
151+
Write the component and any children.
152+
153+
Parameters
154+
----------
155+
format : str, optional
156+
Output format. Default is MF6.
157+
context : WriteContext, optional
158+
Configuration context for writing. If provided, overrides
159+
the component's write_context. If neither is provided,
160+
uses the current context from the context manager stack,
161+
or default settings.
162+
"""
147163
# TODO: setting filename is a temp hack to get the parent's
148164
# name as this component's filename stem, if it has one. an
149165
# actual solution is to auto-set the filename when children
150166
# are attached to parents.
151167
self.filename = self.filename or self.default_filename()
152-
self._write(format=format)
168+
169+
# Determine active context: provided > attached > current > default
170+
active_context = context or self.write_context or WriteContext.current()
171+
172+
self._write(format=format, context=active_context)
153173
for child in self.children.values(): # type: ignore
154-
child.write(format=format)
174+
child.write(format=format, context=context)
155175

156176
def to_dict(self, blocks: bool = False, strict: bool = False) -> dict[str, Any]:
157177
"""

flopy4/mf6/write_context.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""Write context for configuring MF6 input file writing."""
2+
3+
import threading
4+
from typing import Literal, Optional
5+
6+
from attrs import define, field
7+
8+
ArrayFormat = Literal["internal", "constant", "open/close"]
9+
10+
11+
@define
12+
class WriteContext:
13+
"""
14+
Configuration context for writing MODFLOW 6 input files.
15+
16+
This class controls how simulation data is written to input files,
17+
including numeric precision, binary vs ASCII format, and path handling.
18+
19+
Can be used in three ways:
20+
1. Attached to a component: component.write_context = ctx
21+
2. Passed to write method: component.write(context=ctx)
22+
3. As a context manager for temporary configuration
23+
24+
Parameters
25+
----------
26+
use_binary : bool, optional
27+
Prefer binary files for arrays. Default is False.
28+
binary_threshold : int, optional
29+
Size threshold (in bytes) for using binary format.
30+
Arrays larger than this will be written as binary.
31+
If None, use_binary setting is used unconditionally.
32+
float_precision : int, optional
33+
Number of decimal places for float output. Default is 6.
34+
use_relative_paths : bool, optional
35+
Use relative paths in input files. Default is True.
36+
array_format : ArrayFormat, optional
37+
How to write array data: 'internal', 'constant', or 'open/close'.
38+
If None, automatically determined based on array properties.
39+
40+
Examples
41+
--------
42+
>>> # Attach to component
43+
>>> ctx = WriteContext(float_precision=8, use_binary=True)
44+
>>> sim.write_context = ctx
45+
>>> sim.write()
46+
47+
>>> # Pass to write method
48+
>>> sim.write(context=WriteContext(float_precision=4))
49+
50+
>>> # Use as context manager
51+
>>> with WriteContext(use_binary=True, float_precision=8):
52+
... sim.write()
53+
"""
54+
55+
use_binary: bool = field(default=False)
56+
binary_threshold: Optional[int] = field(default=None)
57+
float_precision: int = field(default=6)
58+
use_relative_paths: bool = field(default=True)
59+
array_format: Optional[ArrayFormat] = field(default=None)
60+
61+
# Class-level thread-local storage for context stack
62+
_context_stack: threading.local = field(init=False, factory=threading.local)
63+
64+
def __enter__(self) -> "WriteContext":
65+
"""Enter context manager, pushing this context onto the stack."""
66+
if not hasattr(self._context_stack, "stack"):
67+
self._context_stack.stack = []
68+
self._context_stack.stack.append(self)
69+
return self
70+
71+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
72+
"""Exit context manager, popping this context from the stack."""
73+
if hasattr(self._context_stack, "stack") and self._context_stack.stack:
74+
self._context_stack.stack.pop()
75+
76+
@classmethod
77+
def current(cls) -> "WriteContext":
78+
"""
79+
Get the currently active WriteContext from the context stack.
80+
81+
Returns the most recently entered context manager, or a default
82+
context if no context manager is active.
83+
84+
Returns
85+
-------
86+
WriteContext
87+
The active context, or a default context.
88+
"""
89+
# Create a class-level thread-local if it doesn't exist
90+
if not hasattr(cls, "_global_context_stack"):
91+
cls._global_context_stack = threading.local()
92+
93+
if (
94+
hasattr(cls._global_context_stack, "stack")
95+
and cls._global_context_stack.stack
96+
):
97+
return cls._global_context_stack.stack[-1]
98+
return cls.default()
99+
100+
@classmethod
101+
def default(cls) -> "WriteContext":
102+
"""
103+
Create a WriteContext with default settings.
104+
105+
Returns
106+
-------
107+
WriteContext
108+
A context with default configuration values.
109+
"""
110+
return cls()
111+
112+
def to_numpy_printoptions(self) -> dict:
113+
"""
114+
Convert WriteContext settings to numpy printoptions dict.
115+
116+
Returns
117+
-------
118+
dict
119+
Dictionary suitable for use with np.printoptions()
120+
"""
121+
import sys
122+
123+
return {
124+
"precision": self.float_precision,
125+
"linewidth": sys.maxsize,
126+
"threshold": sys.maxsize,
127+
}
128+
129+
def get_float_format(self) -> str:
130+
"""
131+
Get the float format string based on precision setting.
132+
133+
Returns
134+
-------
135+
str
136+
Format string for floating point numbers (e.g., "%.6e")
137+
"""
138+
return f"%.{self.float_precision}e"

0 commit comments

Comments
 (0)