Skip to content

Commit a8738fe

Browse files
State: quick fix to v0 (NOAA-GFDL#245)
* Use `tmpdir` of `pytest` for a cleaner utest * Allow `None` as a valid State attribute * Deal with `None` iinput in move & copy Enhance errors on copy
1 parent 64c4694 commit a8738fe

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

ndsl/quantity/state.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44
from pathlib import Path
5-
from typing import TYPE_CHECKING, Any, Callable, Self
5+
from typing import TYPE_CHECKING, Any, Callable, Self, TypeAlias
66

77
import dacite
88
import xarray as xr
@@ -13,6 +13,8 @@
1313
if TYPE_CHECKING:
1414
from ndsl import QuantityFactory
1515

16+
StateMemoryMapping: TypeAlias = dict[str, dict | ArrayLike | None]
17+
1618

1719
@dataclasses.dataclass
1820
class State:
@@ -86,7 +88,7 @@ def ones(cls, quantity_factory: QuantityFactory) -> Self:
8688
def copy_memory(
8789
cls,
8890
quantity_factory: QuantityFactory,
89-
memory_map: dict[str, Any],
91+
memory_map: StateMemoryMapping,
9092
) -> Self:
9193
"""Allocate all quantities and fill their value based
9294
on the given memory map. See `update_from_memory`"""
@@ -100,7 +102,7 @@ def copy_memory(
100102
def move_memory(
101103
cls,
102104
quantity_factory: QuantityFactory,
103-
memory_map: dict[str, Any],
105+
memory_map: StateMemoryMapping,
104106
*,
105107
check_shape_and_strides: bool = True,
106108
) -> Self:
@@ -147,10 +149,14 @@ class InnerA
147149

148150
def _update_from_memory_recursive(
149151
state: State,
150-
memory_map: dict[str, dict | ArrayLike],
152+
memory_map: StateMemoryMapping,
151153
):
152154
for name, array in memory_map.items():
153-
if isinstance(array, dict):
155+
if array is None:
156+
raise TypeError(
157+
f"State memory copy: illegal copy from None for attribute {name}"
158+
)
159+
elif isinstance(array, dict):
154160
_update_from_memory_recursive(state.__getattribute__(name), array)
155161
else:
156162
try:
@@ -165,7 +171,7 @@ def _update_from_memory_recursive(
165171

166172
def update_move_memory(
167173
self,
168-
memory_map: dict[str, dict | ArrayLike],
174+
memory_map: StateMemoryMapping,
169175
*,
170176
check_shape_and_strides: bool = True,
171177
) -> None:
@@ -204,11 +210,11 @@ class InnerA
204210
shape and strides as the original quantity
205211
"""
206212

207-
def _update_zero_copy_recursive(
208-
state: State, memory_map: dict[str, dict | ArrayLike]
209-
):
213+
def _update_zero_copy_recursive(state: State, memory_map: StateMemoryMapping):
210214
for name, array in memory_map.items():
211-
if isinstance(array, dict):
215+
if array is None:
216+
state.__setattr__(name, None)
217+
elif isinstance(array, dict):
212218
_update_zero_copy_recursive(state.__getattribute__(name), array)
213219
else:
214220
quantity = state.__getattribute__(name)
@@ -227,8 +233,11 @@ def _update_zero_copy_recursive(
227233
f" Strides: {array.strides} != {quantity.data.strides}"
228234
)
229235
raise e
230-
231-
quantity.data = array
236+
try:
237+
quantity.data = array
238+
except Exception as e:
239+
e.add_note(f" Error on {name} for {type(state)}")
240+
raise e
232241

233242
_update_zero_copy_recursive(self, memory_map)
234243

tests/quantity/test_state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ class InnerB:
4747
)
4848

4949

50-
def test_state():
50+
def test_state(tmpdir):
5151
_, quantity_factory = get_factories_single_tile(
5252
5, 5, 3, 0, backend="dace:cpu_kfirst"
5353
)
5454

5555
microphys_state = CodeState.zeros(quantity_factory)
5656
microphys_state.inner_A.A.field[:] = 42.42
57-
microphys_state.to_netcdf()
57+
microphys_state.to_netcdf(Path(tmpdir))
5858
microphys_state2 = CodeState.zeros(quantity_factory)
59-
microphys_state2.update_from_netcdf(Path("./"))
59+
microphys_state2.update_from_netcdf(Path(tmpdir))
6060
assert (microphys_state2.inner_A.A.field[:] == 42.42).all()
6161
a = np.ones((5, 5, 3))
6262
b = np.ones((5, 5, 3))

0 commit comments

Comments
 (0)