22
33import dataclasses
44from pathlib import Path
5- from typing import TYPE_CHECKING , Any , Callable , Self
5+ from typing import TYPE_CHECKING , Any , Callable , Self , TypeAlias
66
77import dacite
88import xarray as xr
1313if TYPE_CHECKING :
1414 from ndsl import QuantityFactory
1515
16+ StateMemoryMapping : TypeAlias = dict [str , dict | ArrayLike | None ]
17+
1618
1719@dataclasses .dataclass
1820class State :
@@ -86,7 +88,7 @@ def ones(cls, quantity_factory: QuantityFactory) -> Self:
8688 def copy_memory (
8789 cls ,
8890 quantity_factory : QuantityFactory ,
89- memory_map : dict [ str , Any ] ,
91+ memory_map : StateMemoryMapping ,
9092 ) -> Self :
9193 """Allocate all quantities and fill their value based
9294 on the given memory map. See `update_from_memory`"""
@@ -100,7 +102,7 @@ def copy_memory(
100102 def move_memory (
101103 cls ,
102104 quantity_factory : QuantityFactory ,
103- memory_map : dict [ str , Any ] ,
105+ memory_map : StateMemoryMapping ,
104106 * ,
105107 check_shape_and_strides : bool = True ,
106108 ) -> Self :
@@ -147,10 +149,14 @@ class InnerA
147149
148150 def _update_from_memory_recursive (
149151 state : State ,
150- memory_map : dict [ str , dict | ArrayLike ] ,
152+ memory_map : StateMemoryMapping ,
151153 ):
152154 for name , array in memory_map .items ():
153- if isinstance (array , dict ):
155+ if array is None :
156+ raise TypeError (
157+ f"State memory copy: illegal copy from None for attribute { name } "
158+ )
159+ elif isinstance (array , dict ):
154160 _update_from_memory_recursive (state .__getattribute__ (name ), array )
155161 else :
156162 try :
@@ -165,7 +171,7 @@ def _update_from_memory_recursive(
165171
166172 def update_move_memory (
167173 self ,
168- memory_map : dict [ str , dict | ArrayLike ] ,
174+ memory_map : StateMemoryMapping ,
169175 * ,
170176 check_shape_and_strides : bool = True ,
171177 ) -> None :
@@ -204,11 +210,11 @@ class InnerA
204210 shape and strides as the original quantity
205211 """
206212
207- def _update_zero_copy_recursive (
208- state : State , memory_map : dict [str , dict | ArrayLike ]
209- ):
213+ def _update_zero_copy_recursive (state : State , memory_map : StateMemoryMapping ):
210214 for name , array in memory_map .items ():
211- if isinstance (array , dict ):
215+ if array is None :
216+ state .__setattr__ (name , None )
217+ elif isinstance (array , dict ):
212218 _update_zero_copy_recursive (state .__getattribute__ (name ), array )
213219 else :
214220 quantity = state .__getattribute__ (name )
@@ -227,8 +233,11 @@ def _update_zero_copy_recursive(
227233 f" Strides: { array .strides } != { quantity .data .strides } "
228234 )
229235 raise e
230-
231- quantity .data = array
236+ try :
237+ quantity .data = array
238+ except Exception as e :
239+ e .add_note (f" Error on { name } for { type (state )} " )
240+ raise e
232241
233242 _update_zero_copy_recursive (self , memory_map )
234243
0 commit comments